Skip to content

kups.potential.mliap.torch

PyTorch model bridge for ML interatomic potentials.

This module provides a bridge for using PyTorch-based ML potentials from JAX via TorchModuleWrapper.

Example
from kups.potential.mliap.torch import load_mace_wrapper

# Load MACE model with forces
wrapper = load_mace_wrapper("model.model")
result = wrapper(node_attrs, positions, edge_index, batch, ptr, shifts, cell)
energy, forces = result["energy"], result["forces"]

# Energy only (faster for MC)
wrapper = load_mace_wrapper("model.model", compute_force=False)
result = wrapper(...)
energy = result["energy"]

Requires the torch_dev dependency group: uv sync --group torch_dev

TorchMACEModel

MACE model container for PyTorch models via TorchModuleWrapper.

Attributes:

Name Type Description
species_to_index Array

Mapping from atomic number to MACE species index.

cutoff Table[SystemId, Array]

Model cutoff radius per system.

num_mace_species int

Number of species the MACE model was trained on.

wrapper TorchModuleWrapper

TorchModuleWrapper bridging PyTorch to JAX.

compute_virials bool

Whether to compute virials for stress.

Source code in src/kups/potential/mliap/torch/mace.py
@dataclass
class TorchMACEModel:
    """MACE model container for PyTorch models via TorchModuleWrapper.

    Attributes:
        species_to_index: Mapping from atomic number to MACE species index.
        cutoff: Model cutoff radius per system.
        num_mace_species: Number of species the MACE model was trained on.
        wrapper: TorchModuleWrapper bridging PyTorch to JAX.
        compute_virials: Whether to compute virials for stress.
    """

    species_to_index: Array
    cutoff: Table[SystemId, Array]
    num_mace_species: int = field(static=True)
    wrapper: TorchModuleWrapper = field(static=True)
    compute_virials: bool = field(static=True, default=False)

load_mace_wrapper(model_path, device='cuda', compute_force=True, compute_virials=False, dtype='float32')

Load a PyTorch MACE model and wrap it for JAX computation.

Parameters:

Name Type Description Default
model_path str | Path

Path to the MACE .model file

required
device str

Device to load the model onto (default: "cuda")

'cuda'
compute_force bool

Whether to compute forces (default: True)

True
compute_virials bool

Whether to compute virials for stress (default: False)

False
dtype Literal['float32', 'float64']

Model precision - "float32" (default) or "float64"

'float32'

Returns:

Type Description
TorchModuleWrapper

TorchModuleWrapper containing MACEModule.

Source code in src/kups/potential/mliap/torch/mace.py
def load_mace_wrapper(
    model_path: str | Path,
    device: str = "cuda",
    compute_force: bool = True,
    compute_virials: bool = False,
    dtype: Literal["float32", "float64"] = "float32",
) -> TorchModuleWrapper:
    """Load a PyTorch MACE model and wrap it for JAX computation.

    Args:
        model_path: Path to the MACE .model file
        device: Device to load the model onto (default: "cuda")
        compute_force: Whether to compute forces (default: True)
        compute_virials: Whether to compute virials for stress (default: False)
        dtype: Model precision - "float32" (default) or "float64"

    Returns:
        TorchModuleWrapper containing MACEModule.
    """
    if device == "cuda" and not torch.cuda.is_available():
        raise RuntimeError(
            "Device 'cuda' requested but CUDA is not available. "
            "Use device='cpu' or ensure CUDA is properly installed."
        )

    path = Path(model_path)
    if not path.exists():
        raise FileNotFoundError(f"MACE model not found: {model_path}")

    model = torch.load(path, weights_only=False, map_location=device)
    model.eval()

    if dtype == "float64":
        model = model.double()
    else:
        model = model.float()

    module = MACEModule(
        model, compute_force=compute_force, compute_virials=compute_virials
    )
    return TorchModuleWrapper(module, requires_grad=compute_force)

make_torch_mace_potential(particles_view, systems_view, neighborlist_view, model, cutoffs_view, compute_virials=False, patch_idx_view=None, out_cache_lens=None)

make_torch_mace_potential(
    particles_view: View[State, Table[ParticleId, P]],
    systems_view: View[State, Table[SystemId, S]],
    neighborlist_view: View[State, NNList],
    model: View[State, TorchMACEModel] | TorchMACEModel,
    cutoffs_view: View[State, Table[SystemId, Array]],
    compute_virials: Literal[False] = False,
    patch_idx_view: View[
        State, PotentialOut[Array, EmptyType]
    ]
    | None = None,
    out_cache_lens: Lens[
        State, PotentialOut[Array, EmptyType]
    ]
    | None = None,
) -> Potential[State, Array, EmptyType, Patch[State]]
make_torch_mace_potential(
    particles_view: View[State, Table[ParticleId, P]],
    systems_view: View[State, Table[SystemId, S]],
    neighborlist_view: View[State, NNList],
    model: View[State, TorchMACEModel] | TorchMACEModel,
    cutoffs_view: View[State, Table[SystemId, Array]],
    compute_virials: Literal[True],
    patch_idx_view: View[
        State, PotentialOut[PositionAndUnitCell, EmptyType]
    ]
    | None = None,
    out_cache_lens: Lens[
        State, PotentialOut[PositionAndUnitCell, EmptyType]
    ]
    | None = None,
) -> Potential[
    State, PositionAndUnitCell, EmptyType, Patch[State]
]

Create kUPS potential from PyTorch MACE model.

Forces are computed by PyTorch natively (not JAX autodiff). Hessians are NOT supported (returns EmptyType).

Parameters:

Name Type Description Default
particles_view View[State, Table[ParticleId, P]]

Extracts particle data from state

required
systems_view View[State, Table[SystemId, S]]

Extracts system data (unit cell) from state

required
neighborlist_view View[State, NNList]

Extracts neighbor list from state

required
model View[State, TorchMACEModel] | TorchMACEModel

TorchMACEModel instance or view to model in state

required
cutoffs_view View[State, Table[SystemId, Array]]

Extracts cutoffs as Indexed[SystemId, Array]

required
compute_virials bool

Whether to compute virials for stress (default: False)

False
patch_idx_view Any | None

Cached output index structure (optional)

None
out_cache_lens Any | None

Cache location lens (optional)

None

Returns:

Type Description
Any

kUPS Potential backed by the PyTorch MACE model.

Source code in src/kups/potential/mliap/torch/mace.py
def make_torch_mace_potential[
    State,
    P: IsTorchMACEParticles,
    S: HasUnitCell,
    NNList: NearestNeighborList,
](
    particles_view: View[State, Table[ParticleId, P]],
    systems_view: View[State, Table[SystemId, S]],
    neighborlist_view: View[State, NNList],
    model: View[State, TorchMACEModel] | TorchMACEModel,
    cutoffs_view: View[State, Table[SystemId, Array]],
    compute_virials: bool = False,
    patch_idx_view: Any | None = None,
    out_cache_lens: Any | None = None,
) -> Any:
    """Create kUPS potential from PyTorch MACE model.

    Forces are computed by PyTorch natively (not JAX autodiff).
    Hessians are NOT supported (returns EmptyType).

    Args:
        particles_view: Extracts particle data from state
        systems_view: Extracts system data (unit cell) from state
        neighborlist_view: Extracts neighbor list from state
        model: TorchMACEModel instance or view to model in state
        cutoffs_view: Extracts cutoffs as ``Indexed[SystemId, Array]``
        compute_virials: Whether to compute virials for stress (default: False)
        patch_idx_view: Cached output index structure (optional)
        out_cache_lens: Cache location lens (optional)

    Returns:
        kUPS ``Potential`` backed by the PyTorch MACE model.
    """
    if isinstance(model, TorchMACEModel):
        model_view: View[State, TorchMACEModel] = constant(model)
    else:
        model_view = model

    if compute_virials:

        def virial_fn(inp):
            return torch_mace_model_fn(inp, compute_virials=True)

        return make_mliap_potential(
            model_fn=virial_fn,
            particles_view=particles_view,
            systems_view=systems_view,
            neighborlist_view=neighborlist_view,
            model_view=model_view,
            cutoffs_view=cutoffs_view,
            patch_idx_view=patch_idx_view,
            out_cache_lens=out_cache_lens,
        )

    def forces_fn(inp):
        return torch_mace_model_fn(inp, compute_virials=False)

    return make_mliap_potential(
        model_fn=forces_fn,
        particles_view=particles_view,
        systems_view=systems_view,
        neighborlist_view=neighborlist_view,
        model_view=model_view,
        cutoffs_view=cutoffs_view,
        patch_idx_view=patch_idx_view,
        out_cache_lens=out_cache_lens,
    )