Skip to content

kups.potential.mliap.torch.mace

MACE model integration via TorchModuleWrapper.

This module provides wrappers for using PyTorch MACE models directly from JAX via the TorchModuleWrapper bridge, including full kUPS potential integration.

Example
from kups.potential.mliap.torch import load_mace_wrapper

# Load and wrap a MACE model (with forces)
wrapper = load_mace_wrapper("path/to/mace.model")
result = wrapper(node_attrs, positions, edge_index, batch, ptr, shifts, cell)
energy, forces = result["energy"], result["forces"]
For kUPS integration
from kups.potential.mliap.torch.mace import TorchMACEModel, make_torch_mace_potential

# Create TorchMACEModel from wrapper
model = TorchMACEModel(species_to_index, cutoff, num_species, wrapper)

# Create kUPS potential (forces only)
potential = make_torch_mace_potential(
    particles_view, systems_view, neighborlist_view, model, ...
)

# With virial/stress support
potential = make_torch_mace_potential(
    ..., compute_virials=True, ...
)

Requires the torch_dev dependency group: uv sync --group torch_dev

IsTorchMACEParticles

Bases: IsRadiusGraphPoints, HasAtomicNumbers, Protocol

Particle protocol for PyTorch MACE models.

Source code in src/kups/potential/mliap/torch/mace.py
class IsTorchMACEParticles(IsRadiusGraphPoints, HasAtomicNumbers, Protocol):
    """Particle protocol for PyTorch MACE models."""

    ...

MACEModule

Bases: Module

Wraps a MACE model for JAX interop via TorchModuleWrapper.

Supports energy-only, energy+forces, and energy+forces+virials modes via the compute_force and compute_virials flags.

Source code in src/kups/potential/mliap/torch/mace.py
class MACEModule(torch.nn.Module):
    """Wraps a MACE model for JAX interop via TorchModuleWrapper.

    Supports energy-only, energy+forces, and energy+forces+virials modes
    via the compute_force and compute_virials flags.
    """

    def __init__(
        self,
        mace_model: torch.nn.Module,
        compute_force: bool = True,
        compute_virials: bool = False,
    ) -> None:
        """Initialise MACEModule.

        Args:
            mace_model: Underlying PyTorch MACE model.
            compute_force: Whether to compute forces.
            compute_virials: Whether to compute virials.
        """
        super().__init__()
        self.mace = mace_model
        self.mace.eval()
        self.compute_force = compute_force
        self.compute_virials = compute_virials

    def forward(
        self,
        node_attrs: torch.Tensor,
        positions: torch.Tensor,
        edge_index: torch.Tensor,
        batch: torch.Tensor,
        ptr: torch.Tensor,
        shifts: torch.Tensor,
        cell: torch.Tensor | None = None,
    ) -> dict[str, torch.Tensor]:
        """Run MACE forward pass.

        Args:
            node_attrs: One-hot species encoding.
            positions: Atom positions.
            edge_index: Edge index array (2, n_edges).
            batch: System index per atom.
            ptr: Cumulative atom counts per system.
            shifts: Absolute shift vectors per edge.
            cell: Unit cell lattice vectors (optional).

        Returns:
            Dictionary with ``"energy"`` and optionally ``"forces"`` /
            ``"virials"`` tensors.
        """
        input_dict = {
            "node_attrs": node_attrs,
            "positions": positions,
            "edge_index": edge_index,
            "batch": batch,
            "ptr": ptr,
            "shifts": shifts,
            "cell": cell,
        }
        output = self.mace(
            input_dict,
            compute_force=self.compute_force,
            compute_virials=self.compute_virials,
        )

        result: dict[str, torch.Tensor] = {"energy": output["energy"].detach()}
        if self.compute_force:
            result["forces"] = output["forces"].detach()
        if self.compute_virials:
            result["virials"] = output["virials"].detach()
        return result

__init__(mace_model, compute_force=True, compute_virials=False)

Initialise MACEModule.

Parameters:

Name Type Description Default
mace_model Module

Underlying PyTorch MACE model.

required
compute_force bool

Whether to compute forces.

True
compute_virials bool

Whether to compute virials.

False
Source code in src/kups/potential/mliap/torch/mace.py
def __init__(
    self,
    mace_model: torch.nn.Module,
    compute_force: bool = True,
    compute_virials: bool = False,
) -> None:
    """Initialise MACEModule.

    Args:
        mace_model: Underlying PyTorch MACE model.
        compute_force: Whether to compute forces.
        compute_virials: Whether to compute virials.
    """
    super().__init__()
    self.mace = mace_model
    self.mace.eval()
    self.compute_force = compute_force
    self.compute_virials = compute_virials

forward(node_attrs, positions, edge_index, batch, ptr, shifts, cell=None)

Run MACE forward pass.

Parameters:

Name Type Description Default
node_attrs Tensor

One-hot species encoding.

required
positions Tensor

Atom positions.

required
edge_index Tensor

Edge index array (2, n_edges).

required
batch Tensor

System index per atom.

required
ptr Tensor

Cumulative atom counts per system.

required
shifts Tensor

Absolute shift vectors per edge.

required
cell Tensor | None

Unit cell lattice vectors (optional).

None

Returns:

Type Description
dict[str, Tensor]

Dictionary with "energy" and optionally "forces" /

dict[str, Tensor]

"virials" tensors.

Source code in src/kups/potential/mliap/torch/mace.py
def forward(
    self,
    node_attrs: torch.Tensor,
    positions: torch.Tensor,
    edge_index: torch.Tensor,
    batch: torch.Tensor,
    ptr: torch.Tensor,
    shifts: torch.Tensor,
    cell: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
    """Run MACE forward pass.

    Args:
        node_attrs: One-hot species encoding.
        positions: Atom positions.
        edge_index: Edge index array (2, n_edges).
        batch: System index per atom.
        ptr: Cumulative atom counts per system.
        shifts: Absolute shift vectors per edge.
        cell: Unit cell lattice vectors (optional).

    Returns:
        Dictionary with ``"energy"`` and optionally ``"forces"`` /
        ``"virials"`` tensors.
    """
    input_dict = {
        "node_attrs": node_attrs,
        "positions": positions,
        "edge_index": edge_index,
        "batch": batch,
        "ptr": ptr,
        "shifts": shifts,
        "cell": cell,
    }
    output = self.mace(
        input_dict,
        compute_force=self.compute_force,
        compute_virials=self.compute_virials,
    )

    result: dict[str, torch.Tensor] = {"energy": output["energy"].detach()}
    if self.compute_force:
        result["forces"] = output["forces"].detach()
    if self.compute_virials:
        result["virials"] = output["virials"].detach()
    return result

TorchMACEModel

MACE model container for PyTorch models via TorchModuleWrapper.

Attributes:

Name Type Description
species_to_index Array

Mapping from atomic number to MACE species index.

cutoff Table[SystemId, Array]

Model cutoff radius per system.

num_mace_species int

Number of species the MACE model was trained on.

wrapper TorchModuleWrapper

TorchModuleWrapper bridging PyTorch to JAX.

compute_virials bool

Whether to compute virials for stress.

Source code in src/kups/potential/mliap/torch/mace.py
@dataclass
class TorchMACEModel:
    """MACE model container for PyTorch models via TorchModuleWrapper.

    Attributes:
        species_to_index: Mapping from atomic number to MACE species index.
        cutoff: Model cutoff radius per system.
        num_mace_species: Number of species the MACE model was trained on.
        wrapper: TorchModuleWrapper bridging PyTorch to JAX.
        compute_virials: Whether to compute virials for stress.
    """

    species_to_index: Array
    cutoff: Table[SystemId, Array]
    num_mace_species: int = field(static=True)
    wrapper: TorchModuleWrapper = field(static=True)
    compute_virials: bool = field(static=True, default=False)

load_mace_wrapper(model_path, device='cuda', compute_force=True, compute_virials=False, dtype='float32')

Load a PyTorch MACE model and wrap it for JAX computation.

Parameters:

Name Type Description Default
model_path str | Path

Path to the MACE .model file

required
device str

Device to load the model onto (default: "cuda")

'cuda'
compute_force bool

Whether to compute forces (default: True)

True
compute_virials bool

Whether to compute virials for stress (default: False)

False
dtype Literal['float32', 'float64']

Model precision - "float32" (default) or "float64"

'float32'

Returns:

Type Description
TorchModuleWrapper

TorchModuleWrapper containing MACEModule.

Source code in src/kups/potential/mliap/torch/mace.py
def load_mace_wrapper(
    model_path: str | Path,
    device: str = "cuda",
    compute_force: bool = True,
    compute_virials: bool = False,
    dtype: Literal["float32", "float64"] = "float32",
) -> TorchModuleWrapper:
    """Load a PyTorch MACE model and wrap it for JAX computation.

    Args:
        model_path: Path to the MACE .model file
        device: Device to load the model onto (default: "cuda")
        compute_force: Whether to compute forces (default: True)
        compute_virials: Whether to compute virials for stress (default: False)
        dtype: Model precision - "float32" (default) or "float64"

    Returns:
        TorchModuleWrapper containing MACEModule.
    """
    if device == "cuda" and not torch.cuda.is_available():
        raise RuntimeError(
            "Device 'cuda' requested but CUDA is not available. "
            "Use device='cpu' or ensure CUDA is properly installed."
        )

    path = Path(model_path)
    if not path.exists():
        raise FileNotFoundError(f"MACE model not found: {model_path}")

    model = torch.load(path, weights_only=False, map_location=device)
    model.eval()

    if dtype == "float64":
        model = model.double()
    else:
        model = model.float()

    module = MACEModule(
        model, compute_force=compute_force, compute_virials=compute_virials
    )
    return TorchModuleWrapper(module, requires_grad=compute_force)

make_torch_mace_potential(particles_view, systems_view, neighborlist_view, model, cutoffs_view, compute_virials=False, patch_idx_view=None, out_cache_lens=None)

make_torch_mace_potential(
    particles_view: View[State, Table[ParticleId, P]],
    systems_view: View[State, Table[SystemId, S]],
    neighborlist_view: View[State, NNList],
    model: View[State, TorchMACEModel] | TorchMACEModel,
    cutoffs_view: View[State, Table[SystemId, Array]],
    compute_virials: Literal[False] = False,
    patch_idx_view: View[
        State, PotentialOut[Array, EmptyType]
    ]
    | None = None,
    out_cache_lens: Lens[
        State, PotentialOut[Array, EmptyType]
    ]
    | None = None,
) -> Potential[State, Array, EmptyType, Patch[State]]
make_torch_mace_potential(
    particles_view: View[State, Table[ParticleId, P]],
    systems_view: View[State, Table[SystemId, S]],
    neighborlist_view: View[State, NNList],
    model: View[State, TorchMACEModel] | TorchMACEModel,
    cutoffs_view: View[State, Table[SystemId, Array]],
    compute_virials: Literal[True],
    patch_idx_view: View[
        State, PotentialOut[PositionAndUnitCell, EmptyType]
    ]
    | None = None,
    out_cache_lens: Lens[
        State, PotentialOut[PositionAndUnitCell, EmptyType]
    ]
    | None = None,
) -> Potential[
    State, PositionAndUnitCell, EmptyType, Patch[State]
]

Create kUPS potential from PyTorch MACE model.

Forces are computed by PyTorch natively (not JAX autodiff). Hessians are NOT supported (returns EmptyType).

Parameters:

Name Type Description Default
particles_view View[State, Table[ParticleId, P]]

Extracts particle data from state

required
systems_view View[State, Table[SystemId, S]]

Extracts system data (unit cell) from state

required
neighborlist_view View[State, NNList]

Extracts neighbor list from state

required
model View[State, TorchMACEModel] | TorchMACEModel

TorchMACEModel instance or view to model in state

required
cutoffs_view View[State, Table[SystemId, Array]]

Extracts cutoffs as Indexed[SystemId, Array]

required
compute_virials bool

Whether to compute virials for stress (default: False)

False
patch_idx_view Any | None

Cached output index structure (optional)

None
out_cache_lens Any | None

Cache location lens (optional)

None

Returns:

Type Description
Any

kUPS Potential backed by the PyTorch MACE model.

Source code in src/kups/potential/mliap/torch/mace.py
def make_torch_mace_potential[
    State,
    P: IsTorchMACEParticles,
    S: HasUnitCell,
    NNList: NearestNeighborList,
](
    particles_view: View[State, Table[ParticleId, P]],
    systems_view: View[State, Table[SystemId, S]],
    neighborlist_view: View[State, NNList],
    model: View[State, TorchMACEModel] | TorchMACEModel,
    cutoffs_view: View[State, Table[SystemId, Array]],
    compute_virials: bool = False,
    patch_idx_view: Any | None = None,
    out_cache_lens: Any | None = None,
) -> Any:
    """Create kUPS potential from PyTorch MACE model.

    Forces are computed by PyTorch natively (not JAX autodiff).
    Hessians are NOT supported (returns EmptyType).

    Args:
        particles_view: Extracts particle data from state
        systems_view: Extracts system data (unit cell) from state
        neighborlist_view: Extracts neighbor list from state
        model: TorchMACEModel instance or view to model in state
        cutoffs_view: Extracts cutoffs as ``Indexed[SystemId, Array]``
        compute_virials: Whether to compute virials for stress (default: False)
        patch_idx_view: Cached output index structure (optional)
        out_cache_lens: Cache location lens (optional)

    Returns:
        kUPS ``Potential`` backed by the PyTorch MACE model.
    """
    if isinstance(model, TorchMACEModel):
        model_view: View[State, TorchMACEModel] = constant(model)
    else:
        model_view = model

    if compute_virials:

        def virial_fn(inp):
            return torch_mace_model_fn(inp, compute_virials=True)

        return make_mliap_potential(
            model_fn=virial_fn,
            particles_view=particles_view,
            systems_view=systems_view,
            neighborlist_view=neighborlist_view,
            model_view=model_view,
            cutoffs_view=cutoffs_view,
            patch_idx_view=patch_idx_view,
            out_cache_lens=out_cache_lens,
        )

    def forces_fn(inp):
        return torch_mace_model_fn(inp, compute_virials=False)

    return make_mliap_potential(
        model_fn=forces_fn,
        particles_view=particles_view,
        systems_view=systems_view,
        neighborlist_view=neighborlist_view,
        model_view=model_view,
        cutoffs_view=cutoffs_view,
        patch_idx_view=patch_idx_view,
        out_cache_lens=out_cache_lens,
    )

torch_mace_model_fn(inp, *, compute_virials=False)

torch_mace_model_fn(
    inp: TorchMACEInput[P, S],
    *,
    compute_virials: Literal[False] = False,
) -> WithPatch[PotentialOut[Array, EmptyType], IdPatch]
torch_mace_model_fn(
    inp: TorchMACEInput[P, S],
    *,
    compute_virials: Literal[True],
) -> WithPatch[
    PotentialOut[PositionAndUnitCell, EmptyType], IdPatch
]

Model function for PyTorch MACE models.

Returns PotentialOut with forces (and optionally virials) computed by PyTorch.

Parameters:

Name Type Description Default
inp TorchMACEInput[P, S]

Graph potential input containing the MACE model and graph.

required
compute_virials bool

Whether to include virial gradients in the output.

False

Returns:

Type Description
WithPatch[PotentialOut[Array, EmptyType], IdPatch] | WithPatch[PotentialOut[PositionAndUnitCell, EmptyType], IdPatch]

WithPatch containing PotentialOut with energy, gradients, and

WithPatch[PotentialOut[Array, EmptyType], IdPatch] | WithPatch[PotentialOut[PositionAndUnitCell, EmptyType], IdPatch]

an identity patch.

Source code in src/kups/potential/mliap/torch/mace.py
def torch_mace_model_fn[
    P: IsTorchMACEParticles,
    S: HasUnitCell,
](
    inp: TorchMACEInput[P, S],
    *,
    compute_virials: bool = False,
) -> (
    WithPatch[PotentialOut[Array, EmptyType], IdPatch]
    | WithPatch[PotentialOut[PositionAndUnitCell, EmptyType], IdPatch]
):
    """Model function for PyTorch MACE models.

    Returns ``PotentialOut`` with forces (and optionally virials) computed by
    PyTorch.

    Args:
        inp: Graph potential input containing the MACE model and graph.
        compute_virials: Whether to include virial gradients in the output.

    Returns:
        ``WithPatch`` containing ``PotentialOut`` with energy, gradients, and
        an identity patch.
    """
    result = _call_mace_wrapper(inp)

    if compute_virials:
        assert result.virials is not None, "Model must have compute_virials=True"
        gradients = PositionAndUnitCell(
            positions=Table(inp.graph.particles.keys, -result.forces),
            unitcell=Table(
                inp.graph.systems.keys,
                TriclinicUnitCell.from_matrix(result.virials),
            ),
        )
        return WithPatch(
            PotentialOut(Table.arange(result.energy, label=SystemId), gradients, EMPTY),
            IdPatch(),
        )
    else:
        return WithPatch(
            PotentialOut(
                Table.arange(result.energy, label=SystemId), -result.forces, EMPTY
            ),
            IdPatch(),
        )