kups.potential.mliap.local
¶
Local machine-learned interatomic potential (MLIAP) with message passing.
This module provides infrastructure for local MLIAPs that use a single round of message passing. The key feature is efficient incremental energy updates during Monte Carlo simulations by caching node embeddings and aggregated messages.
Key components:
- LocalMLIAPData: Configuration holding model functions and cache
- LocalMLIAPComposer: Composes state into inputs for energy evaluation
- local_mliap_energy: Computes energy with automatic full/incremental dispatch
The architecture follows a message-passing neural network pattern: 1. Node initialization: atomic_numbers → node embeddings 2. Edge function: (node_i, node_j, r_ij) → messages 3. Message aggregation: sum messages per node 4. Readout: (node_emb, msg_sum) → per-atom energies
Example
def init_fn(atomic_numbers): return embedding_table[atomic_numbers]
def edge_fn(n1, n2, r): return mlp(r) * n2
def readout_fn(emb, msg): return linear(emb + msg)
config = LocalMLIAPData(
cutoff=jnp.array([6.0]),
init_function=init_fn,
edge_function=edge_fn,
readout_function=readout_fn,
cache=LocalMLIAPCache(...),
)
EdgeFunction
¶
Bases: Protocol
Protocol for edge/message function.
Computes edge messages (n_edges, msg_dim) from source/target node
embeddings (n_edges, embed_dim) and displacement vectors (n_edges, 3).
Source code in src/kups/potential/mliap/local.py
IsLocalMLIAPGraphParticles
¶
Bases: IsLocalMLIAPParticleData, IsRadiusGraphPoints, Protocol
Combined protocol for local MLIAP particles in radius graph context.
Source code in src/kups/potential/mliap/local.py
IsLocalMLIAPParticleData
¶
Bases: HasPositionsAndAtomicNumbers, HasSystemIndex, Protocol
Protocol for particle data required by LocalMLIAP.
Must provide positions, atomic numbers, and system index.
Source code in src/kups/potential/mliap/local.py
IsLocalMLIAPState
¶
Bases: Protocol
Protocol for states providing all inputs for the local MLIAP potential.
Source code in src/kups/potential/mliap/local.py
LocalMLIAPCache
¶
Cache for incremental energy updates.
Stores intermediate values from the last full computation to enable efficient incremental updates when only a subset of atoms change.
Attributes:
| Name | Type | Description |
|---|---|---|
node_init |
Array
|
Cached node embeddings from init_function, shape |
msg_sum |
Array
|
Cached aggregated messages per node, shape |
Source code in src/kups/potential/mliap/local.py
LocalMLIAPComposer
¶
Composes simulation state into LocalMLIAP input.
Extracts particles, edges, and model config from state, handling both full computation (patch=None) and incremental updates (patch provided).
Class Type Parameters:
| Name | Bound or Constraints | Description | Default |
|---|---|---|---|
State
|
Simulation state type |
required | |
P
|
IsLocalMLIAPGraphParticles
|
Particle data type (positions + system + inclusion/exclusion + atomic numbers) |
required |
S
|
HasUnitCell
|
System data type (unit cell + cutoff) |
required |
Ptch
|
Patch
|
Patch type for incremental updates |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
particles |
View[State, Table[ParticleId, P]]
|
View to extract indexed particle data from state |
systems |
View[State, Table[SystemId, S]]
|
View to extract indexed system data from state |
neighborlist |
View[State, NearestNeighborList]
|
View to extract full neighbor list from state |
model |
Lens[State, LocalMLIAPData]
|
Lens to access model config in state |
probe |
Probe[State, Ptch, IsRadiusGraphProbe[P]] | None
|
Probe to detect particle changes from patch |
Source code in src/kups/potential/mliap/local.py
__call__(state, patch)
¶
Compose state and patch into LocalMLIAP input.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
State
|
Current simulation state |
required |
patch
|
Ptch | None
|
Proposed changes (None for full computation) |
required |
Returns:
| Type | Description |
|---|---|
Sum[LocalMLIAPInput[State, P, S]]
|
Sum containing single LocalMLIAPInput summand |
Source code in src/kups/potential/mliap/local.py
LocalMLIAPData
¶
Configuration for a local MLIAP model.
Bundles the model functions, cutoff, and cache together with a lens for updating the cache in the state.
Attributes:
| Name | Type | Description |
|---|---|---|
cutoff |
Table[SystemId, Array]
|
Interaction cutoff radius [Ã…], shape |
init_function |
NodeInitFunction
|
Maps atomic numbers to node embeddings |
edge_function |
EdgeFunction
|
Computes messages from node pairs and displacements |
readout_function |
ReadoutFunction
|
Computes per-atom energies from embeddings and messages |
cache |
LocalMLIAPCache
|
Cached values for incremental updates |
Source code in src/kups/potential/mliap/local.py
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 | |
from_zip_file(zip_file, n_atoms)
staticmethod
¶
Load a local MLIAP model from a zip archive.
Expects node_init.jax, edge.jax, readout.jax,
and metadata.json (with cutoff and precision keys).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
zip_file
|
str | Path
|
Path to the |
required |
n_atoms
|
int
|
Number of atoms to allocate cache for. |
required |
Returns:
| Type | Description |
|---|---|
LocalMLIAPData
|
Loaded local MLIAP model with initialized cache. |
Source code in src/kups/potential/mliap/local.py
LocalMLIAPInput
¶
Input bundle for local MLIAP energy computation.
Contains all data needed to compute energies, supporting both full computation and incremental updates.
Class Type Parameters:
| Name | Bound or Constraints | Description | Default |
|---|---|---|---|
State
|
Simulation state type |
required | |
P
|
IsLocalMLIAPGraphParticles
|
Particle data type (positions, atomic numbers, system, inclusion/exclusion) |
required |
S
|
HasUnitCell
|
System data type (must have unit cell) |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
point_cloud |
PointCloud[P, S]
|
Current particle positions and systems |
point_cloud_changes |
WithIndices[ParticleId, P] | None
|
Changed particles for incremental update (None for full) |
edges |
Edges[Literal[2]]
|
Current edges within cutoff |
edges_deleted |
Edges[Literal[2]] | None
|
Old edges to remove for incremental update (None for full) |
config |
LocalMLIAPData
|
Model configuration with functions and cache |
cache_lens |
Lens[State, LocalMLIAPCache]
|
Lens to access/update cache in state |
Source code in src/kups/potential/mliap/local.py
LocalMLIAPPatch
¶
Bases: Patch[State]
Patch to update the MLIAP cache in state.
Applied after energy computation to update cached node embeddings and aggregated messages for systems where moves were accepted.
Class Type Parameters:
| Name | Bound or Constraints | Description | Default |
|---|---|---|---|
State
|
Simulation state type |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
cache |
LocalMLIAPCache
|
New cache values to apply |
system_idx |
Index[SystemId]
|
System index for masking |
lens |
Lens[State, LocalMLIAPCache]
|
Lens to access cache in state |
Source code in src/kups/potential/mliap/local.py
__call__(state, accept)
¶
Apply cache update to state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
State
|
Current simulation state |
required |
accept
|
Accept
|
Boolean mask per system indicating accepted moves |
required |
Returns:
| Type | Description |
|---|---|
State
|
Updated state with new cache values where mask is True |
Source code in src/kups/potential/mliap/local.py
NodeInitFunction
¶
Bases: Protocol
Protocol for node initialization function.
Maps atomic numbers (n_atoms,) to node embeddings (n_atoms, embed_dim).
Source code in src/kups/potential/mliap/local.py
ReadoutFunction
¶
Bases: Protocol
Protocol for readout function.
Computes per-atom energies (n_atoms,) from node embeddings and
aggregated messages, both of shape (n_atoms, embed_dim).
Source code in src/kups/potential/mliap/local.py
local_mliap_energy(inp)
¶
Compute MLIAP energy with automatic full/incremental dispatch.
Automatically chooses between full computation and incremental update based on whether point_cloud_changes is provided.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
inp
|
LocalMLIAPInput[State, P, S]
|
Input bundle with point cloud, edges, and config |
required |
Returns:
| Type | Description |
|---|---|
WithPatch[Table[SystemId, Energy], Patch[State]]
|
Total energy per system and patch to update cache |
Source code in src/kups/potential/mliap/local.py
local_mliap_energy_full(inp)
¶
Compute full MLIAP energy from scratch.
Performs complete message passing: initializes node embeddings, computes all edge messages, aggregates, and applies readout. Updates the cache with new values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
inp
|
LocalMLIAPInput[State, P, S]
|
Input containing point cloud, edges, and model config |
required |
Returns:
| Type | Description |
|---|---|
WithPatch[Table[SystemId, Energy], Patch[State]]
|
Total energy per system and patch to update cache |
Source code in src/kups/potential/mliap/local.py
local_mliap_energy_update(inp)
¶
Compute MLIAP energy incrementally using cached values.
Only recomputes embeddings and messages for changed atoms, subtracting old contributions and adding new ones. Much faster than full computation when only a small subset of atoms change.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
inp
|
LocalMLIAPInput[State, P, S]
|
Input with point_cloud_changes and edges_deleted set |
required |
Returns:
| Type | Description |
|---|---|
WithPatch[Table[SystemId, Energy], Patch[State]]
|
Total energy per system and patch to update cache |
Raises:
| Type | Description |
|---|---|
AssertionError
|
If point_cloud_changes or edges_deleted is None |
Source code in src/kups/potential/mliap/local.py
make_local_mliap_from_state(state, probe=None, gradient_lens=EMPTY_LENS, hessian_lens=EMPTY_LENS, hessian_idx_view=EMPTY_LENS, out_idx_view=None)
¶
make_local_mliap_from_state(
state: Lens[
State,
IsLocalMLIAPState[MaybeCached[LocalMLIAPData, Any]],
],
probe: None = None,
gradient_lens: Lens[
LocalMLIAPInput, Gradient
] = EMPTY_LENS,
hessian_lens: Lens[Gradient, Hessian] = EMPTY_LENS,
hessian_idx_view: Lens[State, Hessian] = EMPTY_LENS,
out_idx_view: None = None,
) -> Potential[State, Gradient, Hessian, Patch]
make_local_mliap_from_state(
state: Lens[
State,
IsLocalMLIAPState[
HasCache[LocalMLIAPData, PotentialOut]
],
],
probe: Probe[
State,
Ptch,
IsRadiusGraphProbe[IsLocalMLIAPGraphParticles],
],
gradient_lens: Lens[
LocalMLIAPInput, Gradient
] = EMPTY_LENS,
hessian_lens: Lens[Gradient, Hessian] = EMPTY_LENS,
hessian_idx_view: Lens[State, Hessian] = EMPTY_LENS,
out_idx_view: Lens[
State, PotentialOut[Gradient, Hessian]
]
| None = None,
) -> Potential[State, Gradient, Hessian, Ptch]
Create a local MLIAP potential from a typed state, optionally with incremental updates.
Convenience wrapper around
make_local_mliap_potential.
When probe is None, extracts views from a state satisfying
IsLocalMLIAPState.
When probe is provided, additionally wires the PotentialOut cache for
efficient incremental caching across Monte Carlo steps.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
Any
|
Lens into the sub-state providing particles, unit cell, neighbor list, and local MLIAP model. |
required |
probe
|
Any
|
Detects which particles and neighbor-list edges changed since the last
step. |
None
|
gradient_lens
|
Any
|
Specifies which gradients to compute (e.g., forces). |
EMPTY_LENS
|
hessian_lens
|
Any
|
Specifies which Hessians to compute. |
EMPTY_LENS
|
hessian_idx_view
|
Any
|
Index structure for Hessian updates. |
EMPTY_LENS
|
out_idx_view
|
Any
|
Index into the cached output for partial updates. Only used when
|
None
|
Returns:
| Type | Description |
|---|---|
Any
|
Configured local MLIAP Potential. |
Source code in src/kups/potential/mliap/local.py
make_local_mliap_potential(particles_view, systems_view, cutoffs_view, neighborlist_view, model_lens, probe, gradient_lens, hessian_lens, hessian_idx_view, patch_idx_view=None, out_cache_lens=None)
¶
Create a local MLIAP potential with single message passing.
Constructs a potential from model functions (init, edge, readout) with support for efficient incremental updates via caching.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
particles_view
|
View[State, Table[ParticleId, P]]
|
Extracts indexed particle data (positions, atomic numbers) |
required |
systems_view
|
View[State, Table[SystemId, S]]
|
Extracts indexed system data (unit cell) |
required |
cutoffs_view
|
View[State, Table[SystemId, Array]]
|
Extracts cutoffs as |
required |
neighborlist_view
|
View[State, NearestNeighborList]
|
Extracts full neighbor list |
required |
model_lens
|
Lens[State, LocalMLIAPData]
|
Lens to access LocalMLIAPData containing model functions and cache |
required |
probe
|
Probe[State, Ptch, IsRadiusGraphProbe[P]] | None
|
Detects particle changes and provides updated/old neighbor lists for incremental updates |
required |
gradient_lens
|
Lens[LocalMLIAPInput[State, P, S], Gradients]
|
Specifies which gradients to compute |
required |
hessian_lens
|
Lens[Gradients, Hessians]
|
Specifies which Hessians to compute |
required |
hessian_idx_view
|
View[State, Hessians]
|
Index structure for Hessian computation |
required |
patch_idx_view
|
View[State, PotentialOut[Gradients, Hessians]] | None
|
Index structure for cached output updates (optional) |
None
|
out_cache_lens
|
Lens[State, PotentialOut[Gradients, Hessians]] | None
|
Lens to cache location for incremental updates (optional) |
None
|
Returns:
| Type | Description |
|---|---|
Potential[State, Gradients, Hessians, Ptch]
|
Configured local MLIAP Potential |