kups.potential.mliap.torch.mace
¶
MACE model integration via TorchModuleWrapper.
This module provides wrappers for using PyTorch MACE models directly from JAX via the TorchModuleWrapper bridge, including full kUPS potential integration.
Example
For kUPS integration
from kups.potential.mliap.torch.mace import TorchMACEModel, make_torch_mace_potential
# Create TorchMACEModel from wrapper
model = TorchMACEModel(species_to_index, cutoff, num_species, wrapper)
# Create kUPS potential (forces only)
potential = make_torch_mace_potential(
particles_view, systems_view, neighborlist_view, model, ...
)
# With virial/stress support
potential = make_torch_mace_potential(
..., compute_virials=True, ...
)
Requires the torch_dev dependency group: uv sync --group torch_dev
IsTorchMACEParticles
¶
Bases: IsRadiusGraphPoints, HasAtomicNumbers, Protocol
Particle protocol for PyTorch MACE models.
Source code in src/kups/potential/mliap/torch/mace.py
MACEModule
¶
Bases: Module
Wraps a MACE model for JAX interop via TorchModuleWrapper.
Supports energy-only, energy+forces, and energy+forces+virials modes via the compute_force and compute_virials flags.
Source code in src/kups/potential/mliap/torch/mace.py
__init__(mace_model, compute_force=True, compute_virials=False)
¶
Initialise MACEModule.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mace_model
|
Module
|
Underlying PyTorch MACE model. |
required |
compute_force
|
bool
|
Whether to compute forces. |
True
|
compute_virials
|
bool
|
Whether to compute virials. |
False
|
Source code in src/kups/potential/mliap/torch/mace.py
forward(node_attrs, positions, edge_index, batch, ptr, shifts, cell=None)
¶
Run MACE forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
node_attrs
|
Tensor
|
One-hot species encoding. |
required |
positions
|
Tensor
|
Atom positions. |
required |
edge_index
|
Tensor
|
Edge index array (2, n_edges). |
required |
batch
|
Tensor
|
System index per atom. |
required |
ptr
|
Tensor
|
Cumulative atom counts per system. |
required |
shifts
|
Tensor
|
Absolute shift vectors per edge. |
required |
cell
|
Tensor | None
|
Unit cell lattice vectors (optional). |
None
|
Returns:
| Type | Description |
|---|---|
dict[str, Tensor]
|
Dictionary with |
dict[str, Tensor]
|
|
Source code in src/kups/potential/mliap/torch/mace.py
TorchMACEModel
¶
MACE model container for PyTorch models via TorchModuleWrapper.
Attributes:
| Name | Type | Description |
|---|---|---|
species_to_index |
Array
|
Mapping from atomic number to MACE species index. |
cutoff |
Table[SystemId, Array]
|
Model cutoff radius per system. |
num_mace_species |
int
|
Number of species the MACE model was trained on. |
wrapper |
TorchModuleWrapper
|
TorchModuleWrapper bridging PyTorch to JAX. |
compute_virials |
bool
|
Whether to compute virials for stress. |
Source code in src/kups/potential/mliap/torch/mace.py
load_mace_wrapper(model_path, device='cuda', compute_force=True, compute_virials=False, dtype='float32')
¶
Load a PyTorch MACE model and wrap it for JAX computation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_path
|
str | Path
|
Path to the MACE .model file |
required |
device
|
str
|
Device to load the model onto (default: "cuda") |
'cuda'
|
compute_force
|
bool
|
Whether to compute forces (default: True) |
True
|
compute_virials
|
bool
|
Whether to compute virials for stress (default: False) |
False
|
dtype
|
Literal['float32', 'float64']
|
Model precision - "float32" (default) or "float64" |
'float32'
|
Returns:
| Type | Description |
|---|---|
TorchModuleWrapper
|
TorchModuleWrapper containing MACEModule. |
Source code in src/kups/potential/mliap/torch/mace.py
make_torch_mace_potential(particles_view, systems_view, neighborlist_view, model, cutoffs_view, compute_virials=False, patch_idx_view=None, out_cache_lens=None)
¶
make_torch_mace_potential(
particles_view: View[State, Table[ParticleId, P]],
systems_view: View[State, Table[SystemId, S]],
neighborlist_view: View[State, NNList],
model: View[State, TorchMACEModel] | TorchMACEModel,
cutoffs_view: View[State, Table[SystemId, Array]],
compute_virials: Literal[False] = False,
patch_idx_view: View[
State, PotentialOut[Array, EmptyType]
]
| None = None,
out_cache_lens: Lens[
State, PotentialOut[Array, EmptyType]
]
| None = None,
) -> Potential[State, Array, EmptyType, Patch[State]]
make_torch_mace_potential(
particles_view: View[State, Table[ParticleId, P]],
systems_view: View[State, Table[SystemId, S]],
neighborlist_view: View[State, NNList],
model: View[State, TorchMACEModel] | TorchMACEModel,
cutoffs_view: View[State, Table[SystemId, Array]],
compute_virials: Literal[True],
patch_idx_view: View[
State, PotentialOut[PositionAndUnitCell, EmptyType]
]
| None = None,
out_cache_lens: Lens[
State, PotentialOut[PositionAndUnitCell, EmptyType]
]
| None = None,
) -> Potential[
State, PositionAndUnitCell, EmptyType, Patch[State]
]
Create kUPS potential from PyTorch MACE model.
Forces are computed by PyTorch natively (not JAX autodiff). Hessians are NOT supported (returns EmptyType).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
particles_view
|
View[State, Table[ParticleId, P]]
|
Extracts particle data from state |
required |
systems_view
|
View[State, Table[SystemId, S]]
|
Extracts system data (unit cell) from state |
required |
neighborlist_view
|
View[State, NNList]
|
Extracts neighbor list from state |
required |
model
|
View[State, TorchMACEModel] | TorchMACEModel
|
TorchMACEModel instance or view to model in state |
required |
cutoffs_view
|
View[State, Table[SystemId, Array]]
|
Extracts cutoffs as |
required |
compute_virials
|
bool
|
Whether to compute virials for stress (default: False) |
False
|
patch_idx_view
|
Any | None
|
Cached output index structure (optional) |
None
|
out_cache_lens
|
Any | None
|
Cache location lens (optional) |
None
|
Returns:
| Type | Description |
|---|---|
Any
|
kUPS |
Source code in src/kups/potential/mliap/torch/mace.py
torch_mace_model_fn(inp, *, compute_virials=False)
¶
Model function for PyTorch MACE models.
Returns PotentialOut with forces (and optionally virials) computed by
PyTorch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
inp
|
TorchMACEInput[P, S]
|
Graph potential input containing the MACE model and graph. |
required |
compute_virials
|
bool
|
Whether to include virial gradients in the output. |
False
|
Returns:
| Type | Description |
|---|---|
WithPatch[PotentialOut[Array, EmptyType], IdPatch] | WithPatch[PotentialOut[PositionAndUnitCell, EmptyType], IdPatch]
|
|
WithPatch[PotentialOut[Array, EmptyType], IdPatch] | WithPatch[PotentialOut[PositionAndUnitCell, EmptyType], IdPatch]
|
an identity patch. |