kups.application.utils.propagate
¶
Shared propagation utilities for simulation loops.
Provides warmup, sampling, and data-parallelism helpers used across MD, MCMC, and relaxation application modules.
data_parallelism_put(state)
¶
data_parallelism_vmap(f)
¶
Vmap f with multi-device sharding when available.
On a single device the function is simply vmapped and JIT-compiled.
With multiple devices it is additionally shard-mapped across the
"batch" mesh axis.
Source code in src/kups/application/utils/propagate.py
propagate_and_fix(fn, key, state, *, max_tries=10)
¶
Execute a propagator repeatedly until all assertions pass or retries are exhausted.
On each attempt, failed assertions are repaired via their fix functions. Raises if a failed assertion has no fix function or retries run out.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Callable[[Array, State], Result[State, State]]
|
Assertion-aware propagator produced by :func: |
required |
key
|
Array
|
JAX PRNG key. |
required |
state
|
State
|
Current simulation state. |
required |
max_tries
|
int
|
Maximum number of repair attempts. |
10
|
Returns:
| Type | Description |
|---|---|
State
|
Propagated state with all assertions satisfied. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If called inside a JAX transform. |
RuntimeError
|
If assertions still fail after |
Source code in src/kups/core/propagator.py
propagator_with_assertions(propagator)
¶
Wrap a propagator to capture assertion results alongside the state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
propagator
|
Propagator[State]
|
Propagator to wrap. |
required |
Returns:
| Type | Description |
|---|---|
Callable[[Array, State], Result[State, State]]
|
Function returning a Result that pairs the new state with assertion metadata. |
Source code in src/kups/core/propagator.py
run_simulation_cycles(key, propagator, state, num_cycles, logger, *, convergence_fn=None)
¶
Run simulation steps with logging and optional early stopping.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
Array
|
JAX PRNG key for stochastic propagators (e.g. MD thermostats). |
required |
propagator
|
Propagator[State]
|
Step propagator. |
required |
state
|
State
|
Initial state. |
required |
num_cycles
|
int
|
Maximum number of steps. |
required |
logger
|
Logger[State]
|
Logger receiving state each step. |
required |
convergence_fn
|
Callable[[State], bool] | None
|
If provided, called after each step; stops early when it returns True. |
None
|
Returns:
| Type | Description |
|---|---|
State
|
State after all steps or early convergence. |
Source code in src/kups/application/utils/propagate.py
run_warmup_cycles(key, propagator, state, num_cycles)
¶
Run warmup propagation cycles without logging.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
Array
|
JAX PRNG key. |
required |
propagator
|
Propagator[State]
|
Step propagator. |
required |
state
|
State
|
Initial simulation state. |
required |
num_cycles
|
int
|
Number of warmup steps. |
required |
Returns:
| Type | Description |
|---|---|
State
|
State after warmup. |
Source code in src/kups/application/utils/propagate.py
warmup_and_sample(chain, propagator, state, warmup_cycles, num_samples, *, print_progress=True)
¶
Warm up and collect evenly-spaced state snapshots.
Samples num_samples states from the last 20% of warmup steps.
Falls back to the final num_samples steps when 20% is too few.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
chain
|
Generator[Array, Any, Any]
|
PRNG key generator. |
required |
propagator
|
Callable[[Array, State], Result[State, State]]
|
JIT-compiled propagator returning |
required |
state
|
State
|
Initial state. |
required |
warmup_cycles
|
int
|
Total number of propagation steps. |
required |
num_samples
|
int
|
Number of snapshots to collect. |
required |
print_progress
|
bool
|
Show a tqdm progress bar. |
True
|
Returns:
| Type | Description |
|---|---|
list[State]
|
List of |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |