Skip to content

kups.application.utils.propagate

Shared propagation utilities for simulation loops.

Provides warmup, sampling, and data-parallelism helpers used across MD, MCMC, and relaxation application modules.

data_parallelism_put(state)

Place state on the batch-sharded mesh.

Source code in src/kups/application/utils/propagate.py
def data_parallelism_put[S](state: S) -> S:
    """Place ``state`` on the batch-sharded mesh."""
    return jax.device_put(state, BATCH_SHARDING)

data_parallelism_vmap(f)

Vmap f with multi-device sharding when available.

On a single device the function is simply vmapped and JIT-compiled. With multiple devices it is additionally shard-mapped across the "batch" mesh axis.

Source code in src/kups/application/utils/propagate.py
def data_parallelism_vmap[C: Callable](f: C) -> C:
    """Vmap ``f`` with multi-device sharding when available.

    On a single device the function is simply vmapped and JIT-compiled.
    With multiple devices it is additionally shard-mapped across the
    ``"batch"`` mesh axis.
    """
    vmapped_fn = jax.vmap(f)
    if jax.device_count() > 1:
        return jit(
            shard_map(
                vmapped_fn,
                out_specs=BATCH_P,
                in_specs=BATCH_P,
                mesh=BATCH_MESH,
            ),
            donate_argnums=(1,),
        )
    return jit(vmapped_fn, donate_argnums=(1,))

propagate_and_fix(fn, key, state, *, max_tries=10)

Execute a propagator repeatedly until all assertions pass or retries are exhausted.

On each attempt, failed assertions are repaired via their fix functions. Raises if a failed assertion has no fix function or retries run out.

Parameters:

Name Type Description Default
fn Callable[[Array, State], Result[State, State]]

Assertion-aware propagator produced by :func:propagator_with_assertions.

required
key Array

JAX PRNG key.

required
state State

Current simulation state.

required
max_tries int

Maximum number of repair attempts.

10

Returns:

Type Description
State

Propagated state with all assertions satisfied.

Raises:

Type Description
ValueError

If called inside a JAX transform.

RuntimeError

If assertions still fail after max_tries attempts.

Source code in src/kups/core/propagator.py
def propagate_and_fix[State](
    fn: Callable[[Array, State], Result[State, State]],
    key: Array,
    state: State,
    *,
    max_tries: int = 10,
) -> State:
    """Execute a propagator repeatedly until all assertions pass or retries are exhausted.

    On each attempt, failed assertions are repaired via their fix functions.
    Raises if a failed assertion has no fix function or retries run out.

    Args:
        fn: Assertion-aware propagator produced by :func:`propagator_with_assertions`.
        key: JAX PRNG key.
        state: Current simulation state.
        max_tries: Maximum number of repair attempts.

    Returns:
        Propagated state with all assertions satisfied.

    Raises:
        ValueError: If called inside a JAX transform.
        RuntimeError: If assertions still fail after ``max_tries`` attempts.
    """
    is_traced = any(isinstance(x, jax.core.Tracer) for x in jax.tree.leaves(state))
    if is_traced:
        raise ValueError("propagate_and_fix cannot be jax transformed.")

    for _ in range(max_tries):
        out = fn(key, state)
        state = out.value
        if not out.failed_assertions:
            return state
        state = out.fix_or_raise(state)
    raise RuntimeError("Failed to resolve potential after multiple attempts")

propagator_with_assertions(propagator)

Wrap a propagator to capture assertion results alongside the state.

Parameters:

Name Type Description Default
propagator Propagator[State]

Propagator to wrap.

required

Returns:

Type Description
Callable[[Array, State], Result[State, State]]

Function returning a Result that pairs the new state with assertion metadata.

Source code in src/kups/core/propagator.py
def propagator_with_assertions[State](
    propagator: Propagator[State],
) -> Callable[[Array, State], Result[State, State]]:
    """Wrap a propagator to capture assertion results alongside the state.

    Args:
        propagator: Propagator to wrap.

    Returns:
        Function returning a Result that pairs the new state with assertion metadata.
    """
    return as_result_function(propagator)

run_simulation_cycles(key, propagator, state, num_cycles, logger, *, convergence_fn=None)

Run simulation steps with logging and optional early stopping.

Parameters:

Name Type Description Default
key Array

JAX PRNG key for stochastic propagators (e.g. MD thermostats).

required
propagator Propagator[State]

Step propagator.

required
state State

Initial state.

required
num_cycles int

Maximum number of steps.

required
logger Logger[State]

Logger receiving state each step.

required
convergence_fn Callable[[State], bool] | None

If provided, called after each step; stops early when it returns True.

None

Returns:

Type Description
State

State after all steps or early convergence.

Source code in src/kups/application/utils/propagate.py
def run_simulation_cycles[State](
    key: Array,
    propagator: Propagator[State],
    state: State,
    num_cycles: int,
    logger: Logger[State],
    *,
    convergence_fn: Callable[[State], bool] | None = None,
) -> State:
    """Run simulation steps with logging and optional early stopping.

    Args:
        key: JAX PRNG key for stochastic propagators (e.g. MD thermostats).
        propagator: Step propagator.
        state: Initial state.
        num_cycles: Maximum number of steps.
        logger: Logger receiving state each step.
        convergence_fn: If provided, called after each step; stops early when
            it returns True.

    Returns:
        State after all steps or early convergence.
    """
    chain = key_chain(key)
    prop_with_assertions = jit(as_result_function(propagator), donate_argnums=(1,))
    with logger:
        for i in range(num_cycles):
            state = propagate_and_fix(prop_with_assertions, next(chain), state)
            logger.log(state, i)
            if convergence_fn is not None and convergence_fn(state):
                logging.info("Converged at step %d", i + 1)
                break
    return state

run_warmup_cycles(key, propagator, state, num_cycles)

Run warmup propagation cycles without logging.

Parameters:

Name Type Description Default
key Array

JAX PRNG key.

required
propagator Propagator[State]

Step propagator.

required
state State

Initial simulation state.

required
num_cycles int

Number of warmup steps.

required

Returns:

Type Description
State

State after warmup.

Source code in src/kups/application/utils/propagate.py
def run_warmup_cycles[State](
    key: Array, propagator: Propagator[State], state: State, num_cycles: int
) -> State:
    """Run warmup propagation cycles without logging.

    Args:
        key: JAX PRNG key.
        propagator: Step propagator.
        state: Initial simulation state.
        num_cycles: Number of warmup steps.

    Returns:
        State after warmup.
    """
    chain = key_chain(key)
    propagator_with_assertion = jit(as_result_function(propagator), donate_argnums=(1,))
    for _ in tqdm.trange(num_cycles):
        state = propagate_and_fix(propagator_with_assertion, next(chain), state)
    return state

warmup_and_sample(chain, propagator, state, warmup_cycles, num_samples, *, print_progress=True)

Warm up and collect evenly-spaced state snapshots.

Samples num_samples states from the last 20% of warmup steps. Falls back to the final num_samples steps when 20% is too few.

Parameters:

Name Type Description Default
chain Generator[Array, Any, Any]

PRNG key generator.

required
propagator Callable[[Array, State], Result[State, State]]

JIT-compiled propagator returning Result.

required
state State

Initial state.

required
warmup_cycles int

Total number of propagation steps.

required
num_samples int

Number of snapshots to collect.

required
print_progress bool

Show a tqdm progress bar.

True

Returns:

Type Description
list[State]

List of num_samples deep-copied state snapshots.

Raises:

Type Description
ValueError

If warmup_cycles is too small for num_samples.

Source code in src/kups/application/utils/propagate.py
def warmup_and_sample[State](
    chain: Generator[Array, Any, Any],
    propagator: Callable[[Array, State], Result[State, State]],
    state: State,
    warmup_cycles: int,
    num_samples: int,
    *,
    print_progress: bool = True,
) -> list[State]:
    """Warm up and collect evenly-spaced state snapshots.

    Samples ``num_samples`` states from the last 20% of warmup steps.
    Falls back to the final ``num_samples`` steps when 20% is too few.

    Args:
        chain: PRNG key generator.
        propagator: JIT-compiled propagator returning ``Result``.
        state: Initial state.
        warmup_cycles: Total number of propagation steps.
        num_samples: Number of snapshots to collect.
        print_progress: Show a tqdm progress bar.

    Returns:
        List of ``num_samples`` deep-copied state snapshots.

    Raises:
        ValueError: If ``warmup_cycles`` is too small for ``num_samples``.
    """
    # Spread samples over the last 20% of steps; fall back to the tail.
    if warmup_cycles * 0.2 < num_samples:
        take_from = set(range(warmup_cycles - num_samples, warmup_cycles))
    else:
        take_from = set(
            np.linspace(
                warmup_cycles * 0.8,
                warmup_cycles - 1,
                num_samples + 1,
                dtype=int,
            )[1:]
        )
    if len(take_from) < num_samples:
        raise ValueError(
            "Not enough warmup cycles to take the requested number of states."
        )

    states = []
    with tqdm.trange(warmup_cycles, disable=not print_progress) as pbar:
        for i in pbar:
            state = propagate_and_fix(propagator, next(chain), state)
            if i in take_from:
                states.append(deepcopy(state))
    assert len(states) == num_samples
    return states