Skip to content

kups.relaxation.transforms

Per-system relaxation transforms compatible with the :class:kups.relaxation.optimizer.Optimizer protocol.

These are batch-aware versions of the transforms in :mod:kups.relaxation.optax: they accept an index_prefix pytree at init time identifying which system each element belongs to, so batched systems are clipped or scaled independently.

ClipByGlobalNorm

Bases: Optimizer[Params, ClipByGlobalNormState]

Clip the per-system L2 norm of updates to max_norm.

With index_prefix=None this reduces to the standard :func:optax.clip_by_global_norm (a single tree-global L2 norm).

Attributes:

Name Type Description
max_norm float

Maximum allowed per-system L2 norm of the update.

Source code in src/kups/relaxation/transforms/clip_by_global_norm.py
@dataclass
class ClipByGlobalNorm[Params](Optimizer[Params, ClipByGlobalNormState]):
    """Clip the per-system L2 norm of updates to ``max_norm``.

    With ``index_prefix=None`` this reduces to the standard
    :func:`optax.clip_by_global_norm` (a single tree-global L2 norm).

    Attributes:
        max_norm: Maximum allowed per-system L2 norm of the update.
    """

    max_norm: float = field(static=True)

    def init(
        self, parameters: Params, index_prefix: PyTree | None = None
    ) -> ClipByGlobalNormState:
        del parameters
        return ClipByGlobalNormState(index_prefix=tree_copy(index_prefix))

    def update(
        self,
        updates: Params,
        state: ClipByGlobalNormState,
        params: Params | None = None,
        **kwargs: Any,
    ) -> tuple[Params, ClipByGlobalNormState]:
        del params, kwargs
        index_prefix = state.index_prefix
        if index_prefix is None:
            index_prefix = jax.tree.map(lambda x: Index.new((0,) * len(x)), updates)
        norm = tree_segment_norm(updates, index_prefix)
        scale = norm.map_data(lambda x: jnp.minimum(1.0, self.max_norm / (x + 1e-12)))
        return tree_scale_per_row(updates, scale, index_prefix), state

ClipByGlobalNormState

State carrying the index_prefix captured at init time.

Attributes:

Name Type Description
index_prefix PyTree | None

Tree prefix of the parameter pytree whose leaves are Index[K] objects, or None to clip with a single global (cross-system) L2 norm.

Source code in src/kups/relaxation/transforms/clip_by_global_norm.py
@dataclass
class ClipByGlobalNormState:
    """State carrying the ``index_prefix`` captured at init time.

    Attributes:
        index_prefix: Tree prefix of the parameter pytree whose leaves are
            ``Index[K]`` objects, or ``None`` to clip with a single global
            (cross-system) L2 norm.
    """

    index_prefix: PyTree | None

MaxStepSize

Bases: Optimizer[Params, MaxStepSizeState]

Clip updates so no element of any system moves more than max_step_size.

Per-element norms are computed along the last axis. For every system, the maximum norm across all elements assigned to that system (across every leaf of updates) is found, and updates for those elements are uniformly scaled so the worst-case norm does not exceed max_step_size. Different systems are scaled independently.

Attributes:

Name Type Description
max_step_size float

Maximum allowed per-element displacement norm.

Source code in src/kups/relaxation/transforms/max_step_size.py
@dataclass
class MaxStepSize[Params](Optimizer[Params, MaxStepSizeState]):
    """Clip updates so no element of any system moves more than ``max_step_size``.

    Per-element norms are computed along the last axis. For every system, the
    maximum norm across all elements assigned to that system (across every leaf
    of ``updates``) is found, and updates for those elements are uniformly
    scaled so the worst-case norm does not exceed ``max_step_size``. Different
    systems are scaled independently.

    Attributes:
        max_step_size: Maximum allowed per-element displacement norm.
    """

    max_step_size: float = field(static=True)

    def init(
        self, parameters: Params, index_prefix: PyTree | None = None
    ) -> MaxStepSizeState:
        del parameters
        return MaxStepSizeState(index_prefix=tree_copy(index_prefix))

    def update(
        self,
        updates: Params,
        state: MaxStepSizeState,
        params: Params | None = None,
        **kwargs: Any,
    ) -> tuple[Params, MaxStepSizeState]:
        del params, kwargs
        index_prefix = state.index_prefix
        if index_prefix is None:
            index_prefix = jax.tree.map(lambda x: Index.new((0,) * len(x)), updates)
        per_particle_size = jax.tree.map(
            functools.partial(jnp.linalg.norm, axis=-1), updates
        )
        max_size = tree_segment_max(per_particle_size, index_prefix)
        scale = max_size.map_data(
            lambda x: jnp.minimum(1.0, self.max_step_size / (x + 1e-12))
        )
        updates = tree_scale_per_row(updates, scale, index_prefix)
        return updates, state

MaxStepSizeState

Optimizer state holding the index_prefix captured at init time.

Attributes:

Name Type Description
index_prefix PyTree | None

Tree prefix of the parameter pytree whose leaves are :class:Index objects, or None for global (cross-system) clipping.

Source code in src/kups/relaxation/transforms/max_step_size.py
@dataclass
class MaxStepSizeState:
    """Optimizer state holding the ``index_prefix`` captured at init time.

    Attributes:
        index_prefix: Tree prefix of the parameter pytree whose leaves are
            :class:`Index` objects, or ``None`` for global (cross-system)
            clipping.
    """

    index_prefix: PyTree | None

ScaleByAseLbfgs

Bases: Optimizer[Params, ScaleByAseLbfgsState]

L-BFGS preconditioner with per-system block-diagonal Hessian.

With a trivial index_prefix (one system) this reduces to the same algorithm as :func:kups.relaxation.optax.scale_by_ase_lbfgs: the initial inverse Hessian is (1/alpha) * I (ASE convention) and the recursion buffers memory_size past (diff_params, diff_updates) pairs. With multiple systems, every system maintains its own independent inverse-Hessian approximation and its own ρᵢ weights.

Attributes:

Name Type Description
memory_size int

Number of past difference pairs to store. >= 1.

alpha float

Initial inverse Hessian is (1/alpha) * I.

Source code in src/kups/relaxation/transforms/lbfgs.py
@dataclass
class ScaleByAseLbfgs[Params](Optimizer[Params, ScaleByAseLbfgsState]):
    """L-BFGS preconditioner with per-system block-diagonal Hessian.

    With a trivial ``index_prefix`` (one system) this reduces to the same
    algorithm as :func:`kups.relaxation.optax.scale_by_ase_lbfgs`:
    the initial inverse Hessian is ``(1/alpha) * I`` (ASE convention) and
    the recursion buffers ``memory_size`` past ``(diff_params, diff_updates)``
    pairs. With multiple systems, every system maintains its own
    independent inverse-Hessian approximation and its own ``ρᵢ`` weights.

    Attributes:
        memory_size: Number of past difference pairs to store. ``>= 1``.
        alpha: Initial inverse Hessian is ``(1/alpha) * I``.
    """

    memory_size: int = field(static=True, default=100)
    alpha: float = field(static=True, default=70.0)

    def __post_init__(self) -> None:
        if self.memory_size < 1:
            raise ValueError("memory_size must be >= 1")

    def init(
        self, parameters: Params, index_prefix: PyTree | None = None
    ) -> ScaleByAseLbfgsState:
        if index_prefix is None:
            index_prefix = jax.tree.map(lambda x: Index.new((0,) * len(x)), parameters)
        idx_leaves = jax.tree.leaves(
            index_prefix, is_leaf=lambda x: isinstance(x, Index)
        )
        first = next(x for x in idx_leaves if isinstance(x, Index))
        keys = first.keys
        n_systems = len(keys)

        stacked_zero = jax.tree.map(
            lambda leaf: jnp.zeros((self.memory_size,) + leaf.shape, dtype=leaf.dtype),
            parameters,
        )
        return ScaleByAseLbfgsState(
            count=jnp.asarray(0, dtype=jnp.int32),
            params=jax.tree.map(jnp.zeros_like, parameters),
            updates=jax.tree.map(jnp.zeros_like, parameters),
            diff_params_memory=stacked_zero,
            diff_updates_memory=jax.tree.map(jnp.zeros_like, stacked_zero),
            weights_memory=Table(keys, jnp.zeros((n_systems, self.memory_size))),
            index_prefix=tree_copy(index_prefix),
        )

    def update(
        self,
        updates: Params,
        state: ScaleByAseLbfgsState,
        params: Params | None = None,
        **kwargs: Any,
    ) -> tuple[Params, ScaleByAseLbfgsState]:
        del kwargs
        if params is None:
            raise ValueError("ScaleByASELBFGS.update requires params")
        idx = state.index_prefix
        memory_idx = state.count % self.memory_size
        prev_memory_idx = (state.count - 1) % self.memory_size

        # Compute fresh (s, y) differences and corresponding ρ = 1/(y·s).
        diff_params = jax.tree.map(jnp.subtract, params, state.params)
        diff_updates = jax.tree.map(jnp.subtract, updates, state.updates)
        vdot_data = tree_vdot(diff_updates, diff_params, idx).data
        weight = jnp.where(vdot_data == 0.0, 0.0, 1.0 / vdot_data)

        # Differences are undefined at the very first iteration; stay zero.
        is_first = state.count == 0
        diff_params = jax.tree.map(
            lambda x: jnp.where(is_first, jnp.zeros_like(x), x), diff_params
        )
        diff_updates = jax.tree.map(
            lambda x: jnp.where(is_first, jnp.zeros_like(x), x), diff_updates
        )
        weight = jnp.where(is_first, jnp.zeros_like(weight), weight)

        diff_params_memory = jax.tree.map(
            lambda mem, x: mem.at[prev_memory_idx].set(x),
            state.diff_params_memory,
            diff_params,
        )
        diff_updates_memory = jax.tree.map(
            lambda mem, x: mem.at[prev_memory_idx].set(x),
            state.diff_updates_memory,
            diff_updates,
        )
        weights_data = state.weights_memory.data.at[:, prev_memory_idx].set(weight)

        precond = _precondition_by_lbfgs_segmented(
            updates,
            diff_params_memory,
            diff_updates_memory,
            weights_data,
            identity_scale=1.0 / self.alpha,
            memory_idx=memory_idx,
            index_prefix=idx,
            keys=state.weights_memory.keys,
        )
        return precond, ScaleByAseLbfgsState(
            count=state.count + 1,
            params=params,
            updates=updates,
            diff_params_memory=diff_params_memory,
            diff_updates_memory=diff_updates_memory,
            weights_memory=state.weights_memory.set_data(weights_data),
            index_prefix=idx,
        )

ScaleByAseLbfgsState

State for the per-system ASE-flavor L-BFGS preconditioner.

Attributes:

Name Type Description
count Array

Total update steps taken so far (scalar int32).

params PyTree

Last seen parameters, pytree matching parameters.

updates PyTree

Last seen gradients/updates.

diff_params_memory PyTree

Stacked past parameter differences, shape (memory_size, *leaf_shape) per leaf.

diff_updates_memory PyTree

Stacked past update differences, same shape.

weights_memory Table[Any, Array]

Per-system per-slot ρᵢ = 1/(yᵢ · sᵢ) weights as Table[K, Array] with data shape (n_systems, memory_size).

index_prefix PyTree

Tree prefix of the parameter pytree whose leaves are Index[K] objects, captured at init time.

Source code in src/kups/relaxation/transforms/lbfgs.py
@dataclass
class ScaleByAseLbfgsState:
    """State for the per-system ASE-flavor L-BFGS preconditioner.

    Attributes:
        count: Total update steps taken so far (scalar int32).
        params: Last seen parameters, pytree matching ``parameters``.
        updates: Last seen gradients/updates.
        diff_params_memory: Stacked past parameter differences, shape
            ``(memory_size, *leaf_shape)`` per leaf.
        diff_updates_memory: Stacked past update differences, same shape.
        weights_memory: Per-system per-slot ``ρᵢ = 1/(yᵢ · sᵢ)`` weights as
            ``Table[K, Array]`` with data shape ``(n_systems, memory_size)``.
        index_prefix: Tree prefix of the parameter pytree whose leaves are
            ``Index[K]`` objects, captured at init time.
    """

    count: Array
    params: PyTree
    updates: PyTree
    diff_params_memory: PyTree
    diff_updates_memory: PyTree
    weights_memory: Table[Any, Array]
    index_prefix: PyTree

ScaleByFire

Bases: Optimizer[Params, ScaleByFireState]

FIRE (Fast Inertial Relaxation Engine) optimizer with per-system state.

Implements Bitzek et al. Phys. Rev. Lett. 97, 170201 (2006), but every global tree reduction is replaced by a per-system reduction over the index_prefix. Each system independently adapts its own dt / alpha / n_pos and sees its own per-system power and norms.

The transform follows the optax convention: updates passed to :meth:update is interpreted as the force F = -∇L (the descent direction). Sign conversion from a raw gradient and any clipping live in the surrounding :func:kups.relaxation.optimizer.chain — see the module docstring for a worked example.

.. note::

This is the original FIRE 1.0. For most production relaxations
prefer :class:`kups.relaxation.transforms.ScaleByFire2`, which
Guénolé et al. 2020 (Fig. 4–6) report converges in ~1.5–3×
fewer force calls on Lennard-Jones, EAM and Tersoff
benchmarks. ABC-FIRE (``ScaleByFire2(use_abc=True)``, Echeverri
Restrepo & Andric 2023, Fig. 2–3) is typically a further
~10–40% faster, but takes more aggressive steps and is
correspondingly more prone to diverging on poorly conditioned
or noisy landscapes — enable it only after a plain FIRE 2.0
run is known to be stable. FIRE 1.0 remains useful as a
well-tested baseline and for comparison with legacy results.

Composable clipping:

  • :class:kups.relaxation.transforms.ClipByGlobalNorm — per-system L2 cap on the input force (prepend before FIRE).
  • :class:kups.relaxation.transforms.MaxStepSize — per-particle ∞/L2 cap on the output displacement (append after FIRE).

Attributes:

Name Type Description
dt_start float

Initial timestep.

dt_max float | None

Maximum timestep. Defaults to 10 * dt_start.

dt_min float | None

Minimum timestep. Defaults to dt_start * 1e-4.

f_inc float

Factor to increase dt when making progress.

f_dec float

Factor to decrease dt on a bad step.

alpha_start float

Initial velocity-mixing parameter.

f_alpha float

Factor to decay alpha when making progress.

n_min int

Minimum positive-power steps before dt is allowed to grow.

Source code in src/kups/relaxation/transforms/fire.py
@dataclass
class ScaleByFire[Params](Optimizer[Params, ScaleByFireState]):
    """FIRE (Fast Inertial Relaxation Engine) optimizer with per-system state.

    Implements Bitzek et al. *Phys. Rev. Lett.* **97**, 170201 (2006), but
    every global tree reduction is replaced by a per-system reduction over
    the ``index_prefix``. Each system independently adapts its own
    ``dt`` / ``alpha`` / ``n_pos`` and sees its own per-system power and
    norms.

    The transform follows the optax convention: ``updates`` passed to
    :meth:`update` is interpreted as the force ``F = -∇L`` (the descent
    direction). Sign conversion from a raw gradient and any clipping live
    in the surrounding :func:`kups.relaxation.optimizer.chain` — see the
    module docstring for a worked example.

    .. note::

        This is the original FIRE 1.0. For most production relaxations
        prefer :class:`kups.relaxation.transforms.ScaleByFire2`, which
        Guénolé et al. 2020 (Fig. 4–6) report converges in ~1.5–3×
        fewer force calls on Lennard-Jones, EAM and Tersoff
        benchmarks. ABC-FIRE (``ScaleByFire2(use_abc=True)``, Echeverri
        Restrepo & Andric 2023, Fig. 2–3) is typically a further
        ~10–40% faster, but takes more aggressive steps and is
        correspondingly more prone to diverging on poorly conditioned
        or noisy landscapes — enable it only after a plain FIRE 2.0
        run is known to be stable. FIRE 1.0 remains useful as a
        well-tested baseline and for comparison with legacy results.

    Composable clipping:

    * :class:`kups.relaxation.transforms.ClipByGlobalNorm` — per-system
      L2 cap on the input force (prepend before FIRE).
    * :class:`kups.relaxation.transforms.MaxStepSize` — per-particle
      ∞/L2 cap on the output displacement (append after FIRE).

    Attributes:
        dt_start: Initial timestep.
        dt_max: Maximum timestep. Defaults to ``10 * dt_start``.
        dt_min: Minimum timestep. Defaults to ``dt_start * 1e-4``.
        f_inc: Factor to increase dt when making progress.
        f_dec: Factor to decrease dt on a bad step.
        alpha_start: Initial velocity-mixing parameter.
        f_alpha: Factor to decay alpha when making progress.
        n_min: Minimum positive-power steps before dt is allowed to grow.
    """

    dt_start: float = field(static=True, default=0.1)
    dt_max: float | None = field(static=True, default=None)
    dt_min: float | None = field(static=True, default=None)
    f_inc: float = field(static=True, default=1.1)
    f_dec: float = field(static=True, default=0.5)
    alpha_start: float = field(static=True, default=0.1)
    f_alpha: float = field(static=True, default=0.99)
    n_min: int = field(static=True, default=5)

    @property
    def _dt_max(self) -> float:
        return self.dt_max if self.dt_max is not None else 10.0 * self.dt_start

    @property
    def _dt_min(self) -> float:
        return self.dt_min if self.dt_min is not None else self.dt_start * 1e-4

    def init(
        self, parameters: Params, index_prefix: PyTree | None = None
    ) -> ScaleByFireState:
        if index_prefix is None:
            index_prefix = jax.tree.map(lambda x: Index.new((0,) * len(x)), parameters)
        idx_leaves = jax.tree.leaves(
            index_prefix, is_leaf=lambda x: isinstance(x, Index)
        )
        first = next(x for x in idx_leaves if isinstance(x, Index))
        keys = first.keys
        n = len(keys)
        return ScaleByFireState(
            velocity=jax.tree.map(jnp.zeros_like, parameters),
            dt=Table(keys, jnp.full((n,), self.dt_start)),
            alpha=Table(keys, jnp.full((n,), self.alpha_start)),
            n_pos=Table(keys, jnp.zeros((n,), dtype=jnp.int32)),
            index_prefix=tree_copy(index_prefix),
        )

    def update(
        self,
        updates: Params,
        state: ScaleByFireState,
        params: Params | None = None,
        **kwargs: Any,
    ) -> tuple[Params, ScaleByFireState]:
        del params, kwargs
        idx = state.index_prefix

        # ``updates`` IS the force F = -∇L (optax convention); see module
        # docstring. v <- v + dt[s] · F (per-system dt broadcast per particle).
        velocity = jax.tree.map(
            lambda v, sf: v + sf,
            state.velocity,
            tree_scale_per_row(updates, state.dt, idx),
        )

        # Per-system power P = F · v and its sign.
        power = tree_vdot(updates, velocity, idx)
        positive = power.data > 0.0

        # Per-system L2 norms; safe denominator for ||F||.
        v_norm = tree_segment_norm(velocity, idx)
        f_norm = tree_segment_norm(updates, idx)
        safe_f_norm = jnp.maximum(f_norm.data, 1e-10)

        # Mixed velocity per particle: v' = (1-α)·v + α·||v||/||F|| · F.
        v_scale = state.alpha.set_data(1.0 - state.alpha.data)
        f_scale = state.alpha.set_data(state.alpha.data * v_norm.data / safe_f_norm)
        mixed_velocity = jax.tree.map(
            lambda a, b: a + b,
            tree_scale_per_row(velocity, v_scale, idx),
            tree_scale_per_row(updates, f_scale, idx),
        )

        # Adaptive dt / alpha / n_pos updates per system.
        should_increase = positive & (state.n_pos.data >= self.n_min)
        new_dt_data = jnp.where(
            positive,
            jnp.where(
                should_increase,
                jnp.minimum(state.dt.data * self.f_inc, self._dt_max),
                state.dt.data,
            ),
            jnp.maximum(state.dt.data * self.f_dec, self._dt_min),
        )
        new_alpha_data = jnp.where(
            positive,
            jnp.where(
                should_increase,
                state.alpha.data * self.f_alpha,
                state.alpha.data,
            ),
            jnp.full_like(state.alpha.data, self.alpha_start),
        )
        new_n_pos_data = jnp.where(
            positive, state.n_pos.data + 1, jnp.zeros_like(state.n_pos.data)
        )

        # Per-system gating: zero velocity / position update where P <= 0.
        gate = state.dt.set_data(positive.astype(state.dt.data.dtype))
        final_velocity = tree_scale_per_row(mixed_velocity, gate, idx)
        position_updates = tree_scale_per_row(
            mixed_velocity,
            state.dt.set_data(jnp.where(positive, state.dt.data, 0.0)),
            idx,
        )

        return position_updates, ScaleByFireState(
            velocity=final_velocity,
            dt=state.dt.set_data(new_dt_data),
            alpha=state.alpha.set_data(new_alpha_data),
            n_pos=state.n_pos.set_data(new_n_pos_data),
            index_prefix=idx,
        )

ScaleByFire2

Bases: Optimizer[Params, ScaleByFire2State]

FIRE 2.0 (with optional ABC-FIRE) with per-system block-diagonal state.

Per-system port of the LAMMPS-style FIRE 2.0 integrator described in Guénolé et al. 2020, with the ABC-FIRE bias correction (use_abc=True) of Echeverri Restrepo & Andric 2023. With a single system this reduces to the algorithm from kups.relaxation.optax.scale_by_fire2; with multiple systems each system independently adapts its own dt / alpha / n_pos and sees its own per-system power, norms and dmax.

The transform follows the optax convention: updates passed to :meth:update is interpreted as the force F = -∇L (the descent direction). Sign conversion from a raw gradient and any external clipping live in the surrounding :func:kups.relaxation.optimizer.chain — see the module docstring for a worked example. The LAMMPS-style dmax clip configured via :attr:max_step is internal to FIRE 2.0 and applies on top of any composed clipping.

Attributes:

Name Type Description
dt_start float

Initial timestep.

dt_max float

Maximum timestep (LAMMPS dtmax).

dt_min float

Minimum timestep (LAMMPS dtmin).

max_step float | None

Per-step displacement bound dmax. use_abc=False applies it as a one-shot ∞-norm timestep rescale (max_i |Δx_i| ≤ max_step); use_abc=True applies it as a per-component velocity clip that persists into the next step. None disables it. The clip is per-system: each system's ∞-norm or component limit is computed independently.

f_inc float

Factor to grow dt (LAMMPS dtgrow).

f_dec float

Factor to shrink dt (LAMMPS dtshrink).

alpha_start float

Initial velocity-mixing parameter (LAMMPS alpha0).

f_alpha float

Factor to decay alpha (LAMMPS alphashrink).

n_min int

Minimum positive-power steps before dt is allowed to grow (LAMMPS delaystep).

use_abc bool

If True, apply ABC-FIRE bias correction to the mixing.

halfstepback bool

If True, apply x -= 0.5·new_dt·v_old on the non-positive-power branch.

delaystep_start bool

If True, suppress dt shrink and alpha reset while n_total < n_min.

References
  • Guénolé et al., Comput. Mater. Sci. 175, 109584 (2020).
  • Echeverri Restrepo & Andric, Comput. Mater. Sci. 218, 111978 (2023).
  • LAMMPS src/min_fire.cpp (develop branch).
Source code in src/kups/relaxation/transforms/fire2.py
@dataclass
class ScaleByFire2[Params](Optimizer[Params, ScaleByFire2State]):
    """FIRE 2.0 (with optional ABC-FIRE) with per-system block-diagonal state.

    Per-system port of the LAMMPS-style FIRE 2.0 integrator described in
    Guénolé et al. 2020, with the ABC-FIRE bias correction
    (``use_abc=True``) of Echeverri Restrepo & Andric 2023. With a single
    system this reduces to the algorithm from
    ``kups.relaxation.optax.scale_by_fire2``; with multiple systems each
    system independently adapts its own ``dt`` / ``alpha`` / ``n_pos``
    and sees its own per-system power, norms and ``dmax``.

    The transform follows the optax convention: ``updates`` passed to
    :meth:`update` is interpreted as the force ``F = -∇L`` (the descent
    direction). Sign conversion from a raw gradient and any external
    clipping live in the surrounding
    :func:`kups.relaxation.optimizer.chain` — see the module docstring
    for a worked example. The LAMMPS-style ``dmax`` clip configured via
    :attr:`max_step` is internal to FIRE 2.0 and applies on top of any
    composed clipping.

    Attributes:
        dt_start: Initial timestep.
        dt_max: Maximum timestep (LAMMPS ``dtmax``).
        dt_min: Minimum timestep (LAMMPS ``dtmin``).
        max_step: Per-step displacement bound ``dmax``. ``use_abc=False``
            applies it as a one-shot ∞-norm timestep rescale
            (``max_i |Δx_i| ≤ max_step``); ``use_abc=True`` applies it as
            a per-component velocity clip that persists into the next
            step. ``None`` disables it. The clip is per-system: each
            system's ∞-norm or component limit is computed independently.
        f_inc: Factor to grow ``dt`` (LAMMPS ``dtgrow``).
        f_dec: Factor to shrink ``dt`` (LAMMPS ``dtshrink``).
        alpha_start: Initial velocity-mixing parameter (LAMMPS ``alpha0``).
        f_alpha: Factor to decay ``alpha`` (LAMMPS ``alphashrink``).
        n_min: Minimum positive-power steps before ``dt`` is allowed to
            grow (LAMMPS ``delaystep``).
        use_abc: If True, apply ABC-FIRE bias correction to the mixing.
        halfstepback: If True, apply ``x -= 0.5·new_dt·v_old`` on the
            non-positive-power branch.
        delaystep_start: If True, suppress ``dt`` shrink and ``alpha``
            reset while ``n_total < n_min``.

    References:
        * Guénolé et al., *Comput. Mater. Sci.* **175**, 109584 (2020).
        * Echeverri Restrepo & Andric, *Comput. Mater. Sci.* **218**,
          111978 (2023).
        * LAMMPS ``src/min_fire.cpp`` (develop branch).
    """

    dt_start: float = field(static=True, default=0.1)
    dt_max: float = field(static=True, default=1.0)
    dt_min: float = field(static=True, default=2e-3)
    max_step: float | None = field(static=True, default=0.1)
    f_inc: float = field(static=True, default=1.1)
    f_dec: float = field(static=True, default=0.5)
    alpha_start: float = field(static=True, default=0.25)
    f_alpha: float = field(static=True, default=0.99)
    n_min: int = field(static=True, default=20)
    use_abc: bool = field(static=True, default=False)
    halfstepback: bool = field(static=True, default=True)
    delaystep_start: bool = field(static=True, default=True)

    def init(
        self, parameters: Params, index_prefix: PyTree | None = None
    ) -> ScaleByFire2State:
        if index_prefix is None:
            index_prefix = jax.tree.map(lambda x: Index.new((0,) * len(x)), parameters)
        idx_leaves = jax.tree.leaves(
            index_prefix, is_leaf=lambda x: isinstance(x, Index)
        )
        first = next(x for x in idx_leaves if isinstance(x, Index))
        keys = first.keys
        n = len(keys)
        return ScaleByFire2State(
            velocity=jax.tree.map(jnp.zeros_like, parameters),
            dt=Table(keys, jnp.full((n,), self.dt_start)),
            alpha=Table(keys, jnp.full((n,), self.alpha_start)),
            n_pos=Table(keys, jnp.zeros((n,), dtype=jnp.int32)),
            n_total=jnp.asarray(0, dtype=jnp.int32),
            index_prefix=tree_copy(index_prefix),
        )

    def update(
        self,
        updates: Params,
        state: ScaleByFire2State,
        params: Params | None = None,
        **kwargs: Any,
    ) -> tuple[Params, ScaleByFire2State]:
        del params, kwargs
        idx = state.index_prefix
        keys = state.dt.keys
        dt_data = state.dt.data
        alpha_data = state.alpha.data
        float_dtype = dt_data.dtype
        n_total = state.n_total + 1

        # ``updates`` IS the force F = -∇L (optax convention); see module
        # docstring. P = v_old · F per system (LAMMPS: vdotfall).
        power = tree_vdot(updates, state.velocity, idx).data
        positive = power > 0.0

        # ----- n_pos (LAMMPS: ntimestep - last_negative) ------------------
        new_n_pos = jnp.where(positive, state.n_pos.data + 1, 0)
        should_increase = positive & (new_n_pos > self.n_min)

        # ----- dt adaptation per system -----------------------------------
        dt_increased = jnp.minimum(dt_data * self.f_inc, self.dt_max)
        dt_decreased = jnp.where(
            dt_data * self.f_dec >= self.dt_min,
            dt_data * self.f_dec,
            dt_data,
        )
        new_dt = jnp.where(
            positive,
            jnp.where(should_increase, dt_increased, dt_data),
            dt_decreased,
        )

        # ----- alpha adaptation per system --------------------------------
        alpha_for_mixing = (
            jnp.maximum(alpha_data, 1e-10) if self.use_abc else alpha_data
        )
        new_alpha = jnp.where(
            positive,
            jnp.where(
                should_increase,
                alpha_for_mixing * self.f_alpha,
                alpha_for_mixing,
            ),
            jnp.full_like(alpha_data, self.alpha_start),
        )

        # ----- delaystep_start: suppress shrink during startup ------------
        if self.delaystep_start:
            in_startup = (~positive) & (n_total < self.n_min)
            new_dt = jnp.where(in_startup, dt_data, new_dt)
            new_alpha = jnp.where(in_startup, alpha_data, new_alpha)

        # ----- Mixing scales (use OLD velocity, per system) --------------
        v_old_sq = tree_vdot(state.velocity, state.velocity, idx).data
        f_sq = tree_vdot(updates, updates, idx).data

        if self.use_abc:
            abc = jnp.where(
                positive,
                1.0 - jnp.power(1.0 - alpha_for_mixing, new_n_pos.astype(float_dtype)),
                1.0,
            )
            safe_abc = jnp.maximum(abc, 1e-30)
            scale1 = jnp.where(positive, (1.0 - alpha_for_mixing) / safe_abc, 1.0)
            scale2_raw = jnp.where(
                f_sq <= 1e-20,  # type: ignore[operator]
                0.0,
                (alpha_for_mixing * jnp.sqrt(v_old_sq / jnp.maximum(f_sq, 1e-20)))
                / safe_abc,
            )
            scale2 = jnp.where(positive, scale2_raw, 0.0)
        else:
            scale1 = 1.0 - alpha_data
            scale2 = jnp.where(
                f_sq <= 1e-20,  # type: ignore[operator]
                0.0,
                alpha_data * jnp.sqrt(v_old_sq / jnp.maximum(f_sq, 1e-20)),
            )

        # ----- dmax: compute dtv (non-ABC only) per system ----------------
        if self.max_step is not None and not self.use_abc:
            abs_v = jax.tree.map(jnp.abs, state.velocity)
            abs_f = jax.tree.map(jnp.abs, updates)
            vmax_pos = tree_segment_max(abs_v, idx).data
            vmax_neg = new_dt * tree_segment_max(abs_f, idx).data
            vmax = jnp.where(positive, vmax_pos, vmax_neg)
            dtv = jnp.where(
                new_dt * vmax > self.max_step,
                self.max_step / jnp.maximum(vmax, 1e-30),
                new_dt,
            )
        else:
            dtv = new_dt

        # ----- Half-step backtrack: -0.5·new_dt·v_old per particle --------
        if self.halfstepback:
            backtrack = tree_scale_per_row(
                state.velocity, Table(keys, -0.5 * new_dt), idx
            )
        else:
            backtrack = jax.tree.map(jnp.zeros_like, state.velocity)

        # ----- v_pre: zero on P<=0, keep on P>0 ---------------------------
        gate = Table(keys, positive.astype(float_dtype))
        v_pre = tree_scale_per_row(state.velocity, gate, idx)

        # ----- Euler-implicit kick: v += dtv · F --------------------------
        scaled_f = tree_scale_per_row(updates, Table(keys, dtv), idx)
        v_int = jax.tree.map(jnp.add, v_pre, scaled_f)

        # ----- Mixing (applied only when P > 0): v = s1·v + s2·F ----------
        v_mixed = jax.tree.map(
            jnp.add,
            tree_scale_per_row(v_int, Table(keys, scale1), idx),
            tree_scale_per_row(updates, Table(keys, scale2), idx),
        )
        new_velocity = tree_where_per_row(Table(keys, positive), v_mixed, v_int, idx)

        # ----- ABC per-component dmax clip (P>0 only) ---------------------
        if self.max_step is not None and self.use_abc:
            effective_limit = jnp.where(
                positive,
                self.max_step / jnp.maximum(dtv, 1e-30),
                jnp.inf,
            )
            new_velocity = tree_clip_per_row(
                new_velocity, Table(keys, effective_limit), idx
            )

        # ----- Position update: dtv · v + (~positive) · backtrack ---------
        main = tree_scale_per_row(new_velocity, Table(keys, dtv), idx)
        not_positive = Table(keys, (1.0 - positive.astype(float_dtype)))
        gated_backtrack = tree_scale_per_row(backtrack, not_positive, idx)
        position_updates = jax.tree.map(jnp.add, main, gated_backtrack)

        return position_updates, ScaleByFire2State(
            velocity=new_velocity,
            dt=state.dt.set_data(new_dt),
            alpha=state.alpha.set_data(new_alpha),
            n_pos=state.n_pos.set_data(new_n_pos),
            n_total=n_total,
            index_prefix=idx,
        )

ScaleByFire2State

Optimizer state for the per-system FIRE 2.0 transform.

Attributes:

Name Type Description
velocity PyTree

Velocity estimate, pytree matching the parameters.

dt Table[Any, Array]

Per-system adaptive timestep.

alpha Table[Any, Array]

Per-system velocity-mixing parameter.

n_pos Table[Any, Array]

Per-system count of consecutive positive-power steps (LAMMPS ntimestep - last_negative; also the ABC-FIRE bias exponent).

n_total Array

Scalar — total update steps taken so far (drives delaystep_start).

index_prefix PyTree

Tree prefix of the parameter pytree whose leaves are Index[K] objects, captured at init time.

Source code in src/kups/relaxation/transforms/fire2.py
@dataclass
class ScaleByFire2State:
    """Optimizer state for the per-system FIRE 2.0 transform.

    Attributes:
        velocity: Velocity estimate, pytree matching the parameters.
        dt: Per-system adaptive timestep.
        alpha: Per-system velocity-mixing parameter.
        n_pos: Per-system count of consecutive positive-power steps
            (LAMMPS ``ntimestep - last_negative``; also the ABC-FIRE bias
            exponent).
        n_total: Scalar — total update steps taken so far (drives
            ``delaystep_start``).
        index_prefix: Tree prefix of the parameter pytree whose leaves are
            ``Index[K]`` objects, captured at init time.
    """

    velocity: PyTree
    dt: Table[Any, Array]
    alpha: Table[Any, Array]
    n_pos: Table[Any, Array]
    n_total: Array
    index_prefix: PyTree

ScaleByFireState

Optimizer state for the per-system FIRE transform.

Attributes:

Name Type Description
velocity PyTree

Velocity estimate (PyTree matching params).

dt Table[Any, Array]

Per-system adaptive timestep.

alpha Table[Any, Array]

Per-system velocity-mixing parameter.

n_pos Table[Any, Array]

Per-system count of consecutive positive-power steps.

index_prefix PyTree

Tree prefix of the parameter pytree whose leaves are Index[K] objects, captured at init time.

Source code in src/kups/relaxation/transforms/fire.py
@dataclass
class ScaleByFireState:
    """Optimizer state for the per-system FIRE transform.

    Attributes:
        velocity: Velocity estimate (PyTree matching params).
        dt: Per-system adaptive timestep.
        alpha: Per-system velocity-mixing parameter.
        n_pos: Per-system count of consecutive positive-power steps.
        index_prefix: Tree prefix of the parameter pytree whose leaves are
            ``Index[K]`` objects, captured at init time.
    """

    velocity: PyTree
    dt: Table[Any, Array]
    alpha: Table[Any, Array]
    n_pos: Table[Any, Array]
    index_prefix: PyTree