kups.potential.mliap.torch.interface
¶
Universal PyTorch MLFF interface.
Mirrors the JAX tojax interface for PyTorch
models. Each MLFF backend only needs to provide a torch.nn.Module whose
forward consumes the universal AtomGraphInput
and returns a dict with "energy", "position_gradients", and
optionally "cell_gradients". All graph extraction, padding, and kUPS
Potential wiring is handled here.
Example
from kups.potential.mliap.torch.interface import (
TorchMliap, make_torch_mliap_from_state,
)
# A backend provides a Module with the universal forward contract:
model = TorchMliap.from_module(my_module, cutoff=6.0, compute_cell_gradients=True)
# Wire into a kUPS Potential:
potential = make_torch_mliap_from_state(
state_lens, compute_position_and_cell_gradients=True,
)
Requires the torch_dev dependency group: uv sync --group torch_dev.
AtomGraphInput
¶
Bases: TypedDict
Universal input schema shared by all torch MLFF backends.
Mirrors the JAX AtomGraphInput.
Shapes use N atoms, B systems, and E edges (each padded by one
extra atom/system to work around backends that cannot handle empty graphs).
Source code in src/kups/potential/mliap/torch/interface.py
IsTorchMliapParticles
¶
Bases: IsRadiusGraphPoints, HasAtomicNumbers, Protocol
Particle protocol for torch MLFF models.
Source code in src/kups/potential/mliap/torch/interface.py
IsTorchMliapState
¶
Bases: Protocol
Protocol for states providing all inputs for a torch MLFF potential.
Source code in src/kups/potential/mliap/torch/interface.py
TorchMliap
¶
Container for a torch MLFF wired into JAX.
Attributes:
| Name | Type | Description |
|---|---|---|
cutoff |
Table[SystemId, Array]
|
Per-system cutoff radius [Å]. |
wrapper |
TorchModuleWrapper
|
|
compute_cell_gradients |
bool
|
Whether the module returns |
Source code in src/kups/potential/mliap/torch/interface.py
call(input)
¶
from_module(module, cutoff, compute_cell_gradients=False)
staticmethod
¶
Wrap a torch.nn.Module that returns energy and gradients.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module
|
Module
|
torch |
required |
cutoff
|
float
|
Interaction cutoff radius [Å]. |
required |
compute_cell_gradients
|
bool
|
Whether the module returns
|
False
|
Returns:
| Type | Description |
|---|---|
'TorchMliap'
|
Configured |
Source code in src/kups/potential/mliap/torch/interface.py
TorchMliapForward
¶
Bases: Protocol
Forward contract for a torch MLFF module.
The module must accept an AtomGraphInput dict and return a dict with:
"energy":(B,)per-system total energies."position_gradients":(N, 3):math:\partial E / \partial r."cell_gradients":(B, 3, 3):math:\partial E / \partial h, required only whencompute_cell_gradients=True.
Outputs are gradients (not forces); adapters around models that natively produce forces/virials negate appropriately inside the module.
Source code in src/kups/potential/mliap/torch/interface.py
lattice_gradient_from_virial(forces, positions, batch, cell, virial)
¶
Recover ∂E/∂h from a symmetric-strain virial.
Many torch MLFF backends (MACE, UMA, …) return a virial or stress quantity that encodes the gradient of energy under a symmetric infinitesimal strain applied jointly to positions and cell:
r_b → r_b + r_b @ ε (per atom b)
h_s → h_s + h_s @ ε (per system s)
The virial returned by the backend is then
virial = sym(pos_virial + cell_virial) where:
pos_virial[s, j, k] = Σ_{b∈s} (∂E/∂r_b)_j · (r_b)_k
cell_virial = cell^T @ ∂E/∂h
sym(M) = (M + M^T) / 2
Given forces (= -∂E/∂r), positions, batch, cell, and the virial, this
function reconstructs pos_virial from -forces ⊗ positions, subtracts
it, and solves cell^T @ (∂E/∂h) = cell_virial for the raw lattice
gradient. Assumes cell^T @ ∂E/∂h is symmetric (rotational invariance
of the energy); the antisymmetric part is unrecoverable from the
symmetric-strain virial alone.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
forces
|
'torch.Tensor'
|
|
required |
positions
|
'torch.Tensor'
|
|
required |
batch
|
'torch.Tensor'
|
|
required |
cell
|
'torch.Tensor'
|
|
required |
virial
|
'torch.Tensor'
|
|
required |
Returns:
| Type | Description |
|---|---|
'torch.Tensor'
|
|
Source code in src/kups/potential/mliap/torch/interface.py
make_torch_mliap_from_state(state, *, compute_position_and_cell_gradients=False)
¶
Create a torch MLFF potential from a typed state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
Any
|
Lens into a sub-state providing particles, systems, neighbor list, and torch MLFF model. |
required |
compute_position_and_cell_gradients
|
bool
|
When |
False
|
Returns:
| Type | Description |
|---|---|
Any
|
Configured torch MLFF |
Source code in src/kups/potential/mliap/torch/interface.py
make_torch_mliap_potential(particles_view, systems_view, neighborlist_view, model, cutoffs_view, compute_cell_gradients=False, patch_idx_view=None, out_cache_lens=None)
¶
make_torch_mliap_potential(
particles_view: View[State, Table[ParticleId, P]],
systems_view: View[State, Table[SystemId, S]],
neighborlist_view: View[State, NNList],
model: View[State, TorchMliap] | TorchMliap,
cutoffs_view: View[State, Table[SystemId, Array]],
compute_cell_gradients: Literal[False] = False,
patch_idx_view: View[
State, PotentialOut[Array, EmptyType]
]
| None = None,
out_cache_lens: Lens[
State, PotentialOut[Array, EmptyType]
]
| None = None,
) -> Potential[State, Array, EmptyType, Patch[State]]
make_torch_mliap_potential(
particles_view: View[State, Table[ParticleId, P]],
systems_view: View[State, Table[SystemId, S]],
neighborlist_view: View[State, NNList],
model: View[State, TorchMliap] | TorchMliap,
cutoffs_view: View[State, Table[SystemId, Array]],
compute_cell_gradients: Literal[True],
patch_idx_view: View[
State, PotentialOut[PositionAndCell, EmptyType]
]
| None = None,
out_cache_lens: Lens[
State, PotentialOut[PositionAndCell, EmptyType]
]
| None = None,
) -> Potential[
State, PositionAndCell, EmptyType, Patch[State]
]
Create a kUPS Potential from a TorchMliap.
Forces (and optionally stress) are computed inside the torch module; the
kUPS side just routes the precomputed gradients through DirectPotential.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
particles_view
|
Any
|
Extracts particle data from state. |
required |
systems_view
|
Any
|
Extracts system data (cell) from state. |
required |
neighborlist_view
|
Any
|
Extracts neighbor list from state. |
required |
model
|
Any
|
|
required |
cutoffs_view
|
Any
|
Extracts cutoffs as |
required |
compute_cell_gradients
|
bool
|
When |
False
|
patch_idx_view
|
Any | None
|
Cached output index structure (optional). |
None
|
out_cache_lens
|
Any | None
|
Cache location lens (optional). |
None
|
Returns:
| Type | Description |
|---|---|
Any
|
Configured |
Source code in src/kups/potential/mliap/torch/interface.py
torch_mliap_model_fn(inp, *, compute_cell_gradients=False)
¶
Run a TorchMliap on a graph input and package the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
inp
|
TorchMliapInput[P, S]
|
Graph potential input bundling the model and graph. |
required |
compute_cell_gradients
|
bool
|
Whether to wrap |
False
|
Returns:
| Type | Description |
|---|---|
WithPatch[PotentialOut[Array, EmptyType], IdPatch] | WithPatch[PotentialOut[PositionAndCell, EmptyType], IdPatch]
|
|
WithPatch[PotentialOut[Array, EmptyType], IdPatch] | WithPatch[PotentialOut[PositionAndCell, EmptyType], IdPatch]
|
an identity patch. |