Skip to content

kups.application.relaxation.simulation

Relaxation propagator construction and simulation runner.

IsRelaxGradients

Bases: Protocol

Protocol for gradient containers returned by relaxation potentials.

Source code in src/kups/application/relaxation/simulation.py
class IsRelaxGradients(Protocol):
    """Protocol for gradient containers returned by relaxation potentials."""

    @property
    def positions(self) -> Table[ParticleId, Array]: ...
    @property
    def cell(self) -> Table[SystemId, Cell]: ...

IsRelaxState

Bases: Protocol

Protocol for relaxation simulation states.

Source code in src/kups/application/relaxation/simulation.py
class IsRelaxState(Protocol):
    """Protocol for relaxation simulation states."""

    @property
    def particles(self) -> Table[ParticleId, RelaxParticles]: ...
    @property
    def systems(self) -> Table[SystemId, RelaxSystems]: ...
    @property
    def opt_state(self) -> optax.OptState: ...
    @property
    def step(self) -> Array: ...

OptInit

Bases: Protocol

Protocol for initialising an Optax optimizer state from gradients.

Source code in src/kups/application/relaxation/simulation.py
class OptInit(Protocol):
    """Protocol for initialising an Optax optimizer state from gradients."""

    def __call__(
        self,
        particles: Table[ParticleId, RelaxParticles],
        systems: Table[SystemId, RelaxSystems],
    ) -> optax.OptState: ...

make_relax_propagator(state_lens, potential, optimizer, optimize_cell=False)

Build a relaxation propagator with step counting and error recovery.

Parameters:

Name Type Description Default
state_lens Lens[State, State]

Lens focusing on the relaxation sub-state.

required
potential Potential[State, Gradients, EmptyType, Any]

Potential whose gradients drive the optimisation.

required
optimizer Optimizer

Optimizer (e.g. FIRE, Adam, L-BFGS).

required
optimize_cell bool

If True, optimise both positions and lattice vectors; otherwise optimise positions only.

False

Returns:

Type Description
Propagator[State]

Tuple of (propagator, opt_init) where propagator performs one

OptInit

optimisation step and opt_init initialises the optimizer state.

Source code in src/kups/application/relaxation/simulation.py
def make_relax_propagator[State: IsRelaxState, Gradients: IsRelaxGradients](
    state_lens: Lens[State, State],
    potential: Potential[State, Gradients, EmptyType, Any],
    optimizer: Optimizer,
    optimize_cell: bool = False,
) -> tuple[Propagator[State], OptInit]:
    """Build a relaxation propagator with step counting and error recovery.

    Args:
        state_lens: Lens focusing on the relaxation sub-state.
        potential: Potential whose gradients drive the optimisation.
        optimizer: Optimizer (e.g. FIRE, Adam, L-BFGS).
        optimize_cell: If True, optimise both positions and lattice vectors;
            otherwise optimise positions only.

    Returns:
        Tuple of ``(propagator, opt_init)`` where *propagator* performs one
        optimisation step and *opt_init* initialises the optimizer state.
    """
    # Cache the gradient and forces within the state
    pot = CachedPotential(
        MappedPotential(potential, lambda x: (x.positions.data, x.cell.data), identity),
        lens(
            lambda x: PotentialOut(
                x.systems.map_data(lambda x: x.potential_energy),
                (
                    x.particles.data.position_gradients,
                    x.systems.data.cell_gradients,
                ),
                EMPTY,
            )
        ),
        lambda x: PotentialOut(
            x.systems.index,  # type: ignore
            (x.particles.data.system, x.systems.index),
            EMPTY,
        ),  # type: ignore
    )

    def relax_prop_and_opt_init[T](prop_view: View[tuple[Array, Cell], T]):
        prop_lens = state_lens.focus(
            lambda x: prop_view((x.particles.data.positions, x.systems.data.cell))
        )

        def opt_init(
            particles: Table[ParticleId, RelaxParticles],
            systems: Table[SystemId, RelaxSystems],
        ) -> optax.OptState:
            params = (particles.data.positions, systems.data.cell)
            indices = (particles.data.system, systems.index)
            return optimizer.init(prop_view(params), prop_view(indices))  # type: ignore

        return RelaxationPropagator(
            potential=MappedPotential(pot, prop_view, identity),
            property=prop_lens,
            opt_state=state_lens.focus(lambda x: x.opt_state),
            optimizer=optimizer,
        ), opt_init

    relax_prop, opt_init = (
        relax_prop_and_opt_init(lens(identity))
        if optimize_cell
        else relax_prop_and_opt_init(lens(lambda x: x[0]))
    )
    step_prop = step_counter_propagator(state_lens.focus(lambda x: x.step))
    return ResetOnErrorPropagator(
        SequentialPropagator((relax_prop, step_prop))
    ), opt_init

run_relax(key, propagator, state, config)

Run structure relaxation with early stopping on force convergence.

Parameters:

Name Type Description Default
key Array

JAX PRNG key.

required
propagator Propagator[State]

Relaxation propagator from make_relax_propagator.

required
state State

Initial simulation state.

required
config RelaxRunConfig

Run configuration (max_steps, force_tolerance, out_file).

required

Returns:

Type Description
State

Final relaxation state after convergence or max_steps.

Source code in src/kups/application/relaxation/simulation.py
def run_relax[State: IsRelaxState](
    key: Array, propagator: Propagator[State], state: State, config: RelaxRunConfig
) -> State:
    """Run structure relaxation with early stopping on force convergence.

    Args:
        key: JAX PRNG key.
        propagator: Relaxation propagator from ``make_relax_propagator``.
        state: Initial simulation state.
        config: Run configuration (max_steps, force_tolerance, out_file).

    Returns:
        Final relaxation state after convergence or ``max_steps``.
    """

    def converged(s: State) -> bool:
        forces = s.particles.data.forces
        max_force = jnp.max(jnp.linalg.norm(forces, axis=-1))
        return bool(max_force < config.force_tolerance)

    def _postfix(s: State) -> dict[str, Any]:
        e = jnp.asarray(s.systems.data.potential_energy).sum()
        fmax = jnp.max(jnp.linalg.norm(s.particles.data.forces, axis=-1))
        return {"E[eV]": f"{float(e): .6f}", "fmax[eV/Å]": f"{float(fmax): .4e}"}

    logger = CompositeLogger(
        TqdmLogger(config.max_steps, postfix=_postfix),
        HDF5StorageWriter(config.out_file, RelaxLoggedData(), state, config.max_steps),
    )
    state = run_simulation_cycles(
        key, propagator, state, config.max_steps, logger, convergence_fn=converged
    )
    return state