Patches¶
Simulation code often needs to propose a change to the state before deciding whether to apply it. Monte Carlo is the canonical example: propose a move, compute the energy change, accept or reject based on the Metropolis criterion. Potentials also need staged updates — many cache quantities (a total energy, a structure factor, a lookup table) that must update atomically with the state changes they depend on. Doing this by hand is error-prone when many independent systems are batched and each can accept or reject independently.
The patch system is a small, composable abstraction for these conditional, batched state updates.
from dataclasses import replace
import jax
import jax.numpy as jnp
from jax import Array
from kups.core.data import Index, Table
from kups.core.lens import lens
from kups.core.patch import (
ComposedPatch,
ExplicitPatch,
IdPatch,
IndexLensPatch,
WithPatch,
)
from kups.core.typing import ParticleId, SystemId
from kups.core.utils.jax import dataclass
The Patch Protocol¶
A Patch is a function that takes a state and a per-system acceptance mask, and returns a modified state. The acceptance mask is a Table[SystemId, bool]: one boolean per system deciding whether the change applies there.
class Patch[State](Protocol):
def __call__(self, state, accept: Table[SystemId, bool]) -> State: ...
Constructing a patch does not change anything. The patch is data. Only calling it with a state and an acceptance mask produces the new state.
Let's build a small state with two systems: particles 0 and 1 in system 0, particle 2 in system 1.
@dataclass
class PData:
positions: Array
system: Index[SystemId]
@dataclass
class State:
particles: Table[ParticleId, PData]
cached_energy: Table[SystemId, Array]
particles = Table.arange(
PData(
positions=jnp.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [2.0, 0.0, 0.0]]),
system=Index.new([SystemId(0), SystemId(0), SystemId(1)]),
),
label=ParticleId,
)
state = State(
particles=particles,
cached_energy=Table(keys=(SystemId(0), SystemId(1)), data=jnp.array([10.0, 20.0])),
)
print("positions:", state.particles.data.positions)
print("cached energy:", state.cached_energy.data)
positions: [[0. 0. 0.]
[1. 0. 0.]
[2. 0. 0.]]
cached energy: [10. 20.]
IdPatch: The No-Op¶
IdPatch returns the state unchanged regardless of the acceptance mask. It is the identity element for patch composition. Potentials return IdPatch when they have no cached intermediates to update — the energy is computed and returned, but nothing in the state needs to change.
accept_all = Table(keys=(SystemId(0), SystemId(1)), data=jnp.array([True, True]))
result = IdPatch()(state, accept_all)
print("same state:", result is state)
print("positions unchanged:", result.particles.data.positions)
same state: True
positions unchanged: [[0. 0. 0.]
[1. 0. 0.]
[2. 0. 0.]]
IndexLensPatch: The Workhorse¶
IndexLensPatch is the most common patch type. It combines three pieces:
data: the new values to writemask_idx: a pytree whose Index leaves tell which system each row belongs tolens: a Lens pointing to where in the state the data lives
When called with an acceptance mask, it blends new and old values per row. Rows whose system accepted get the new data; rows whose system rejected keep the old data.
Let's propose a move for particles 1 and 2, then apply it with different acceptance patterns.
new_data = PData(
positions=jnp.array([[0.0, 0.0, 0.0], [9.9, 0.0, 0.0], [8.8, 0.0, 0.0]]),
system=Index.new([SystemId(0), SystemId(0), SystemId(1)]),
)
mask_idx = PData(
positions=Index.new([SystemId(0), SystemId(0), SystemId(1)]), # type: ignore
system=Index.new([SystemId(0), SystemId(0), SystemId(1)]),
)
pos_patch = IndexLensPatch(
data=new_data,
mask_idx=mask_idx,
lens=lens(lambda s: s.particles.data, cls=State),
)
# Accept only system 0
accept_0 = Table(keys=(SystemId(0), SystemId(1)), data=jnp.array([True, False]))
print("accept [T, F]:", pos_patch(state, accept_0).particles.data.positions)
# Accept only system 1
accept_1 = Table(keys=(SystemId(0), SystemId(1)), data=jnp.array([False, True]))
print("accept [F, T]:", pos_patch(state, accept_1).particles.data.positions)
# Accept both
print("accept [T, T]:", pos_patch(state, accept_all).particles.data.positions)
accept [T, F]: [[0. 0. 0. ]
[9.9 0. 0. ]
[2. 0. 0. ]]
accept [F, T]: [[0. 0. 0. ]
[1. 0. 0. ]
[8.8 0. 0. ]]
accept [T, T]: [[0. 0. 0. ]
[9.9 0. 0. ]
[8.8 0. 0. ]]
Each row kept its old value or received the new value based on the acceptance of its system. In a batched MC simulation, this is how independent chains commit their own moves without interfering with each other.
ExplicitPatch: Custom Updates¶
When none of the built-in patch types fit, ExplicitPatch is the escape hatch. It takes an arbitrary payload and a function apply_fn(state, payload, accept) -> state.
@dataclass
class StateWithCounter:
counter: Array # per-system counter
def increment(state, payload, accept):
return replace(state, counter=state.counter + jnp.where(accept.data, payload, 0))
counter_state = StateWithCounter(counter=jnp.array([0, 0]))
step_patch = ExplicitPatch(payload=jnp.array([1, 1]), apply_fn=increment)
result = step_patch(counter_state, accept_0)
print("accept [T, F]:", result.counter)
result = step_patch(counter_state, accept_all)
print("accept [T, T]:", result.counter)
accept [T, F]: [1 0]
accept [T, T]: [1 1]
ComposedPatch: Sequencing¶
ComposedPatch applies a sequence of patches one after another, all with the same acceptance mask. This is how cached potentials keep the state consistent: the position update and every cache update that depends on it get bundled together under one acceptance decision.
Here we compose a position update with a cached-energy update. Each patch does its own thing; the composition applies both atomically per system.
energy_patch = IndexLensPatch(
data=jnp.array([15.5, 25.5]),
mask_idx=Index.new([SystemId(0), SystemId(1)]),
lens=lens(lambda s: s.cached_energy.data, cls=State),
)
composed = ComposedPatch(patches=(pos_patch, energy_patch))
# Accept only system 0: positions and energy update there, but not in system 1
result = composed(state, accept_0)
print("positions:", result.particles.data.positions)
print("cached energy:", result.cached_energy.data)
positions: [[0. 0. 0. ]
[9.9 0. 0. ]
[2. 0. 0. ]]
cached energy: [15.5 20. ]
System 0 got both its new position and its new cached energy. System 1 kept everything as it was. The cache and the positions stay in sync because they're updated under the same mask.
WithPatch: Pairing Data and Patches¶
Most operations return both a result (e.g., an energy) and a patch (e.g., a cache update). WithPatch packages them together. A Potential returns WithPatch[PotentialOut, Patch]: the data is the energy and gradients, the patch writes the cache.
WithPatch supports __add__ when the data type is addable. Adding two WithPatch objects sums the data and composes the patches. This is what makes potentials compose cleanly — summing two WithPatch outputs gives the combined energy and a composed cache update, all in one object.
w1 = WithPatch(data=jnp.array([1.0, 2.0]), patch=pos_patch)
w2 = WithPatch(data=jnp.array([0.5, 0.3]), patch=energy_patch)
combined = w1 + w2
print("summed data:", combined.data)
print("composed patch:", type(combined.patch).__name__)
# Apply the composed patch
result = combined.patch(state, accept_all)
print("positions:", result.particles.data.positions)
print("cached energy:", result.cached_energy.data)
summed data: [1.5 2.3]
composed patch: ComposedPatch
positions: [[0. 0. 0. ]
[9.9 0. 0. ]
[8.8 0. 0. ]]
cached energy: [15.5 25.5]
.map_data and .map_patch transform one side while keeping the other. .compose_patch(other) chains an additional patch after the current one. Together these make WithPatch a convenient building block for anything that produces a result alongside a conditional state update.
Probes: Reading State + Patch Together¶
A Probe is the dual of a patch. Instead of writing changes into the state, it reads information out about a proposed change.
Probes tell a potential what it needs to know about a proposed change so it can decide how to evaluate cheaply. The exact output R is potential-specific. Each potential documents its own probe-return protocol. Pair potentials (LJ, Coulomb, Ewald real-space) expect an IsRadiusGraphProbe[P], which exposes the indices of moved particles along with old and new neighbor lists so the graph can be restricted to affected edges. See the [Potentials][potentials.ipynb] chapter for a worked example. A potential that always recomputes from scratch needs no probe at all.
Here is a simple probe that counts affected particles. Trivial, but illustrates the shape of the interface.
def count_affected(state, patch: IndexLensPatch) -> int:
leaves = jax.tree.leaves(patch.mask_idx, is_leaf=lambda x: isinstance(x, Index))
return max(leaf.indices.size for leaf in leaves if isinstance(leaf, Index))
print("affected particles:", count_affected(state, pos_patch))
affected particles: 3
The MCMC Propagator¶
The MCMCPropagator is a direct application of the patch system. One step does four things:
-
Propose. A proposal function draws random changes (e.g., particle displacements) and returns them along with a log proposal ratio. A patch function converts the changes into a Patch — typically an IndexLensPatch that would update the particle table.
-
Evaluate. A log-probability-ratio function (the Metropolis Boltzmann factor) is called with the state and the proposed patch. It returns WithPatch
[LogProbabilityRatio, Patch]: the data is the per-system log ratio; the returned patch bundles any cache updates the potentials need to commit on acceptance. -
Decide. For each system, compare
proposal_log_ratio + target_log_ratioagainst a uniform random number. This gives a per-system boolean acceptance mask. -
Apply. Call the proposal patch with the mask to commit position changes. Then call the evaluation patch with the same mask to commit the cache updates. Accepted systems end up with new positions and a consistent cache; rejected systems keep everything, including their cache.
The per-system acceptance mask is what makes batched parallel chains work. Each system in the state is an independent Markov chain: it accepts or rejects its own proposed move without coordinating with the others. The patch machinery ensures the position update and every dependent cache update commit together — or not at all.