kups.potential.mliap.torch
¶
PyTorch model bridge for ML interatomic potentials.
This module provides a bridge for using PyTorch-based ML potentials from JAX via TorchModuleWrapper.
Example
from kups.potential.mliap.torch import load_mace_wrapper
# Load MACE model with forces
wrapper = load_mace_wrapper("model.model")
result = wrapper(node_attrs, positions, edge_index, batch, ptr, shifts, cell)
energy, forces = result["energy"], result["forces"]
# Energy only (faster for MC)
wrapper = load_mace_wrapper("model.model", compute_force=False)
result = wrapper(...)
energy = result["energy"]
Requires the torch_dev dependency group: uv sync --group torch_dev
TorchMACEModel
¶
MACE model container for PyTorch models via TorchModuleWrapper.
Attributes:
| Name | Type | Description |
|---|---|---|
species_to_index |
Array
|
Mapping from atomic number to MACE species index. |
cutoff |
Table[SystemId, Array]
|
Model cutoff radius per system. |
num_mace_species |
int
|
Number of species the MACE model was trained on. |
wrapper |
TorchModuleWrapper
|
TorchModuleWrapper bridging PyTorch to JAX. |
compute_virials |
bool
|
Whether to compute virials for stress. |
Source code in src/kups/potential/mliap/torch/mace.py
load_mace_wrapper(model_path, device='cuda', compute_force=True, compute_virials=False, dtype='float32')
¶
Load a PyTorch MACE model and wrap it for JAX computation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_path
|
str | Path
|
Path to the MACE .model file |
required |
device
|
str
|
Device to load the model onto (default: "cuda") |
'cuda'
|
compute_force
|
bool
|
Whether to compute forces (default: True) |
True
|
compute_virials
|
bool
|
Whether to compute virials for stress (default: False) |
False
|
dtype
|
Literal['float32', 'float64']
|
Model precision - "float32" (default) or "float64" |
'float32'
|
Returns:
| Type | Description |
|---|---|
TorchModuleWrapper
|
TorchModuleWrapper containing MACEModule. |
Source code in src/kups/potential/mliap/torch/mace.py
make_torch_mace_potential(particles_view, systems_view, neighborlist_view, model, cutoffs_view, compute_virials=False, patch_idx_view=None, out_cache_lens=None)
¶
make_torch_mace_potential(
particles_view: View[State, Table[ParticleId, P]],
systems_view: View[State, Table[SystemId, S]],
neighborlist_view: View[State, NNList],
model: View[State, TorchMACEModel] | TorchMACEModel,
cutoffs_view: View[State, Table[SystemId, Array]],
compute_virials: Literal[False] = False,
patch_idx_view: View[
State, PotentialOut[Array, EmptyType]
]
| None = None,
out_cache_lens: Lens[
State, PotentialOut[Array, EmptyType]
]
| None = None,
) -> Potential[State, Array, EmptyType, Patch[State]]
make_torch_mace_potential(
particles_view: View[State, Table[ParticleId, P]],
systems_view: View[State, Table[SystemId, S]],
neighborlist_view: View[State, NNList],
model: View[State, TorchMACEModel] | TorchMACEModel,
cutoffs_view: View[State, Table[SystemId, Array]],
compute_virials: Literal[True],
patch_idx_view: View[
State, PotentialOut[PositionAndUnitCell, EmptyType]
]
| None = None,
out_cache_lens: Lens[
State, PotentialOut[PositionAndUnitCell, EmptyType]
]
| None = None,
) -> Potential[
State, PositionAndUnitCell, EmptyType, Patch[State]
]
Create kUPS potential from PyTorch MACE model.
Forces are computed by PyTorch natively (not JAX autodiff). Hessians are NOT supported (returns EmptyType).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
particles_view
|
View[State, Table[ParticleId, P]]
|
Extracts particle data from state |
required |
systems_view
|
View[State, Table[SystemId, S]]
|
Extracts system data (unit cell) from state |
required |
neighborlist_view
|
View[State, NNList]
|
Extracts neighbor list from state |
required |
model
|
View[State, TorchMACEModel] | TorchMACEModel
|
TorchMACEModel instance or view to model in state |
required |
cutoffs_view
|
View[State, Table[SystemId, Array]]
|
Extracts cutoffs as |
required |
compute_virials
|
bool
|
Whether to compute virials for stress (default: False) |
False
|
patch_idx_view
|
Any | None
|
Cached output index structure (optional) |
None
|
out_cache_lens
|
Any | None
|
Cache location lens (optional) |
None
|
Returns:
| Type | Description |
|---|---|
Any
|
kUPS |