kups.core.propagator
¶
State propagators for simulation dynamics and Monte Carlo moves.
This module provides a composable framework for evolving simulation states over time. Propagators represent any operation that transitions a state from one configuration to another.
Key components:
- Propagator: Protocol for state evolution functions
- propagator_with_assertions: Wrap a propagator to track assertion results
- propagate_and_fix: Retry propagation until assertions pass
- MCMCPropagator: Metropolis-Hastings Monte Carlo with acceptance/rejection
- SequentialPropagator: Chain multiple propagators sequentially
- PalindromePropagator: Reversible composition maintaining detailed balance
- LoopPropagator: Repeat a propagator multiple times
- SwitchPropagator: Randomly select from multiple propagators
- ResetOnErrorPropagator: Rollback state on assertion failures
- ScheduledPropertyPropagator: Update properties according to schedules
Propagators are composable and JIT-compilable, enabling efficient simulation loops.
LogProbabilityRatio = Table[SystemId, Array]
¶
Type alias for log probability ratio arrays.
BakeConstantsPropagator
¶
Bases: Propagator[State]
Wraps a propagator by identifying and caching state leaves that are unchanged.
Uses eval_shape to trace the inner propagator and detect which leaves are
returned via identity (i.e. not modified). Those leaves are snapshot as
read-only NumPy arrays and injected on every call, avoiding redundant
device transfers. This may also enable XLA constant folding, as the baked
values become compile-time constants visible to the compiler.
Note
Baked values are frozen at construction time. Any external mutation of
those leaves after new() will not be reflected in subsequent
calls — the cached snapshot is used instead.
Attributes:
| Name | Type | Description |
|---|---|---|
propagator |
Propagator[State]
|
The inner propagator to wrap. |
const_indices |
tuple[int, ...]
|
Flat indices of constant leaves in the pytree. |
consts |
tuple[ndarray, ...]
|
Cached NumPy snapshots of the constant leaves. |
Source code in src/kups/core/propagator.py
CachePropagator
¶
Bases: Propagator[State]
Propagator that computes a property and caches it in the state.
Evaluates a state property (e.g., neighbor list, energy) and stores the result in the state using a lens-based update.
Attributes:
| Name | Type | Description |
|---|---|---|
function |
StateProperty[State, ResultType]
|
Function that computes the property |
update |
Update[State, ResultType]
|
Update function that stores the result in state |
Source code in src/kups/core/propagator.py
ChangesFn
¶
Bases: Protocol
Protocol for functions that propose changes and a log proposal ratio.
Source code in src/kups/core/propagator.py
IdentityPropagator
¶
Bases: Propagator[State]
No-op propagator that returns the state unchanged.
Useful as a placeholder or for testing.
Source code in src/kups/core/propagator.py
LogProbabilityRatioFn
¶
Bases: Protocol
Protocol for computing target density ratios.
Computes log probability ratio of target distribution (e.g., Boltzmann factor).
Source code in src/kups/core/propagator.py
LoopPropagator
¶
Bases: Propagator[State]
Repeat a propagator multiple times in a loop.
Applies a single propagator repeatedly for either a fixed number of iterations
or a dynamic number determined from the state. Uses jax.lax.while_loop for
efficient compilation.
Attributes:
| Name | Type | Description |
|---|---|---|
propagator |
Propagator[State]
|
The propagator to repeat |
repetitions |
View[State, Array] | int
|
Either a fixed integer or a function extracting repetition count from state |
Example
Source code in src/kups/core/propagator.py
MCMCPropagator
¶
Bases: Propagator[State]
Metropolis-Hastings Monte Carlo propagator with acceptance/rejection.
Supports both single-move and mixed-move scenarios. When multiple
propose_fns are provided, one is selected at random each step
(weighted by weights), and only the corresponding scheduler is updated.
Attributes:
| Name | Type | Description |
|---|---|---|
patch_fn |
PatchFn[State, Changes, Move]
|
Converts changes to a state patch. |
propose_fns |
tuple[ChangesFn[State, Changes], ...]
|
Tuple of change proposal functions. |
log_probability_ratio_fn |
LogProbabilityRatioFn[State, Move]
|
Computes target density ratio (e.g., Boltzmann). |
parameter_schedulers |
tuple[Scheduler[State, Table[SystemId, Array]], ...]
|
One scheduler per propose_fn, updated selectively. |
weights |
tuple[float, ...] | None
|
Selection probabilities per move (unnormalized). None for uniform. |
Source code in src/kups/core/propagator.py
PalindromePropagator
¶
Bases: Propagator[State]
Apply propagators forward then backward to preserve detailed balance.
Applies propagators in sequence: [P₁, P₂, ..., Pₙ, Pₙ, ..., P₂, P₁]. This "telescope" pattern ensures that if individual propagators satisfy detailed balance, the combined propagator also does.
Critical for maintaining correct equilibrium distributions in MCMC.
Attributes:
| Name | Type | Description |
|---|---|---|
propagators |
tuple[Propagator[State], ...]
|
Tuple of propagators to apply palindromically |
Mathematical property
If each Pᵢ satisfies detailed balance, then the composition P₁ ∘ P₂ ∘ ... ∘ Pₙ ∘ Pₙ ∘ ... ∘ P₂ ∘ P₁ also satisfies detailed balance.
Example
Source code in src/kups/core/propagator.py
PatchFn
¶
Bases: Protocol
Protocol for functions that convert move proposals to state patches.
Source code in src/kups/core/propagator.py
Propagator
¶
Bases: Protocol
Protocol for state evolution functions.
A propagator takes a random key and current state, returning an updated state. Propagators can represent time evolution (MD integrators), Monte Carlo moves, or any other state transformation.
Class Type Parameters:
| Name | Bound or Constraints | Description | Default |
|---|---|---|---|
State
|
Simulation state type |
required |
Example
Source code in src/kups/core/propagator.py
__call__(key, state)
¶
Propagate the state forward.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
Array
|
JAX PRNG key for stochastic operations |
required |
state
|
State
|
Current simulation state |
required |
Returns:
| Type | Description |
|---|---|
State
|
Updated state after propagation |
Source code in src/kups/core/propagator.py
ResetOnErrorPropagator
¶
Bases: Propagator[State]
Rollback to previous state if runtime assertions fail.
Wraps a propagator and checks runtime assertions after execution. If any assertion fails, reverts to the original state. Useful for robust simulation where certain configurations are invalid.
Attributes:
| Name | Type | Description |
|---|---|---|
propagator |
Propagator[State]
|
Base propagator to wrap with error handling |
Example
Note
Uses check_assertions which must be called within a with_runtime_assertions context to function properly.
Source code in src/kups/core/propagator.py
ScheduledPropertyPropagator
¶
Bases: Propagator[State]
Propagator that updates a property according to a schedule.
Reads the scheduling input (e.g., step number) from the state, applies the schedule to compute a new value, and updates the state.
This is useful for time-dependent parameter changes during simulation, such as temperature annealing, pressure ramps, or time step adaptation.
Class Type Parameters:
| Name | Bound or Constraints | Description | Default |
|---|---|---|---|
State
|
Simulation state type |
required | |
Input
|
Type of scheduling input (typically Array for step/time) |
required | |
Value
|
Type of value being scheduled |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
lens |
Lens[State, Value]
|
Lens to access and update the scheduled property |
input_view |
View[State, Input]
|
View to extract the scheduling input from state |
schedule |
Schedule[Input, Value]
|
Schedule that computes new values |
Example
from kups.core.schedule import LinearSchedule
# Temperature annealing from 500K to 300K over 10000 steps
temp_propagator = ScheduledPropertyPropagator(
lens=lens(lambda s: s.temperature),
input_view=lens(lambda s: s.step).get,
schedule=LinearSchedule(
start=jnp.array(500.0),
end=jnp.array(300.0),
total_steps=jnp.array(10000)
)
)
# In simulation loop:
state = temp_propagator(key, state)
See Also
- Schedule: Protocol for scheduling functions
- PropertyScheduler: Non-propagator scheduler
Source code in src/kups/core/propagator.py
__call__(key, state)
¶
Apply the schedule to update the state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
Array
|
JAX PRNG key (unused, but required by Propagator protocol) |
required |
state
|
State
|
Current simulation state |
required |
Returns:
| Type | Description |
|---|---|
State
|
Updated state with scheduled property modified |
Source code in src/kups/core/propagator.py
SequentialPropagator
¶
Bases: Propagator[State]
Apply multiple propagators in sequence.
Chains propagators together, applying each in order with independent random keys.
Attributes:
| Name | Type | Description |
|---|---|---|
propagators |
tuple[Propagator[State], ...]
|
Tuple of propagators to apply sequentially |
Example
Source code in src/kups/core/propagator.py
StateProperty
¶
Bases: Protocol
Protocol for functions that extract properties from states.
Class Type Parameters:
| Name | Bound or Constraints | Description | Default |
|---|---|---|---|
State
|
Simulation state type |
required | |
Property
|
Type of property to extract |
required |
Source code in src/kups/core/propagator.py
StatePropertySum
¶
Sum multiple state properties together.
Attributes:
| Name | Type | Description |
|---|---|---|
properties |
tuple[StateProperty[State, Property], ...]
|
Tuple of property extractors to sum |
Source code in src/kups/core/propagator.py
SwitchPropagator
¶
Bases: Propagator[State]
Randomly select and apply one propagator from multiple options.
Chooses a propagator based on probabilities and applies it to the state. Useful for hybrid Monte Carlo schemes with multiple move types.
Attributes:
| Name | Type | Description |
|---|---|---|
propagators |
tuple[Propagator[State], ...]
|
Tuple of propagators to choose from |
probabilities |
View[State, Array]
|
Function returning selection probabilities for each propagator |
Warning
When vmapped, all propagators are executed and results selected, leading to higher compute costs. Use conditionals if vmap efficiency is critical.
Example
Source code in src/kups/core/propagator.py
compose_propagators(*propagators)
¶
Compose multiple propagators into a single one by sequentially chaining them.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*propagators
|
Propagator[S]
|
Propagators to chain together |
()
|
Returns:
| Type | Description |
|---|---|
Propagator[S]
|
SequentialPropagator that applies each propagator in order |
Example
Source code in src/kups/core/propagator.py
propagate_and_fix(fn, key, state, *, max_tries=10)
¶
Execute a propagator repeatedly until all assertions pass or retries are exhausted.
On each attempt, failed assertions are repaired via their fix functions. Raises if a failed assertion has no fix function or retries run out.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Callable[[Array, State], Result[State, State]]
|
Assertion-aware propagator produced by :func: |
required |
key
|
Array
|
JAX PRNG key. |
required |
state
|
State
|
Current simulation state. |
required |
max_tries
|
int
|
Maximum number of repair attempts. |
10
|
Returns:
| Type | Description |
|---|---|
State
|
Propagated state with all assertions satisfied. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If called inside a JAX transform. |
RuntimeError
|
If assertions still fail after |
Source code in src/kups/core/propagator.py
propagator_with_assertions(propagator)
¶
Wrap a propagator to capture assertion results alongside the state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
propagator
|
Propagator[State]
|
Propagator to wrap. |
required |
Returns:
| Type | Description |
|---|---|
Callable[[Array, State], Result[State, State]]
|
Function returning a Result that pairs the new state with assertion metadata. |
Source code in src/kups/core/propagator.py
propose_mixed(key, state, propose_fns, weights=None)
¶
Compute all proposals eagerly and select one at random.
All propose_fns are evaluated, then jax.lax.select_n picks one.
Returns (selected_changes, selected_log_ratio, which_index).
Source code in src/kups/core/propagator.py
step_counter_propagator(step_lens)
¶
Build a propagator that increments a step counter by 1 each call.
Wraps ScheduledPropertyPropagator
with an IncrementSchedule so that the
counter stored at step_lens is advanced by 1 on every propagation step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
step_lens
|
Lens[State, Array]
|
Lens pointing to the integer step-counter array in the state.
The array must be broadcastable with an increment of |
required |
Returns:
| Type | Description |
|---|---|
ScheduledPropertyPropagator[State, Array, Array]
|
|
ScheduledPropertyPropagator[State, Array, Array]
|
that increments the counter by 1 each time it is called. |