Skip to content

kups.relaxation.optax

Optax-based optimizers for structure relaxation.

Transform = str | dict[str, bool | int | float | str | list | None] module-attribute

A single transform spec: either a name string or a dict with "transform" key.

TransformationConfig = list[Transform] module-attribute

Ordered list of transform specs to chain into an optimizer.

ScaleByFireState

Bases: NamedTuple

State for scale_by_fire transform.

Attributes:

Name Type Description
velocity Params

Velocity estimate (PyTree matching params).

dt Array

Current adaptive timestep.

alpha Array

Current velocity mixing parameter.

n_pos Array

Count of consecutive positive power steps.

Source code in src/kups/relaxation/optax/fire.py
class ScaleByFireState(NamedTuple):
    """State for scale_by_fire transform.

    Attributes:
        velocity: Velocity estimate (PyTree matching params).
        dt: Current adaptive timestep.
        alpha: Current velocity mixing parameter.
        n_pos: Count of consecutive positive power steps.
    """

    velocity: optax.Params
    dt: Array
    alpha: Array
    n_pos: Array

get_transform(transform)

Convert a transform config entry to an Optax GradientTransformation.

Parameters:

Name Type Description Default
transform Transform

Either a plain string name (e.g. "scale_by_adam") or a dict with a "transform" key and additional keyword arguments.

required

Returns:

Type Description
GradientTransformation

The constructed GradientTransformation.

Raises:

Type Description
ValueError

If the transform name is not found in custom transforms or optax.

Source code in src/kups/relaxation/optax/optimizer.py
def get_transform(transform: Transform) -> optax.GradientTransformation:
    """Convert a transform config entry to an Optax GradientTransformation.

    Args:
        transform: Either a plain string name (e.g. ``"scale_by_adam"``) or a
            dict with a ``"transform"`` key and additional keyword arguments.

    Returns:
        The constructed GradientTransformation.

    Raises:
        ValueError: If the transform name is not found in custom transforms or optax.
    """
    if isinstance(transform, str):
        name = transform
        kwargs: dict[str, Any] = {}
    else:
        transform = transform.copy()
        name = str(transform.pop("transform"))
        kwargs = transform

    if name in _CUSTOM_TRANSFORMS:
        constructor = _CUSTOM_TRANSFORMS[name]
    elif hasattr(optax, name):
        constructor = getattr(optax, name)
    else:
        raise ValueError(f"Unknown transformation: {name}")

    return constructor(**kwargs)

get_transformations(transformations)

Convert a list of transform configs to Optax GradientTransformations.

Parameters:

Name Type Description Default
transformations TransformationConfig

List of transform specifications.

required

Returns:

Type Description
list[GradientTransformation]

List of GradientTransformations in the same order.

Source code in src/kups/relaxation/optax/optimizer.py
def get_transformations(
    transformations: TransformationConfig,
) -> list[optax.GradientTransformation]:
    """Convert a list of transform configs to Optax GradientTransformations.

    Args:
        transformations: List of transform specifications.

    Returns:
        List of GradientTransformations in the same order.
    """
    return [get_transform(t) for t in transformations]

make_optimizer(transformations)

Create a chained optimizer from a list of transform configs.

Parameters:

Name Type Description Default
transformations TransformationConfig

List of transform specifications.

required

Returns:

Type Description
GradientTransformationExtraArgs

Chained Optax GradientTransformation.

Example

config = [ ... {"transform": "clip_by_global_norm", "max_norm": 1.0}, ... {"transform": "scale_by_fire", "dt_start": 0.1}, ... ] optimizer = make_optimizer(config)

Source code in src/kups/relaxation/optax/optimizer.py
def make_optimizer(
    transformations: TransformationConfig,
) -> optax.GradientTransformationExtraArgs:
    """Create a chained optimizer from a list of transform configs.

    Args:
        transformations: List of transform specifications.

    Returns:
        Chained Optax GradientTransformation.

    Example:
        >>> config = [
        ...     {"transform": "clip_by_global_norm", "max_norm": 1.0},
        ...     {"transform": "scale_by_fire", "dt_start": 0.1},
        ... ]
        >>> optimizer = make_optimizer(config)
    """
    return optax.chain(*get_transformations(transformations))

scale_by_ase_lbfgs(memory_size=100, alpha=70.0)

L-BFGS preconditioner with ASE-style initial inverse Hessian.

Equivalent to optax.scale_by_lbfgs except the initial Hessian approximation is (1/alpha) * I (following the ASE convention) rather than the curvature-based initialization used by default in Optax.

Parameters:

Name Type Description Default
memory_size int

Number of past (param, gradient) differences to store. Must be >= 1.

100
alpha float

Initial inverse Hessian is (1/alpha) * I. In the ASE convention this controls the initial step size.

70.0

Returns:

Type Description
GradientTransformation

Optax GradientTransformation applying L-BFGS preconditioning.

Raises:

Type Description
ValueError

If memory_size < 1.

Example

optimizer = optax.chain( ... scale_by_ase_lbfgs(memory_size=10, alpha=70.0), ... optax.scale(-1.0), ... )

Source code in src/kups/relaxation/optax/lbfgs.py
def scale_by_ase_lbfgs(
    memory_size: int = 100, alpha: float = 70.0
) -> optax.GradientTransformation:
    """L-BFGS preconditioner with ASE-style initial inverse Hessian.

    Equivalent to ``optax.scale_by_lbfgs`` except the initial Hessian
    approximation is ``(1/alpha) * I`` (following the ASE convention)
    rather than the curvature-based initialization used by default in Optax.

    Args:
        memory_size: Number of past (param, gradient) differences to store.
            Must be >= 1.
        alpha: Initial inverse Hessian is ``(1/alpha) * I``.  In the ASE
            convention this controls the initial step size.

    Returns:
        Optax GradientTransformation applying L-BFGS preconditioning.

    Raises:
        ValueError: If ``memory_size < 1``.

    Example:
        >>> optimizer = optax.chain(
        ...     scale_by_ase_lbfgs(memory_size=10, alpha=70.0),
        ...     optax.scale(-1.0),
        ... )
    """
    if memory_size < 1:
        raise ValueError("memory_size must be >= 1")

    def init_fn(params) -> ScaleByLBFGSState:
        stacked_zero_params = jax.tree.map(
            lambda leaf: jnp.zeros((memory_size,) + leaf.shape, dtype=leaf.dtype),
            params,
        )
        return ScaleByLBFGSState(
            count=jnp.asarray(0, dtype=jnp.int32),
            params=optax.tree.zeros_like(params),
            updates=optax.tree.zeros_like(params),
            diff_params_memory=stacked_zero_params,
            diff_updates_memory=optax.tree.zeros_like(stacked_zero_params),
            weights_memory=jnp.zeros(memory_size),
        )

    def update_fn[P](
        updates: P, state: ScaleByLBFGSState, params: P
    ) -> tuple[P, ScaleByLBFGSState]:
        memory_idx = state.count % memory_size  # type: ignore[arg-type] - optax typing
        prev_memory_idx = (state.count - 1) % memory_size  # type: ignore

        # Update the memory buffers with fresh params and gradients
        diff_params = optax.tree.sub(params, state.params)
        diff_updates = optax.tree.sub(updates, state.updates)
        vdot_diff_params_updates = optax.tree.real(
            optax.tree.vdot(diff_updates, diff_params)
        )
        weight = jnp.where(
            vdot_diff_params_updates == 0.0, 0.0, 1.0 / vdot_diff_params_updates
        )
        # Differences are undefined at first iteration; keep at zero
        diff_params, diff_updates, weight = jax.tree.map(
            lambda x: jnp.where(state.count > 0, x, jnp.zeros_like(x)),  # type: ignore
            (diff_params, diff_updates, weight),
        )
        diff_params_memory, diff_updates_memory, weights_memory = jax.tree.map(
            lambda x, y: x.at[prev_memory_idx].set(y),
            (
                state.diff_params_memory,
                state.diff_updates_memory,
                state.weights_memory,
            ),
            (diff_params, diff_updates, weight),
        )
        identity_scale = 1.0 / alpha

        # Compute the L-BFGS preconditioned update
        precond_updates = _precondition_by_lbfgs(
            updates,  # type: ignore[arg-type] - optax typing
            diff_params_memory,
            diff_updates_memory,
            weights_memory,
            identity_scale,
            memory_idx,  # type: ignore[arg-type] - optax typing
        )
        return precond_updates, ScaleByLBFGSState(  # type: ignore[arg-type] - optax typing
            count=state.count + 1,
            params=params,  # type: ignore[arg-type] - optax typing
            updates=updates,  # type: ignore[arg-type] - optax typing
            diff_params_memory=diff_params_memory,
            diff_updates_memory=diff_updates_memory,
            weights_memory=weights_memory,
        )

    return optax.GradientTransformation(init_fn, update_fn)  # type: ignore[arg-type] - optax typing

scale_by_fire(dt_start=0.1, dt_max=None, dt_min=None, max_step=0.2, f_inc=1.1, f_dec=0.5, alpha_start=0.1, f_alpha=0.99, n_min=5)

FIRE (Fast Inertial Relaxation Engine) optimizer.

Composable Optax transform implementing the FIRE algorithm for structure relaxation. Can be chained with other transforms.

Parameters:

Name Type Description Default
dt_start float

Initial timestep.

0.1
dt_max float | None

Maximum timestep. Defaults to 10 * dt_start.

None
dt_min float | None

Minimum timestep. Defaults to dt_start * 1e-4.

None
max_step float | None

Maximum step size (clips position updates). Defaults to 0.2 Å. Set to None to disable clipping.

0.2
f_inc float

Factor to increase dt when making progress.

1.1
f_dec float

Factor to decrease dt on bad step.

0.5
alpha_start float

Initial velocity mixing parameter.

0.1
f_alpha float

Factor to decay alpha when making progress.

0.99
n_min int

Minimum positive power steps before increasing dt.

5

Returns:

Type Description
GradientTransformation

Optax GradientTransformation implementing FIRE.

Reference

Bitzek et al., Phys. Rev. Lett. 97, 170201 (2006).

Source code in src/kups/relaxation/optax/fire.py
def scale_by_fire(
    dt_start: float = 0.1,
    dt_max: float | None = None,
    dt_min: float | None = None,
    max_step: float | None = 0.2,
    f_inc: float = 1.1,
    f_dec: float = 0.5,
    alpha_start: float = 0.1,
    f_alpha: float = 0.99,
    n_min: int = 5,
) -> optax.GradientTransformation:
    """FIRE (Fast Inertial Relaxation Engine) optimizer.

    Composable Optax transform implementing the FIRE algorithm for
    structure relaxation. Can be chained with other transforms.

    Args:
        dt_start: Initial timestep.
        dt_max: Maximum timestep. Defaults to 10 * dt_start.
        dt_min: Minimum timestep. Defaults to dt_start * 1e-4.
        max_step: Maximum step size (clips position updates). Defaults to 0.2 Å.
            Set to None to disable clipping.
        f_inc: Factor to increase dt when making progress.
        f_dec: Factor to decrease dt on bad step.
        alpha_start: Initial velocity mixing parameter.
        f_alpha: Factor to decay alpha when making progress.
        n_min: Minimum positive power steps before increasing dt.

    Returns:
        Optax GradientTransformation implementing FIRE.

    Reference:
        Bitzek et al., Phys. Rev. Lett. 97, 170201 (2006).
    """
    if dt_max is None:
        dt_max = 10.0 * dt_start
    if dt_min is None:
        dt_min = dt_start * 1e-4

    def init_fn(params: optax.Params) -> ScaleByFireState:
        return ScaleByFireState(
            velocity=jax.tree.map(jnp.zeros_like, params),
            dt=jnp.array(dt_start),
            alpha=jnp.array(alpha_start),
            n_pos=jnp.array(0, dtype=jnp.int32),
        )

    def update_fn(
        updates: optax.Updates,
        state: ScaleByFireState,
        params: optax.Params | None = None,
    ) -> tuple[optax.Updates, ScaleByFireState]:
        del params

        # F = -gradient (FIRE uses forces, pointing downhill)
        forces = jax.tree.map(lambda g: -g, updates)

        # Update velocity: v = v + dt * F
        velocity = jax.tree.map(lambda v, f: v + state.dt * f, state.velocity, forces)

        # Compute power: P = F · v (positive when moving downhill)
        power = optax.tree_utils.tree_vdot(forces, velocity)
        positive_power = power > 0.0  # type: ignore

        # Velocity mixing: v = (1-α)v + α|v|F̂
        v_norm = optax.tree_utils.tree_norm(velocity)
        f_norm = optax.tree_utils.tree_norm(forces)
        safe_f_norm = jnp.maximum(f_norm, 1e-10)

        mixed_velocity = jax.tree.map(
            lambda v, f: (1 - state.alpha) * v + state.alpha * v_norm * f / safe_f_norm,
            velocity,
            forces,
        )

        # Adaptive timestep and mixing parameter
        should_increase = jnp.logical_and(positive_power, state.n_pos >= n_min)

        new_dt = jnp.where(
            positive_power,
            jnp.where(should_increase, jnp.minimum(state.dt * f_inc, dt_max), state.dt),
            jnp.maximum(state.dt * f_dec, dt_min),
        )
        new_alpha = jnp.where(
            positive_power,
            jnp.where(should_increase, state.alpha * f_alpha, state.alpha),
            alpha_start,
        )
        new_n_pos = jnp.where(positive_power, state.n_pos + 1, 0)

        # If P > 0: use mixed velocity for next step and position update
        # If P <= 0: reset velocity to zero, no position update
        final_velocity = jax.tree.map(
            lambda v: jnp.where(positive_power, v, jnp.zeros_like(v)),
            mixed_velocity,
        )

        # Position update: step only when making progress (P > 0)
        position_updates = jax.tree.map(
            lambda v: jnp.where(positive_power, state.dt * v, jnp.zeros_like(v)),
            mixed_velocity,
        )

        # Clip position updates to max_step (prevents runaway steps)
        if max_step is not None:
            update_norm = optax.tree_utils.tree_norm(position_updates)
            scale = jnp.minimum(1.0, max_step / jnp.maximum(update_norm, 1e-10))
            position_updates = jax.tree.map(lambda u: u * scale, position_updates)

        return position_updates, ScaleByFireState(
            velocity=final_velocity, dt=new_dt, alpha=new_alpha, n_pos=new_n_pos
        )

    return optax.GradientTransformation(init_fn, update_fn)  # type: ignore[arg-type]