Skip to content

kups.application.relaxation.simulation

Relaxation propagator construction and simulation runner.

IsRelaxGradients

Bases: Protocol

Protocol for gradient containers returned by relaxation potentials.

Source code in src/kups/application/relaxation/simulation.py
class IsRelaxGradients(Protocol):
    """Protocol for gradient containers returned by relaxation potentials."""

    @property
    def positions(self) -> Table[ParticleId, Array]: ...
    @property
    def unitcell(self) -> Table[SystemId, UnitCell]: ...

IsRelaxState

Bases: Protocol

Protocol for relaxation simulation states.

Source code in src/kups/application/relaxation/simulation.py
class IsRelaxState(Protocol):
    """Protocol for relaxation simulation states."""

    @property
    def particles(self) -> Table[ParticleId, RelaxParticles]: ...
    @property
    def systems(self) -> Table[SystemId, RelaxSystems]: ...
    @property
    def opt_state(self) -> optax.OptState: ...
    @property
    def step(self) -> Array: ...

OptInit

Bases: Protocol

Protocol for initialising an Optax optimizer state from gradients.

Source code in src/kups/application/relaxation/simulation.py
class OptInit(Protocol):
    """Protocol for initialising an Optax optimizer state from gradients."""

    def __call__(self, grads: tuple[Array, Array]) -> optax.OptState: ...

make_relax_propagator(state_lens, potential, optimizer, optimize_unitcell=False)

Build a relaxation propagator with step counting and error recovery.

Parameters:

Name Type Description Default
state_lens Lens[State, State]

Lens focusing on the relaxation sub-state.

required
potential Potential[State, Gradients, EmptyType, Any]

Potential whose gradients drive the optimisation.

required
optimizer GradientTransformationExtraArgs

Optax gradient transformation (e.g. FIRE, Adam, L-BFGS).

required
optimize_unitcell bool

If True, optimise both positions and lattice vectors; otherwise optimise positions only.

False

Returns:

Type Description
Propagator[State]

Tuple of (propagator, opt_init) where propagator performs one

OptInit

optimisation step and opt_init initialises the Optax optimizer state.

Source code in src/kups/application/relaxation/simulation.py
def make_relax_propagator[State: IsRelaxState, Gradients: IsRelaxGradients](
    state_lens: Lens[State, State],
    potential: Potential[State, Gradients, EmptyType, Any],
    optimizer: optax.GradientTransformationExtraArgs,
    optimize_unitcell: bool = False,
) -> tuple[Propagator[State], OptInit]:
    """Build a relaxation propagator with step counting and error recovery.

    Args:
        state_lens: Lens focusing on the relaxation sub-state.
        potential: Potential whose gradients drive the optimisation.
        optimizer: Optax gradient transformation (e.g. FIRE, Adam, L-BFGS).
        optimize_unitcell: If True, optimise both positions and lattice vectors;
            otherwise optimise positions only.

    Returns:
        Tuple of ``(propagator, opt_init)`` where *propagator* performs one
        optimisation step and *opt_init* initialises the Optax optimizer state.
    """
    # Cache the gradient and forces within the state
    pot = CachedPotential(
        MappedPotential(
            potential, lambda x: (x.positions.data, x.unitcell.data), identity
        ),
        lens(
            lambda x: PotentialOut(
                x.systems.map_data(lambda x: x.potential_energy),
                (
                    x.particles.data.position_gradients,
                    x.systems.data.unitcell_gradients,
                ),
                EMPTY,
            )
        ),
        lambda x: PotentialOut(
            Index.new(x.systems.keys),  # type: ignore
            (x.particles.data.system, Index.new(x.systems.keys)),
            EMPTY,
        ),  # type: ignore
    )

    def relax_prop_and_opt_init[T](prop_view: View[tuple[Array, UnitCell], T]):
        prop_lens = state_lens.focus(
            lambda x: prop_view((x.particles.data.positions, x.systems.data.unitcell))
        )
        return RelaxationPropagator(
            potential=MappedPotential(pot, prop_view, identity),
            property=prop_lens,
            opt_state=state_lens.focus(lambda x: x.opt_state),
            optimizer=optimizer,
        ), lambda grads: optimizer.init(prop_view(grads))  # type: ignore

    relax_prop, opt_init = (
        relax_prop_and_opt_init(lens(identity))
        if optimize_unitcell
        else relax_prop_and_opt_init(lens(lambda x: x[0]))
    )
    step_prop = step_counter_propagator(state_lens.focus(lambda x: x.step))
    return ResetOnErrorPropagator(
        SequentialPropagator((relax_prop, step_prop))
    ), opt_init

run_relax(key, propagator, state, config)

Run structure relaxation with early stopping on force convergence.

Parameters:

Name Type Description Default
key Array

JAX PRNG key.

required
propagator Propagator[State]

Relaxation propagator from make_relax_propagator.

required
state State

Initial simulation state.

required
config RelaxRunConfig

Run configuration (max_steps, force_tolerance, out_file).

required

Returns:

Type Description
State

Final relaxation state after convergence or max_steps.

Source code in src/kups/application/relaxation/simulation.py
def run_relax[State: IsRelaxState](
    key: Array, propagator: Propagator[State], state: State, config: RelaxRunConfig
) -> State:
    """Run structure relaxation with early stopping on force convergence.

    Args:
        key: JAX PRNG key.
        propagator: Relaxation propagator from ``make_relax_propagator``.
        state: Initial simulation state.
        config: Run configuration (max_steps, force_tolerance, out_file).

    Returns:
        Final relaxation state after convergence or ``max_steps``.
    """

    def converged(s: State) -> bool:
        forces = s.particles.data.forces
        max_force = jnp.max(jnp.linalg.norm(forces, axis=-1))
        return bool(max_force < config.force_tolerance)

    logger = CompositeLogger(
        TqdmLogger(config.max_steps),
        HDF5StorageWriter(config.out_file, RelaxLoggedData(), state, config.max_steps),
    )
    state = run_simulation_cycles(
        key, propagator, state, config.max_steps, logger, convergence_fn=converged
    )
    return state