Skip to content

kups.relaxation.optax.optimizer

Factory utilities for building Optax optimizers from config specs.

Transform = str | dict[str, bool | int | float | str | list | None] module-attribute

A single transform spec: either a name string or a dict with "transform" key.

TransformationConfig = list[Transform] module-attribute

Ordered list of transform specs to chain into an optimizer.

get_transform(transform)

Convert a transform config entry to an Optax GradientTransformation.

Parameters:

Name Type Description Default
transform Transform

Either a plain string name (e.g. "scale_by_adam") or a dict with a "transform" key and additional keyword arguments.

required

Returns:

Type Description
GradientTransformation

The constructed GradientTransformation.

Raises:

Type Description
ValueError

If the transform name is not found in custom transforms or optax.

Source code in src/kups/relaxation/optax/optimizer.py
def get_transform(transform: Transform) -> optax.GradientTransformation:
    """Convert a transform config entry to an Optax GradientTransformation.

    Args:
        transform: Either a plain string name (e.g. ``"scale_by_adam"``) or a
            dict with a ``"transform"`` key and additional keyword arguments.

    Returns:
        The constructed GradientTransformation.

    Raises:
        ValueError: If the transform name is not found in custom transforms or optax.
    """
    if isinstance(transform, str):
        name = transform
        kwargs: dict[str, Any] = {}
    else:
        transform = transform.copy()
        name = str(transform.pop("transform"))
        kwargs = transform

    if name in _CUSTOM_TRANSFORMS:
        constructor = _CUSTOM_TRANSFORMS[name]
    elif hasattr(optax, name):
        constructor = getattr(optax, name)
    else:
        raise ValueError(f"Unknown transformation: {name}")

    return constructor(**kwargs)

get_transformations(transformations)

Convert a list of transform configs to Optax GradientTransformations.

Parameters:

Name Type Description Default
transformations TransformationConfig

List of transform specifications.

required

Returns:

Type Description
list[GradientTransformation]

List of GradientTransformations in the same order.

Source code in src/kups/relaxation/optax/optimizer.py
def get_transformations(
    transformations: TransformationConfig,
) -> list[optax.GradientTransformation]:
    """Convert a list of transform configs to Optax GradientTransformations.

    Args:
        transformations: List of transform specifications.

    Returns:
        List of GradientTransformations in the same order.
    """
    return [get_transform(t) for t in transformations]

make_optimizer(transformations)

Create a chained optimizer from a list of transform configs.

Parameters:

Name Type Description Default
transformations TransformationConfig

List of transform specifications.

required

Returns:

Type Description
GradientTransformationExtraArgs

Chained Optax GradientTransformation.

Example

config = [ ... {"transform": "clip_by_global_norm", "max_norm": 1.0}, ... {"transform": "scale_by_fire", "dt_start": 0.1}, ... ] optimizer = make_optimizer(config)

Source code in src/kups/relaxation/optax/optimizer.py
def make_optimizer(
    transformations: TransformationConfig,
) -> optax.GradientTransformationExtraArgs:
    """Create a chained optimizer from a list of transform configs.

    Args:
        transformations: List of transform specifications.

    Returns:
        Chained Optax GradientTransformation.

    Example:
        >>> config = [
        ...     {"transform": "clip_by_global_norm", "max_norm": 1.0},
        ...     {"transform": "scale_by_fire", "dt_start": 0.1},
        ... ]
        >>> optimizer = make_optimizer(config)
    """
    return optax.chain(*get_transformations(transformations))