kups.relaxation.optax.optimizer
¶
Factory utilities for building Optax optimizers from config specs.
Transform = str | dict[str, bool | int | float | str | list | None]
module-attribute
¶
A single transform spec: either a name string or a dict with "transform" key.
TransformationConfig = list[Transform]
module-attribute
¶
Ordered list of transform specs to chain into an optimizer.
get_transform(transform)
¶
Convert a transform config entry to an Optax GradientTransformation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
transform
|
Transform
|
Either a plain string name (e.g. |
required |
Returns:
| Type | Description |
|---|---|
GradientTransformation
|
The constructed GradientTransformation. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the transform name is not found in custom transforms or optax. |
Source code in src/kups/relaxation/optax/optimizer.py
get_transformations(transformations)
¶
Convert a list of transform configs to Optax GradientTransformations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
transformations
|
TransformationConfig
|
List of transform specifications. |
required |
Returns:
| Type | Description |
|---|---|
list[GradientTransformation]
|
List of GradientTransformations in the same order. |
Source code in src/kups/relaxation/optax/optimizer.py
make_optimizer(transformations)
¶
Create a chained optimizer from a list of transform configs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
transformations
|
TransformationConfig
|
List of transform specifications. |
required |
Returns:
| Type | Description |
|---|---|
GradientTransformationExtraArgs
|
Chained Optax GradientTransformation. |
Example
config = [ ... {"transform": "clip_by_global_norm", "max_norm": 1.0}, ... {"transform": "scale_by_fire", "dt_start": 0.1}, ... ] optimizer = make_optimizer(config)