Propagators¶
A Propagator is a pure function that takes a PRNG key and a state and returns an updated state: (key, state) → state. MD integrators, Monte Carlo moves, force evaluations, neighbor list updates, and logging steps are all propagators. They share one interface, so they compose freely.
from dataclasses import replace
import jax
import jax.numpy as jnp
from jax import Array
from kups.core.lens import Lens, lens
from kups.core.propagator import LoopPropagator, compose_propagators
from kups.core.utils.jax import dataclass
@dataclass
class State:
position: Array
velocity: Array
# A propagator is any callable with signature (key, state) -> state
def drift(key: Array, state: State) -> State:
return replace(state, position=state.position + state.velocity * 0.1)
def kick(key: Array, state: State) -> State:
force = -state.position # harmonic restoring force
return replace(state, velocity=state.velocity + force * 0.1)
state = State(position=jnp.array([1.0, 0.0]), velocity=jnp.array([0.0, 1.0]))
key = jax.random.key(0)
new_state = drift(key, state)
print("after drift:", new_state.position)
after drift: [1. 0.1]
Why This Interface?¶
The (key, state) → state signature is minimal, but it buys three properties that matter for scientific computing.
Reproducibility. Every source of randomness flows through the explicit PRNG key. There is no hidden global state, no thread-local RNG, no seed that silently changes between runs. Given the same key and state, a propagator produces the same output on any hardware at any parallelism level.
Composability. Because every propagator has the same signature, they snap together without adaptation. An NVE simulation is a velocity Verlet step built from a momentum half-step, a position step, a force evaluation, and another momentum half-step. To turn NVE into NVT, insert a thermostat propagator into that sequence. The momentum step, position step, and force evaluation stay the same. Each piece is developed, tested, and understood on its own; the composition defines the simulation.
JAX compatibility. Pure functions over pytrees are what JAX's transformations (jit, grad, vmap) operate on. Propagators follow this contract naturally, so they get hardware acceleration and automatic differentiation without special handling.
Sequential Composition¶
The simplest composition: apply propagators one after another. The output state of one becomes the input of the next. compose_propagators handles the PRNG key splitting automatically, giving each sub-propagator an independent key.
# NVE: velocity Verlet = kick-drift-kick
nve_step = jax.jit(compose_propagators(kick, drift, kick))
new_state = nve_step(key, state)
print("position:", new_state.position)
print("velocity:", new_state.velocity)
position: [0.99 0.1 ]
velocity: [-0.199 0.99 ]
To go from NVE to NVT, add a thermostat. The kick, drift, and kick propagators are reused without modification.
def thermostat(key: Array, state: State) -> State:
"""Simple stochastic velocity rescaling (toy thermostat)."""
noise = jax.random.normal(key, state.velocity.shape) * 0.01
return replace(state, velocity=state.velocity + noise)
# NVT: insert thermostat into the sequence
nvt_step = jax.jit(compose_propagators(kick, drift, thermostat, kick))
s_nve = nve_step(key, state)
s_nvt = nvt_step(key, state)
print("NVE velocity:", s_nve.velocity)
print("NVT velocity:", s_nvt.velocity)
NVE velocity: [-0.199 0.99 ]
NVT velocity: [-0.20643845 0.99637485]
The NVT step reuses the same kick and drift propagators. The only difference is the thermostat inserted in the middle. Going from NVT to NPT follows the same pattern: add a barostat propagator. The existing pieces never change.
Other Composition Patterns¶
Sequential covers most cases, but not all.
Loop repeats a propagator N times. LoopPropagator wraps jax.lax.while_loop for JIT compatibility. Useful for inner loops like multiple MC moves per outer step.
loop = LoopPropagator(propagator=nve_step, repetitions=100)
final = loop(key, state)
print("after 100 NVE steps:", final.position)
after 100 NVE steps: [-0.01678006 0.7087816 ]
Switch — SwitchPropagator randomly selects one propagator from a set, weighted by probabilities. Useful for hybrid MC schemes that mix translation, rotation, and exchange moves.
Palindrome — PalindromePropagator applies propagators forward then backward: [P₁, P₂, ..., Pₙ, Pₙ, ..., P₂, P₁]. Useful for maintaining detailed balance or time-reversibility.
Potentials as Propagators¶
A Potential computes energies and gradients from the state but does not update it. PotentialAsPropagator bridges the gap: it calls the potential, then patches the gradients back into the state.
The separation exists because potentials have a richer interface. They return energies, gradients, and optionally Hessians. They compose by summation (Lennard-Jones + Coulomb + bonded), not by sequencing. The propagator adapter is where the potential world meets the integrator world: gradients get written to state, and from there the integrator reads them.
Building an MD Step¶
The building blocks are small propagators: MomentumStep (reads forces, writes momenta), PositionStep (reads momenta, writes positions), and PotentialAsPropagator (reads positions, writes gradients). An NVE velocity Verlet step composes them as:
For NVT with a Langevin thermostat (BAOAB splitting), a StochasticStep is inserted in the middle. The momentum and position steps are reused without modification:
NVT = [MomentumStep(½Δt), PositionStep(½Δt), StochasticStep, PositionStep(½Δt), Potential, MomentumStep(½Δt)]
For NPT, a StochasticCellRescalingStep barostat joins the sequence. Everything else stays the same.
Factory functions like make_velocity_verlet_step and make_baoab_langevin_step build these compositions. They are convenience wrappers around SequentialPropagator over the same primitives. Implementing a new thermostat means writing one propagator and inserting it into the sequence.
The Simulation Loop¶
The simulation loop lives on the host side. It calls the propagator repeatedly, checks runtime assertions, and applies fixes if needed (see the Runtime Assertions tutorial). During warmup, buffer sizes stabilize through the fix-and-retry mechanism. During production, assertions pass on every step and the loop runs at full speed. The key_chain utility provides a stream of independent PRNG keys for the loop.
from kups.core.utils.jax import key_chain
# A minimal simulation loop
chain = key_chain(key)
state = State(position=jnp.array([1.0, 0.0]), velocity=jnp.array([0.0, 1.0]))
for i in range(500):
state = nve_step(next(chain), state)
print("after 500 steps:", state.position)
after 500 steps: [-0.08380504 0.70638824]
Lenses in Propagators¶
Propagators are generic over the state type. A momentum step does not know whether positions live at state.position or state.coords. Lenses bridge that gap: each propagator accepts lenses that tell it where to find and write its data. The integrator is written once; lenses adapt it to any state layout.
def make_kick[S](pos_lens: Lens[S, Array], vel_lens: Lens[S, Array], dt: float):
def propagator(key: Array, state: S) -> S:
force = -pos_lens.get(state)
return vel_lens.set(state, vel_lens.get(state) + force * dt)
return propagator
def make_drift[S](pos_lens: Lens[S, Array], vel_lens: Lens[S, Array], dt: float):
def propagator(key: Array, state: S) -> S:
return pos_lens.set(state, pos_lens.get(state) + vel_lens.get(state) * dt)
return propagator
def make_verlet[S](pos: Lens[S, Array], vel: Lens[S, Array], dt: float):
return compose_propagators(
make_kick(pos, vel, dt), make_drift(pos, vel, dt), make_kick(pos, vel, dt)
)
# Same integrator, different state types
@dataclass
class AltState:
coords: Array
momenta: Array
step_a = make_verlet(
lens(lambda s: s.position, cls=State), lens(lambda s: s.velocity, cls=State), 0.1
)
step_b = make_verlet(
lens(lambda s: s.coords, cls=AltState), lens(lambda s: s.momenta, cls=AltState), 0.1
)
s1 = State(position=jnp.array([1.0, 0.0]), velocity=jnp.array([0.0, 1.0]))
s2 = AltState(coords=jnp.array([1.0, 0.0]), momenta=jnp.array([0.0, 1.0]))
print("State: ", step_a(key, s1).position)
print("AltState:", step_b(key, s2).coords)
State: [0.99 0.1 ]
AltState: [0.99 0.1 ]
Both calls produce the same result. The integrator logic (make_kick, make_drift) is written once. The lenses adapt it to State with position/velocity fields and to AltState with coords/momenta fields.