Monte Carlo Moves¶
MD alone does not cover every ensemble worth simulating. The grand canonical ensemble needs discrete insertions and deletions that MD cannot produce, volume and temperature swaps in replica exchange require unphysical proposals between chains, and sampling across high barriers is often faster with non-local moves than with short-time-step dynamics. The usual remedy is Metropolis-Hastings, and kUPS ships a Monte Carlo machinery that plugs into the same state, lens, and patch abstractions used by the integrators.
The library is batched-first: every system in the particle table is an independent Markov chain. One call to the MC propagator advances all of them by one step in parallel, each with its own per-system acceptance decision, step widths, and move statistics. This matters for free-energy work where hundreds of chains are evolved in one jit call, and for simulations where compilation cost would otherwise dominate for small systems.
from dataclasses import replace
import jax
import jax.numpy as jnp
from jax import Array
from kups.core.data import Index, Table
from kups.core.lens import lens
from kups.core.parameter_scheduler import (
ParameterSchedulerState,
acceptance_target_schedule,
)
from kups.core.patch import ExplicitPatch, IdPatch, WithPatch
from kups.core.potential import (
EMPTY,
CachedPotential,
PotentialOut,
empty_patch_idx_view,
)
from kups.core.propagator import MCMCPropagator, propose_mixed
from kups.core.schedule import PropertyScheduler
from kups.core.typing import GroupId, ParticleId, SystemId
from kups.core.unitcell import TriclinicUnitCell
from kups.core.utils.jax import dataclass, key_chain
from kups.mcmc.moves import ParticlePositionChanges, ParticleTranslationMove
from kups.mcmc.probability import BoltzmannLogProbabilityRatio
@dataclass
class Particles:
positions: Array
system: Index[SystemId]
group: Index[GroupId]
@dataclass
class SystemData:
unitcell: TriclinicUnitCell
@dataclass
class State:
particles: Table[ParticleId, Particles]
systems: Table[SystemId, SystemData]
temperature: Table[SystemId, Array]
cached_energy: PotentialOut
translation_params: Table[SystemId, ParameterSchedulerState]
def make_state(
n_systems=2,
n_particles_per_system=3,
box=10.0,
T=300.0,
step_width=0.5,
history_length=100,
seed=0,
):
key = jax.random.key(seed)
n = n_systems * n_particles_per_system
positions = jax.random.uniform(key, (n, 3), minval=-2.0, maxval=2.0)
sys_ids = jnp.repeat(jnp.arange(n_systems), n_particles_per_system)
particles = Table.arange(
Particles(
positions=positions,
system=Index.new([SystemId(int(i)) for i in sys_ids]),
group=Index.new([GroupId(j) for j in range(n)]),
),
label=ParticleId,
)
systems = Table.arange(
SystemData(
unitcell=TriclinicUnitCell.from_matrix(
box * jnp.eye(3)[None].repeat(n_systems, axis=0)
)
),
label=SystemId,
)
return State(
particles=particles,
systems=systems,
temperature=Table.arange(jnp.full(n_systems, T), label=SystemId),
cached_energy=PotentialOut(
total_energies=Table.arange(jnp.zeros(n_systems), label=SystemId),
gradients=EMPTY,
hessians=EMPTY,
),
translation_params=Table.arange(
ParameterSchedulerState.create(
n_systems,
initial_value=step_width,
history_length=history_length,
),
label=SystemId,
),
)
def harmonic_potential(state: State, patch=None):
"""Toy potential: U = sum_i |r_i|^2 per system (harmonic well at origin)."""
if patch is None:
new_state = state
else:
n_sys = len(state.systems)
accept_all = Table.arange(jnp.ones(n_sys, dtype=bool), label=SystemId)
new_state = patch(state, accept_all)
positions = new_state.particles.data.positions
sys_idx = new_state.particles.data.system.indices
n_sys = new_state.particles.data.system.num_labels
e_per_particle = jnp.sum(positions**2, axis=-1)
energies = jnp.zeros(n_sys).at[sys_idx].add(e_per_particle)
return WithPatch(
PotentialOut(
total_energies=Table.arange(energies, label=SystemId),
gradients=EMPTY,
hessians=EMPTY,
),
IdPatch(),
)
cached_potential = CachedPotential(
potential=harmonic_potential,
cache=lens(lambda s: s.cached_energy, cls=State),
patch_idx_view=empty_patch_idx_view,
)
def make_position_patch(proposal: ParticlePositionChanges) -> ExplicitPatch:
"""Patch that writes `proposal.new_positions` at `proposal.particle_ids`."""
def apply(state, payload, accept):
selected = state.particles[payload.particle_ids]
new_particles = state.particles.update_if(
accept,
payload.particle_ids,
Particles(
positions=payload.new_positions,
system=selected.system,
group=selected.group,
),
)
return replace(state, particles=new_particles)
return ExplicitPatch(payload=proposal, apply_fn=apply)
def patch_fn(key, state, proposal):
return make_position_patch(proposal)
def prime_cache(state):
"""Populate the energy cache with the full evaluation on the current state."""
result = cached_potential(state, patch=None)
n_sys = len(state.systems)
accept_all = Table.arange(jnp.ones(n_sys, dtype=bool), label=SystemId)
return result.patch(state, accept_all)
state = prime_cache(make_state())
print("systems: ", len(state.systems))
print("particles/system: ", len(state.particles) // len(state.systems))
print("initial energies: ", state.cached_energy.total_energies.data)
systems: 2
particles/system: 3
initial energies: [11.330361 17.65512 ]
The Metropolis-Hastings step¶
Every MC step in kUPS has the same four-stage shape. First a proposal is drawn: a function of type ChangesFn takes a PRNG key and the current state and returns a pair of a change payload and a per-system log proposal ratio. The change payload is whatever structured description the move needs, for example a ParticlePositionChanges carrying the selected indices and their new coordinates. The log proposal ratio is zero for symmetric moves and carries the forward/backward asymmetry otherwise.
Next the proposal is turned into a patch. A PatchFn takes the change payload and produces a Patch that writes the new positions into the particle table. At this point the state has not yet changed; the patch is a recipe that will be applied conditionally.
The evaluation stage runs a LogProbabilityRatioFn against the current state and the proposed patch. This returns a WithPatch whose data is the per-system log target ratio (for example the Boltzmann factor) and whose own patch commits any cached intermediates that the energy evaluation produced. Cache updates and position updates are kept as separate patches so that a rejected move leaves the cache untouched.
The decision is a single log_p_ratio > log(U) comparison per system, where log_p_ratio is the sum of the proposal and target log ratios. The resulting boolean table is the acceptance mask. The apply stage then calls the proposal patch and the evaluation patch with the same mask: accepted systems see both their positions and their caches advance, rejected systems see neither change.
The cell below walks through those four stages by hand on the two-system state we just built.
chain = key_chain(jax.random.key(2))
move = ParticleTranslationMove[State](
positions=lambda s: s.particles,
systems=lambda s: s.systems,
step_width=lambda s: Table(
s.translation_params.keys, s.translation_params.data.value
),
)
changes, move_log_ratio = move(next(chain), state)
print("picked particles (one per system):", changes.particle_ids.indices)
print(
"proposed displacements:",
changes.new_positions - state.particles[changes.particle_ids].positions,
)
print("move log ratio (symmetric move): ", move_log_ratio.data)
patch = patch_fn(next(chain), state, changes)
boltzmann = BoltzmannLogProbabilityRatio(
temperature=lambda s: s.temperature,
potential=cached_potential,
)
density = boltzmann(state, patch)
log_p_ratio = move_log_ratio.data + density.data.data
print("density log ratio:", density.data.data)
print("combined log ratio:", log_p_ratio)
u = jax.random.uniform(next(chain), (len(state.systems),))
accept_mask = log_p_ratio > jnp.log(u)
accept = Table.arange(accept_mask, label=SystemId)
print("accept per system:", accept.data)
new_state = patch(state, accept)
new_state = density.patch(new_state, accept)
print("position moved for each particle:")
print(
jnp.linalg.norm(
new_state.particles.data.positions - state.particles.data.positions, axis=-1
)
)
picked particles (one per system): [0 5]
proposed displacements: [[-0.30631447 -0.51468146 -0.37817073]
[ 0.23070872 0.01131606 0.08038354]]
move log ratio (symmetric move): [0. 0.]
density log ratio: [ 79.6235 -40.119335]
combined log ratio: [ 79.6235 -40.119335]
accept per system: [ True False]
position moved for each particle:
[0.70833516 0. 0. 0. 0. 0. ]
ChangesFn and MonteCarloMove¶
ChangesFn is the narrow protocol every move satisfies: it maps a key and a state to a proposal and a log ratio. MonteCarloMove is a convenience abstract base that formalises this and is the type callers see in the move classes.
The library ships a handful of concrete moves. ParticleTranslationMove displaces one randomly chosen particle per system by a Gaussian-distributed vector whose width is read from a lens into the state. GroupTranslationMove and GroupRotationMove do the same for a rigid molecular group. ReinsertionMove picks up a group and drops it at a uniformly random position and orientation, which is useful for escaping local traps. ExchangeMove is the grand canonical pair: with equal probability it proposes inserting a fresh motif into a free slot of a Buffered particle table or deleting an existing one.
Each concrete move is a dataclass holding a handful of lenses into the state. The move itself contains no physics beyond geometry; the acceptance decision is handled by the probability function, which is what makes the same moves compose cleanly across ensembles.
Calling a move twice with different keys gives distinct proposals. The log proposal ratio is zero because the translation distribution is symmetric.
changes_a, ratio_a = move(jax.random.key(10), state)
changes_b, ratio_b = move(jax.random.key(20), state)
print("picked A:", changes_a.particle_ids.indices)
print("picked B:", changes_b.particle_ids.indices)
print(
"displacement A:",
changes_a.new_positions - state.particles[changes_a.particle_ids].positions,
)
print(
"displacement B:",
changes_b.new_positions - state.particles[changes_b.particle_ids].positions,
)
print("log ratios always zero:", ratio_a.data, ratio_b.data)
picked A: [1 3]
picked B: [2 5]
displacement A: [[-0.3695712 -1.0207106 -0.1616838 ]
[ 0.04383624 1.0724673 -1.8488913 ]]
displacement B: [[ 0.5721789 0.27600145 0.21787739]
[-0.64267755 0.48727584 0.14450502]]
log ratios always zero: [0. 0.] [0. 0.]
LogProbabilityRatioFn: the acceptance criterion¶
The acceptance criterion is decoupled from the move. A move only proposes geometry; the LogProbabilityRatioFn is what gives the proposed patch a probabilistic interpretation. The library ships three of them.
BoltzmannLogProbabilityRatio is the canonical-ensemble criterion. It reads the temperature from a lens, evaluates the cached potential on the old state and on the patched state, and returns (U_old - U_new) / (k_B T) per system. Because it goes through a CachedPotential, accepted moves commit the new energies to the cache through the returned WithPatch rather than recomputing them on the next step. LogFugacityRatio adds the chemical-potential term for grand canonical moves. MuVTLogProbabilityRatio composes the two for a full GCMC acceptance.
Below we build a Boltzmann ratio for our toy harmonic potential and evaluate it on the proposal from the previous cell. A move that pushes a particle farther from the origin raises the energy and gives a negative log ratio; a move that moves toward the origin gives a positive log ratio.
def one_particle_patch(idx, new_pos):
return make_position_patch(
ParticlePositionChanges(
particle_ids=Index(state.particles.keys, jnp.array([idx])),
new_positions=new_pos,
)
)
current = state.particles.data.positions[0]
outward = one_particle_patch(0, current[None] * 1.5)
inward = one_particle_patch(0, current[None] * 0.5)
print("outward log ratio:", boltzmann(state, outward).data.data)
print("inward log ratio:", boltzmann(state, inward).data.data)
print("at higher T (scaled x10):")
hot_state = replace(
state, temperature=Table.arange(state.temperature.data * 10, label=SystemId)
)
print(" outward log ratio:", boltzmann(hot_state, outward).data.data)
outward log ratio: [-353.9926 0. ]
inward log ratio: [212.39551 0. ]
at higher T (scaled x10):
outward log ratio: [-35.39926 0. ]
MCMCPropagator: the step machine¶
MCMCPropagator is the object that runs the four-stage cycle. It holds a tuple of propose functions, a patch-builder, a probability function, a tuple of step-width schedulers that mirrors the propose functions, and optional selection weights. Its __call__ is one full Metropolis-Hastings step.
Because the propagator satisfies the standard Propagator interface, it can be composed with integrator propagators, loops, and palindromes from the propagators notebook without any special treatment. A hybrid MC/MD cycle is a SequentialPropagator of a velocity-Verlet block followed by an MCMCPropagator block; a long equilibration with thinning is a LoopPropagator wrapping the MC step.
The cell below assembles everything, runs 500 MC steps, and tracks the per-system energy and acceptance rate.
scheduler = PropertyScheduler(
lens=lens(lambda s: s.translation_params, cls=State),
schedule=Table.transform(acceptance_target_schedule),
)
propagator = MCMCPropagator(
patch_fn=patch_fn,
propose_fns=(move,),
log_probability_ratio_fn=boltzmann,
parameter_schedulers=(scheduler,),
)
@jax.jit
def run(state, keys):
def body(state, key):
new_state = propagator(key, state)
# Proxy: a system counts as accepted if any of its particles moved.
# Exact would be to expose the internal acceptance mask from MCMCPropagator.
moved = jnp.any(
new_state.particles.data.positions != state.particles.data.positions,
axis=-1,
)
sys_idx = state.particles.data.system.indices
accepted = jnp.zeros(len(state.systems), dtype=bool).at[sys_idx].max(moved)
return new_state, (new_state.cached_energy.total_energies.data, accepted)
return jax.lax.scan(body, state, keys)
fresh = prime_cache(make_state())
N = 500
keys = jax.random.split(jax.random.key(42), N)
final_state, (energies, acceptances) = run(fresh, keys)
print("initial energy per system:", fresh.cached_energy.total_energies.data)
print("final energy per system:", final_state.cached_energy.total_energies.data)
print("mean acceptance per system:", acceptances.mean(axis=0))
print("final translation step width:", final_state.translation_params.data.value)
initial energy per system: [11.330361 17.65512 ]
final energy per system: [0.07537573 0.08348013]
mean acceptance per system: [0.06 0.1 ]
final translation step width: [0.31046063 0.31046063]
Mixing move types¶
A realistic MC simulation rarely uses a single move. Translations equilibrate local packing, rotations relax orientations, reinsertions escape metastable basins, and exchanges sample particle number. MCMCPropagator supports this directly: its propose_fns field is a tuple and each step picks one of the moves according to the weights.
Internally this goes through propose_mixed, which evaluates every propose function against the same PRNG key and then uses jax.lax.select_n to pick the active one. Evaluating all proposals up front trades a little compute for common-subexpression elimination under jit, which usually pays for itself when the moves share geometric primitives. Only the scheduler belonging to the selected move is advanced each step.
Below we call propose_mixed directly on two copies of the translation move with different step widths to see how move selection works in isolation.
bold = ParticleTranslationMove[State](
positions=lambda s: s.particles,
systems=lambda s: s.systems,
step_width=lambda s: Table.arange(jnp.full(len(s.systems), 2.0), label=SystemId),
)
timid = ParticleTranslationMove[State](
positions=lambda s: s.particles,
systems=lambda s: s.systems,
step_width=lambda s: Table.arange(jnp.full(len(s.systems), 0.05), label=SystemId),
)
# Call 10 times with weights (1, 3): timid should be picked ~3x as often.
picks = []
for k in jax.random.split(jax.random.key(7), 10):
_, _, which = propose_mixed(k, state, (bold, timid), weights=(1.0, 3.0))
picks.append(int(which))
print("which move was picked each call (0=bold, 1=timid):", picks)
print("timid picked", picks.count(1), "times out of 10")
which move was picked each call (0=bold, 1=timid): [0, 1, 0, 1, 0, 1, 1, 1, 1, 1]
timid picked 7 times out of 10
Step-width scheduling and acceptance tuning¶
Monte Carlo step widths are part of the state, not hard-coded constants. Each move reads its width through a lens into a ParameterSchedulerState field. After a step, the propagator runs the corresponding PropertyScheduler with the acceptance_target_schedule to adjust the width toward a target acceptance ratio (the default target is around one half, which is close to optimal for isotropic proposals).
Because the scheduler mutates state through the same lens machinery that reads the width, the tuning is automatic and jit-compatible. In a multi-move propagator each move has its own scheduler, so translations, rotations, and reinsertions land on their own equilibrium step widths independently.
To see the scheduler in action we deliberately start the step width too high (so moves are almost always rejected) and let the propagator dial it down.
too_big = prime_cache(make_state(step_width=10.0, history_length=20))
keys = jax.random.split(jax.random.key(99), 2000)
tuned_state, (_, accepted) = run(too_big, keys)
print("initial step width (per system): ", too_big.translation_params.data.value)
print("final step width (per system): ", tuned_state.translation_params.data.value)
print("acceptance over first 200 steps: ", accepted[:200].mean(axis=0))
print("acceptance over last 200 steps: ", accepted[-200:].mean(axis=0))
initial step width (per system): [10. 10.]
final step width (per system): [0.1030743 0.1030743]
acceptance over first 200 steps: [0.025 0.025]
acceptance over last 200 steps: [0.465 0.505]
Factory constructors for common combinations¶
Writing the wiring by hand is mechanical. The library ships factory functions that take a lens into the state, a patch function, and a probability function and return a fully wired MCMCPropagator. make_group_translation_mcmc_propagator and make_group_rotation_mcmc_propagator build propagators for a single move type. make_displacement_mcmc_propagator combines translation, rotation, and reinsertion under one propose_mixed with configurable weights. make_exchange_mcmc_propagator and make_gcmc_mcmc_propagator do the same for grand canonical setups.
When the state layout follows the usual naming convention (particles, groups, systems, translation_params, rotation_params, and so on, as defined in the IsMCMCMoveState protocol) these factories are the right entry points. Dropping down to the class constructors like we did here is only necessary when the state stores things under unusual names, which is what this notebook has been doing in the interest of keeping the example self-contained.
Where to go next¶
The acceptance criteria need an energy function, which is the subject of the Potentials notebook. The patch composition that commits position and cache updates together is covered in the Patches notebook. The neighbor list reuses the same proposal/patch shape to answer the question "which edges changed" for an MC move; see the Neighbor Lists notebook for neighborlist_changes, which is the cheap counterpart to a full rebuild during the acceptance test.