Skip to content

kups.potential.mliap.torch

PyTorch ML interatomic potentials.

A universal interface mirroring tojax: each torch MLFF backend only fills in a torch.nn.Module whose forward consumes AtomGraphInput and returns {"energy", "position_gradients", "cell_gradients"}. All graph extraction, padding, and kUPS Potential wiring is shared.

Example
from kups.potential.mliap.torch import load_mace, make_torch_mliap_from_state

model = load_mace("mace.model", compute_cell_gradients=True)
potential = make_torch_mliap_from_state(
    state_lens, compute_position_and_cell_gradients=True,
)

Requires the torch_dev dependency group: uv sync --group torch_dev.

AtomGraphInput

Bases: TypedDict

Universal input schema shared by all torch MLFF backends.

Mirrors the JAX AtomGraphInput. Shapes use N atoms, B systems, and E edges (each padded by one extra atom/system to work around backends that cannot handle empty graphs).

Source code in src/kups/potential/mliap/torch/interface.py
class AtomGraphInput(TypedDict):
    """Universal input schema shared by all torch MLFF backends.

    Mirrors the JAX [AtomGraphInput][kups.potential.mliap.tojax.AtomGraphInput].
    Shapes use ``N`` atoms, ``B`` systems, and ``E`` edges (each padded by one
    extra atom/system to work around backends that cannot handle empty graphs).
    """

    pos: Array  # (N, 3)
    atomic_numbers: Array  # (N,)
    cell: Array  # (B, 3, 3)
    pbc: Array  # (B, 3)
    edge_index: Array  # (2, E)
    cell_offsets: Array  # (E, 3) integer multiples of cell vectors
    batch: Array  # (N,)
    charge: Array  # (B,)
    spin: Array  # (B,)

IsTorchMliapParticles

Bases: IsRadiusGraphPoints, HasAtomicNumbers, Protocol

Particle protocol for torch MLFF models.

Source code in src/kups/potential/mliap/torch/interface.py
class IsTorchMliapParticles(IsRadiusGraphPoints, HasAtomicNumbers, Protocol):
    """Particle protocol for torch MLFF models."""

    ...

IsTorchMliapState

Bases: Protocol

Protocol for states providing all inputs for a torch MLFF potential.

Source code in src/kups/potential/mliap/torch/interface.py
class IsTorchMliapState(Protocol):
    """Protocol for states providing all inputs for a torch MLFF potential."""

    @property
    def particles(self) -> Table[ParticleId, IsTorchMliapParticles]: ...
    @property
    def systems(self) -> Table[SystemId, HasCell]: ...
    @property
    def neighborlist(self) -> NearestNeighborList: ...
    @property
    def torch_mliap_model(self) -> TorchMliap: ...

MACEModule

Bases: Module

Adapter: AtomGraphInput → MACE PyG-style input → energy + gradients.

Wraps a MACE nn.Module and translates the universal graph input into the (node_attrs, positions, edge_index, batch, ptr, shifts, cell) tuple that MACE expects. Returns gradients of energy w.r.t. positions (and optionally cell vectors).

Attributes:

Name Type Description
mace

Underlying MACE nn.Module.

species_to_index Tensor

Buffer mapping atomic number Z → MACE species index (0..num_species-1).

num_species

Number of species the MACE model was trained on.

compute_cell_gradients

Whether to compute cell gradients (stress).

Source code in src/kups/potential/mliap/torch/mace.py
class MACEModule(torch.nn.Module):
    """Adapter: ``AtomGraphInput`` → MACE PyG-style input → energy + gradients.

    Wraps a MACE ``nn.Module`` and translates the universal graph input into
    the (``node_attrs``, ``positions``, ``edge_index``, ``batch``, ``ptr``,
    ``shifts``, ``cell``) tuple that MACE expects. Returns gradients of energy
    w.r.t. positions (and optionally cell vectors).

    Attributes:
        mace: Underlying MACE ``nn.Module``.
        species_to_index: Buffer mapping atomic number ``Z`` → MACE species
            index (0..``num_species``-1).
        num_species: Number of species the MACE model was trained on.
        compute_cell_gradients: Whether to compute cell gradients (stress).
    """

    species_to_index: torch.Tensor

    def __init__(
        self,
        mace_model: torch.nn.Module,
        species_to_index: torch.Tensor,
        num_species: int,
        compute_cell_gradients: bool = False,
    ) -> None:
        """Initialise ``MACEModule``.

        Args:
            mace_model: Underlying MACE ``nn.Module``.
            species_to_index: Tensor mapping ``Z`` → MACE index. Indexed by
                atomic number; entries for unsupported ``Z`` are ignored.
            num_species: Number of species the MACE model was trained on.
            compute_cell_gradients: Whether to compute cell gradients.
        """
        super().__init__()
        self.mace = mace_model
        self.mace.eval()
        self.register_buffer("species_to_index", species_to_index.to(dtype=torch.int64))
        self.num_species = num_species
        self.compute_cell_gradients = compute_cell_gradients

    def forward(self, input: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        """Run MACE on a universal ``AtomGraphInput`` and return gradients.

        Args:
            input: Dict matching the universal ``AtomGraphInput`` schema.

        Returns:
            Dict with ``"energy"`` ``(B,)``, ``"position_gradients"`` ``(N, 3)``,
            and optionally ``"cell_gradients"`` ``(B, 3, 3)``.
        """
        # MACE is loaded at a fixed precision (float32 or float64 — see
        # ``load_mace(dtype=...)``); align every float input to that dtype
        # rather than whatever JAX hands us.
        model_dtype = next(self.mace.parameters()).dtype

        pos = input["pos"].to(model_dtype)
        species = input["atomic_numbers"]
        cell = input["cell"].to(model_dtype)
        batch = input["batch"]
        edge_index = input["edge_index"]
        cell_offsets = input["cell_offsets"].to(model_dtype)

        n_atoms = pos.shape[0]
        n_sys = cell.shape[0]

        # kUPS's neighbor list is a fixed-size buffer; ``indices_in`` maps
        # unused slots to the OOB sentinel ``len(keys) == n_atoms``. Drop
        # padded edges before indexing into atom-shaped tensors.
        valid_edge = (edge_index < n_atoms).all(dim=0)
        edge_index = edge_index[:, valid_edge]
        cell_offsets = cell_offsets[valid_edge]

        counts = torch.bincount(batch, minlength=n_sys)
        ptr = torch.cat([batch.new_zeros(1), counts.cumsum(0)])
        # ``species_to_index`` is registered as a CPU buffer at construction
        # time. ``TorchModuleWrapper`` calls us once with mock tensors on the
        # device of the wrapped MACE model (typically cuda) without first
        # calling ``self.to(device)``, so the indexing would straddle devices.
        # Co-locate the lookup table with the input on every call.
        species_to_index = self.species_to_index.to(species.device)
        node_attrs = F.one_hot(species_to_index[species], self.num_species).to(
            pos.dtype
        )
        cell_per_edge = cell[batch[edge_index[0]]]
        # cell_offsets (E,3) integer multiples → absolute Å via per-edge cell:
        # shifts[e, j] = Σ_i cell_offsets[e, i] * cell_per_edge[e, i, j]
        shifts = (cell_offsets.to(pos.dtype).unsqueeze(1) @ cell_per_edge).squeeze(1)

        out = self.mace(
            {
                "node_attrs": node_attrs,
                "positions": pos,
                "edge_index": edge_index,
                "batch": batch,
                "ptr": ptr,
                "shifts": shifts,
                # MACE's ``prepare_graph`` reads ``unit_shifts`` (the integer
                # cell-offset multiples) when ``compute_virials=True`` to build
                # the strain-perturbed graph in ``get_symmetric_displacement``.
                "unit_shifts": cell_offsets.to(pos.dtype),
                "cell": cell,
            },
            compute_force=True,
            compute_virials=self.compute_cell_gradients,
        )

        forces = out["forces"]
        result: dict[str, torch.Tensor] = {
            "energy": out["energy"].detach(),
            "position_gradients": (-forces).detach(),
        }
        if self.compute_cell_gradients:
            # MACE's ``virials`` = -sym(pos_virial + cell^T @ ∂E/∂h).
            # Negate to get the symmetric-strain virial, then invert.
            virial = -out["virials"]
            cell_grad = lattice_gradient_from_virial(
                forces=forces,
                positions=pos,
                batch=batch,
                cell=cell,
                virial=virial,
            )
            result["cell_gradients"] = cell_grad.detach()
        return result

__init__(mace_model, species_to_index, num_species, compute_cell_gradients=False)

Initialise MACEModule.

Parameters:

Name Type Description Default
mace_model Module

Underlying MACE nn.Module.

required
species_to_index Tensor

Tensor mapping Z → MACE index. Indexed by atomic number; entries for unsupported Z are ignored.

required
num_species int

Number of species the MACE model was trained on.

required
compute_cell_gradients bool

Whether to compute cell gradients.

False
Source code in src/kups/potential/mliap/torch/mace.py
def __init__(
    self,
    mace_model: torch.nn.Module,
    species_to_index: torch.Tensor,
    num_species: int,
    compute_cell_gradients: bool = False,
) -> None:
    """Initialise ``MACEModule``.

    Args:
        mace_model: Underlying MACE ``nn.Module``.
        species_to_index: Tensor mapping ``Z`` → MACE index. Indexed by
            atomic number; entries for unsupported ``Z`` are ignored.
        num_species: Number of species the MACE model was trained on.
        compute_cell_gradients: Whether to compute cell gradients.
    """
    super().__init__()
    self.mace = mace_model
    self.mace.eval()
    self.register_buffer("species_to_index", species_to_index.to(dtype=torch.int64))
    self.num_species = num_species
    self.compute_cell_gradients = compute_cell_gradients

forward(input)

Run MACE on a universal AtomGraphInput and return gradients.

Parameters:

Name Type Description Default
input dict[str, Tensor]

Dict matching the universal AtomGraphInput schema.

required

Returns:

Type Description
dict[str, Tensor]

Dict with "energy" (B,), "position_gradients" (N, 3),

dict[str, Tensor]

and optionally "cell_gradients" (B, 3, 3).

Source code in src/kups/potential/mliap/torch/mace.py
def forward(self, input: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
    """Run MACE on a universal ``AtomGraphInput`` and return gradients.

    Args:
        input: Dict matching the universal ``AtomGraphInput`` schema.

    Returns:
        Dict with ``"energy"`` ``(B,)``, ``"position_gradients"`` ``(N, 3)``,
        and optionally ``"cell_gradients"`` ``(B, 3, 3)``.
    """
    # MACE is loaded at a fixed precision (float32 or float64 — see
    # ``load_mace(dtype=...)``); align every float input to that dtype
    # rather than whatever JAX hands us.
    model_dtype = next(self.mace.parameters()).dtype

    pos = input["pos"].to(model_dtype)
    species = input["atomic_numbers"]
    cell = input["cell"].to(model_dtype)
    batch = input["batch"]
    edge_index = input["edge_index"]
    cell_offsets = input["cell_offsets"].to(model_dtype)

    n_atoms = pos.shape[0]
    n_sys = cell.shape[0]

    # kUPS's neighbor list is a fixed-size buffer; ``indices_in`` maps
    # unused slots to the OOB sentinel ``len(keys) == n_atoms``. Drop
    # padded edges before indexing into atom-shaped tensors.
    valid_edge = (edge_index < n_atoms).all(dim=0)
    edge_index = edge_index[:, valid_edge]
    cell_offsets = cell_offsets[valid_edge]

    counts = torch.bincount(batch, minlength=n_sys)
    ptr = torch.cat([batch.new_zeros(1), counts.cumsum(0)])
    # ``species_to_index`` is registered as a CPU buffer at construction
    # time. ``TorchModuleWrapper`` calls us once with mock tensors on the
    # device of the wrapped MACE model (typically cuda) without first
    # calling ``self.to(device)``, so the indexing would straddle devices.
    # Co-locate the lookup table with the input on every call.
    species_to_index = self.species_to_index.to(species.device)
    node_attrs = F.one_hot(species_to_index[species], self.num_species).to(
        pos.dtype
    )
    cell_per_edge = cell[batch[edge_index[0]]]
    # cell_offsets (E,3) integer multiples → absolute Å via per-edge cell:
    # shifts[e, j] = Σ_i cell_offsets[e, i] * cell_per_edge[e, i, j]
    shifts = (cell_offsets.to(pos.dtype).unsqueeze(1) @ cell_per_edge).squeeze(1)

    out = self.mace(
        {
            "node_attrs": node_attrs,
            "positions": pos,
            "edge_index": edge_index,
            "batch": batch,
            "ptr": ptr,
            "shifts": shifts,
            # MACE's ``prepare_graph`` reads ``unit_shifts`` (the integer
            # cell-offset multiples) when ``compute_virials=True`` to build
            # the strain-perturbed graph in ``get_symmetric_displacement``.
            "unit_shifts": cell_offsets.to(pos.dtype),
            "cell": cell,
        },
        compute_force=True,
        compute_virials=self.compute_cell_gradients,
    )

    forces = out["forces"]
    result: dict[str, torch.Tensor] = {
        "energy": out["energy"].detach(),
        "position_gradients": (-forces).detach(),
    }
    if self.compute_cell_gradients:
        # MACE's ``virials`` = -sym(pos_virial + cell^T @ ∂E/∂h).
        # Negate to get the symmetric-strain virial, then invert.
        virial = -out["virials"]
        cell_grad = lattice_gradient_from_virial(
            forces=forces,
            positions=pos,
            batch=batch,
            cell=cell,
            virial=virial,
        )
        result["cell_gradients"] = cell_grad.detach()
    return result

TorchMliap

Container for a torch MLFF wired into JAX.

Attributes:

Name Type Description
cutoff Table[SystemId, Array]

Per-system cutoff radius [Å].

wrapper TorchModuleWrapper

TorchModuleWrapper over the MLFF module.

compute_cell_gradients bool

Whether the module returns "cell_gradients".

Source code in src/kups/potential/mliap/torch/interface.py
@dataclass
class TorchMliap:
    """Container for a torch MLFF wired into JAX.

    Attributes:
        cutoff: Per-system cutoff radius [Å].
        wrapper: ``TorchModuleWrapper`` over the MLFF module.
        compute_cell_gradients: Whether the module returns ``"cell_gradients"``.
    """

    cutoff: Table[SystemId, Array]
    wrapper: TorchModuleWrapper = field(static=True)
    compute_cell_gradients: bool = field(static=True, default=False)

    @staticmethod
    def from_module(
        module: torch.nn.Module,
        cutoff: float,
        compute_cell_gradients: bool = False,
    ) -> "TorchMliap":
        """Wrap a torch.nn.Module that returns energy and gradients.

        Args:
            module: torch ``nn.Module`` satisfying ``TorchMliapForward``.
            cutoff: Interaction cutoff radius [Å].
            compute_cell_gradients: Whether the module returns
                ``"cell_gradients"`` for stress computation.

        Returns:
            Configured ``TorchMliap`` ready for use with the kUPS interface.
        """
        wrapper = TorchModuleWrapper(module, requires_grad=True)
        return TorchMliap(
            cutoff=Table((SystemId(0),), jnp.array([cutoff], float)),
            wrapper=wrapper,
            compute_cell_gradients=compute_cell_gradients,
        )

    def call(self, input: AtomGraphInput) -> dict[str, Array]:
        """Call the wrapped module on a prepared ``AtomGraphInput``."""
        return self.wrapper(input)

call(input)

Call the wrapped module on a prepared AtomGraphInput.

Source code in src/kups/potential/mliap/torch/interface.py
def call(self, input: AtomGraphInput) -> dict[str, Array]:
    """Call the wrapped module on a prepared ``AtomGraphInput``."""
    return self.wrapper(input)

from_module(module, cutoff, compute_cell_gradients=False) staticmethod

Wrap a torch.nn.Module that returns energy and gradients.

Parameters:

Name Type Description Default
module Module

torch nn.Module satisfying TorchMliapForward.

required
cutoff float

Interaction cutoff radius [Å].

required
compute_cell_gradients bool

Whether the module returns "cell_gradients" for stress computation.

False

Returns:

Type Description
'TorchMliap'

Configured TorchMliap ready for use with the kUPS interface.

Source code in src/kups/potential/mliap/torch/interface.py
@staticmethod
def from_module(
    module: torch.nn.Module,
    cutoff: float,
    compute_cell_gradients: bool = False,
) -> "TorchMliap":
    """Wrap a torch.nn.Module that returns energy and gradients.

    Args:
        module: torch ``nn.Module`` satisfying ``TorchMliapForward``.
        cutoff: Interaction cutoff radius [Å].
        compute_cell_gradients: Whether the module returns
            ``"cell_gradients"`` for stress computation.

    Returns:
        Configured ``TorchMliap`` ready for use with the kUPS interface.
    """
    wrapper = TorchModuleWrapper(module, requires_grad=True)
    return TorchMliap(
        cutoff=Table((SystemId(0),), jnp.array([cutoff], float)),
        wrapper=wrapper,
        compute_cell_gradients=compute_cell_gradients,
    )

TorchMliapForward

Bases: Protocol

Forward contract for a torch MLFF module.

The module must accept an AtomGraphInput dict and return a dict with:

  • "energy": (B,) per-system total energies.
  • "position_gradients": (N, 3) :math:\partial E / \partial r.
  • "cell_gradients": (B, 3, 3) :math:\partial E / \partial h, required only when compute_cell_gradients=True.

Outputs are gradients (not forces); adapters around models that natively produce forces/virials negate appropriately inside the module.

Source code in src/kups/potential/mliap/torch/interface.py
class TorchMliapForward(Protocol):
    """Forward contract for a torch MLFF module.

    The module must accept an ``AtomGraphInput`` dict and return a dict with:

    - ``"energy"``: ``(B,)`` per-system total energies.
    - ``"position_gradients"``: ``(N, 3)`` :math:`\\partial E / \\partial r`.
    - ``"cell_gradients"``: ``(B, 3, 3)`` :math:`\\partial E / \\partial h`,
      required only when ``compute_cell_gradients=True``.

    Outputs are gradients (not forces); adapters around models that natively
    produce forces/virials negate appropriately inside the module.
    """

    def __call__(self, input: AtomGraphInput) -> dict[str, Array]: ...

UMAModule

Bases: Module

Adapter: AtomGraphInput → fairchem AtomicData → energy + gradients.

Wraps a fairchem MLIPPredictUnit and translates the universal graph input into the AtomicData object UMA expects. Returns gradients of energy w.r.t. positions (and optionally w.r.t. cell vectors).

The wrapped predict-unit holds its own torch module and manages its own device placement; this adapter intentionally does not register it as a submodule (it is not an nn.Module).

Attributes:

Name Type Description
predict_unit

fairchem MLIPPredictUnit (held by reference).

task_name

UMA inference head to route every system to.

compute_cell_gradients

Whether to also return "cell_gradients".

Note

UMA's stress is the symmetrized strain virial V_ij / volume from a joint symmetric strain on positions and cell (cf. fairchem.core.models.uma.outputs.compute_forces_and_stress). We invert the position contribution and the cell^T factor to recover the raw lattice gradient ∂E/∂h — see lattice_gradient_from_virial. The antisymmetric part of cell^T @ ∂E/∂h is unrecoverable from a symmetric-strain virial alone; for physical models with rotational invariance it is zero, so the recovered ∂E/∂h is exact.

Source code in src/kups/potential/mliap/torch/uma.py
class UMAModule(torch.nn.Module):
    """Adapter: ``AtomGraphInput`` → fairchem ``AtomicData`` → energy + gradients.

    Wraps a fairchem ``MLIPPredictUnit`` and translates the universal graph
    input into the ``AtomicData`` object UMA expects. Returns gradients of
    energy w.r.t. positions (and optionally w.r.t. cell vectors).

    The wrapped predict-unit holds its own torch module and manages its own
    device placement; this adapter intentionally does not register it as a
    submodule (it is not an ``nn.Module``).

    Attributes:
        predict_unit: fairchem ``MLIPPredictUnit`` (held by reference).
        task_name: UMA inference head to route every system to.
        compute_cell_gradients: Whether to also return ``"cell_gradients"``.

    Note:
        UMA's ``stress`` is the symmetrized strain virial ``V_ij /
        volume`` from a joint symmetric strain on positions and cell
        (cf. ``fairchem.core.models.uma.outputs.compute_forces_and_stress``).
        We invert the position contribution and the ``cell^T`` factor to
        recover the raw lattice gradient ``∂E/∂h`` — see
        ``lattice_gradient_from_virial``. The antisymmetric part of
        ``cell^T @ ∂E/∂h`` is unrecoverable from a symmetric-strain virial
        alone; for physical models with rotational invariance it is zero,
        so the recovered ``∂E/∂h`` is exact.
    """

    def __init__(
        self,
        predict_unit: Any,
        task_name: UMATaskName | str = "omat",
        compute_cell_gradients: bool = False,
    ) -> None:
        """Initialise ``UMAModule``.

        Args:
            predict_unit: fairchem ``MLIPPredictUnit`` (already loaded onto a
                device).
            task_name: UMA inference head (e.g. ``"omat"``, ``"omol"``).
            compute_cell_gradients: Whether to compute cell gradients (stress).
        """
        super().__init__()
        # PredictUnit is not an nn.Module; keep as plain attribute so
        # ``module.to(device)`` does not try to traverse it.
        self.predict_unit = predict_unit
        self.task_name = str(task_name)
        self.compute_cell_gradients = compute_cell_gradients

    def forward(self, input: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        """Run UMA on a universal ``AtomGraphInput`` and return gradients.

        Args:
            input: Dict matching the universal ``AtomGraphInput`` schema.

        Returns:
            Dict with ``"energy"`` ``(B,)``, ``"position_gradients"`` ``(N, 3)``,
            and optionally ``"cell_gradients"`` ``(B, 3, 3)``.
        """
        from fairchem.core.datasets.atomic_data import (  # pyright: ignore[reportMissingImports]
            AtomicData,
        )

        # Keep AtomicData inputs at the device JAX/DLPack handed us. We must
        # NOT pre-move them to ``predict_unit.device``: the first call into
        # ``predict_unit.predict`` triggers ``_lazy_init`` which calls
        # ``prepare_for_inference(data, ...)`` while the underlying model is
        # still on cpu (``move_to_device`` runs after ``prepare_for_inference``
        # inside ``_lazy_init``). Pre-moving data to cuda makes UMA's MOLE
        # merge embed cuda indices into cpu weights and crash. The
        # predict-unit itself does ``data.to(self.device).clone()`` internally,
        # so any device-transition cost is paid there.
        #
        # We DO cast every float input to the predict-unit's inference dtype
        # (typically float32) before constructing AtomicData. Otherwise turbo
        # mode's MOLE merge — which runs inside ``_lazy_init.prepare_for_inference``,
        # *before* predict_unit's own dtype cast — pairs float64 data with
        # float32 model weights inside an einsum.
        inf_dtype = self.predict_unit.inference_settings.base_precision_dtype
        pos = input["pos"].to(dtype=inf_dtype)
        species = input["atomic_numbers"].to(dtype=torch.int64)
        cell = input["cell"].to(dtype=inf_dtype)
        batch = input["batch"].to(dtype=torch.int64).clone()
        edge_index = input["edge_index"].to(dtype=torch.int64)
        cell_offsets = -input["cell_offsets"].to(dtype=inf_dtype)
        pbc = input["pbc"]
        charge = input["charge"].to(dtype=inf_dtype)
        spin = input["spin"].to(dtype=inf_dtype)

        n_atoms = pos.shape[0]
        n_sys = cell.shape[0]

        # ``TorchModuleWrapper._get_output_info`` calls us once with all-zero
        # mock tensors to infer output shapes. UMA's turbo path merges its
        # MOLE experts on the *first* real ``predict()`` call (in
        # ``_lazy_init``) and then asserts every subsequent call has the same
        # composition. If we let the mock call through, the merge bakes in a
        # "100% H" composition that conflicts with the real structure. Detect
        # the mock by ``atomic_numbers.sum() == 0`` and return dummies of the
        # right shape/dtype without invoking ``predict_unit`` — that way the
        # first real call is what triggers ``_lazy_init`` with the real
        # composition.
        if int(species.abs().sum().item()) == 0:
            out_dtype = self.predict_unit.inference_settings.base_precision_dtype
            dev = pos.device
            result: dict[str, torch.Tensor] = {
                "energy": torch.zeros(n_sys, dtype=out_dtype, device=dev),
                "position_gradients": torch.zeros(
                    n_atoms, 3, dtype=out_dtype, device=dev
                ),
            }
            if self.compute_cell_gradients:
                result["cell_gradients"] = torch.zeros(
                    n_sys, 3, 3, dtype=out_dtype, device=dev
                )
            return result

        # kUPS's neighbor list is a fixed-size buffer; ``indices_in`` maps
        # unused slots to the OOB sentinel ``len(keys) == n_atoms``. Drop
        # those padded edges before handing the graph to UMA — otherwise the
        # species/position gather hits true out-of-bounds indices and CUDA
        # asserts.
        valid_edge = (edge_index < n_atoms).all(dim=0)
        edge_index = edge_index[:, valid_edge]
        cell_offsets = cell_offsets[valid_edge]

        # Pin the last atom to the last system so ``batch.max() + 1 == n_sys``
        # (validates AtomicData). For ``sorted_by_system`` real input this is
        # a no-op (the last atom is already in the last system).
        batch[-1] = n_sys - 1

        # First-real-call path: ``predict_unit._lazy_init`` runs
        # ``prepare_for_inference`` *before* moving the model to its target
        # device (``self.predict_unit.device``). If we pass cuda tensors at
        # that point, MOLE merge tries to gather cuda indices through cpu
        # weights and crashes. Move data to cpu only for this one call —
        # ``predict_unit.predict`` re-moves to its own device inside, so the
        # actual model forward still runs on cuda. From the second call
        # onward, data can stay on whatever device JAX handed us.
        if not getattr(self.predict_unit, "lazy_model_intialized", True):
            pos = pos.cpu()
            species = species.cpu()
            cell = cell.cpu()
            batch = batch.cpu()
            edge_index = edge_index.cpu()
            cell_offsets = cell_offsets.cpu()
            pbc = pbc.cpu()
            charge = charge.cpu()
            spin = spin.cpu()

        natoms = torch.bincount(batch, minlength=n_sys)
        edge_batch = batch[edge_index[0]]
        nedges = torch.bincount(edge_batch, minlength=n_sys)
        fixed = torch.zeros(n_atoms, dtype=torch.int64, device=pos.device)
        tags = torch.zeros(n_atoms, dtype=torch.int64, device=pos.device)

        data = AtomicData(
            pos=pos,
            atomic_numbers=species,
            cell=cell,
            pbc=pbc,
            natoms=natoms,
            edge_index=edge_index,
            cell_offsets=cell_offsets.to(pos.dtype),
            nedges=nedges,
            charge=charge,
            spin=spin,
            fixed=fixed,
            tags=tags,
            batch=batch,
            sid=[""] * n_sys,
            dataset=[self.task_name] * n_sys,
        )

        preds = self.predict_unit.predict(data)

        forces = preds["forces"]
        # ``predict_unit`` places outputs on its own device, which may differ
        # from the input device (cpu mock vs cuda predict-unit, single- vs
        # multi-gpu). Pin our post-processing tensors to the output device so
        # ``stress * volume`` and ``lattice_gradient_from_virial`` stay on a
        # single device.
        out_device = forces.device
        result: dict[str, torch.Tensor] = {
            "energy": preds["energy"].detach(),
            "position_gradients": (-forces).detach(),
        }
        if self.compute_cell_gradients:
            stress = preds["stress"]
            # UMA flattens stress to (B, 9); reshape to (B, 3, 3) if needed.
            if stress.dim() == 2 and stress.shape[-1] == 9:
                stress = stress.view(-1, 3, 3)
            cell_d = cell.to(out_device)
            volume = torch.linalg.det(cell_d).abs()
            virial = stress * volume.view(-1, 1, 1)
            cell_grad = lattice_gradient_from_virial(
                forces=forces,
                positions=pos.to(out_device),
                batch=batch.to(out_device),
                cell=cell_d,
                virial=virial,
            )
            result["cell_gradients"] = cell_grad.detach()
        return result

__init__(predict_unit, task_name='omat', compute_cell_gradients=False)

Initialise UMAModule.

Parameters:

Name Type Description Default
predict_unit Any

fairchem MLIPPredictUnit (already loaded onto a device).

required
task_name UMATaskName | str

UMA inference head (e.g. "omat", "omol").

'omat'
compute_cell_gradients bool

Whether to compute cell gradients (stress).

False
Source code in src/kups/potential/mliap/torch/uma.py
def __init__(
    self,
    predict_unit: Any,
    task_name: UMATaskName | str = "omat",
    compute_cell_gradients: bool = False,
) -> None:
    """Initialise ``UMAModule``.

    Args:
        predict_unit: fairchem ``MLIPPredictUnit`` (already loaded onto a
            device).
        task_name: UMA inference head (e.g. ``"omat"``, ``"omol"``).
        compute_cell_gradients: Whether to compute cell gradients (stress).
    """
    super().__init__()
    # PredictUnit is not an nn.Module; keep as plain attribute so
    # ``module.to(device)`` does not try to traverse it.
    self.predict_unit = predict_unit
    self.task_name = str(task_name)
    self.compute_cell_gradients = compute_cell_gradients

forward(input)

Run UMA on a universal AtomGraphInput and return gradients.

Parameters:

Name Type Description Default
input dict[str, Tensor]

Dict matching the universal AtomGraphInput schema.

required

Returns:

Type Description
dict[str, Tensor]

Dict with "energy" (B,), "position_gradients" (N, 3),

dict[str, Tensor]

and optionally "cell_gradients" (B, 3, 3).

Source code in src/kups/potential/mliap/torch/uma.py
def forward(self, input: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
    """Run UMA on a universal ``AtomGraphInput`` and return gradients.

    Args:
        input: Dict matching the universal ``AtomGraphInput`` schema.

    Returns:
        Dict with ``"energy"`` ``(B,)``, ``"position_gradients"`` ``(N, 3)``,
        and optionally ``"cell_gradients"`` ``(B, 3, 3)``.
    """
    from fairchem.core.datasets.atomic_data import (  # pyright: ignore[reportMissingImports]
        AtomicData,
    )

    # Keep AtomicData inputs at the device JAX/DLPack handed us. We must
    # NOT pre-move them to ``predict_unit.device``: the first call into
    # ``predict_unit.predict`` triggers ``_lazy_init`` which calls
    # ``prepare_for_inference(data, ...)`` while the underlying model is
    # still on cpu (``move_to_device`` runs after ``prepare_for_inference``
    # inside ``_lazy_init``). Pre-moving data to cuda makes UMA's MOLE
    # merge embed cuda indices into cpu weights and crash. The
    # predict-unit itself does ``data.to(self.device).clone()`` internally,
    # so any device-transition cost is paid there.
    #
    # We DO cast every float input to the predict-unit's inference dtype
    # (typically float32) before constructing AtomicData. Otherwise turbo
    # mode's MOLE merge — which runs inside ``_lazy_init.prepare_for_inference``,
    # *before* predict_unit's own dtype cast — pairs float64 data with
    # float32 model weights inside an einsum.
    inf_dtype = self.predict_unit.inference_settings.base_precision_dtype
    pos = input["pos"].to(dtype=inf_dtype)
    species = input["atomic_numbers"].to(dtype=torch.int64)
    cell = input["cell"].to(dtype=inf_dtype)
    batch = input["batch"].to(dtype=torch.int64).clone()
    edge_index = input["edge_index"].to(dtype=torch.int64)
    cell_offsets = -input["cell_offsets"].to(dtype=inf_dtype)
    pbc = input["pbc"]
    charge = input["charge"].to(dtype=inf_dtype)
    spin = input["spin"].to(dtype=inf_dtype)

    n_atoms = pos.shape[0]
    n_sys = cell.shape[0]

    # ``TorchModuleWrapper._get_output_info`` calls us once with all-zero
    # mock tensors to infer output shapes. UMA's turbo path merges its
    # MOLE experts on the *first* real ``predict()`` call (in
    # ``_lazy_init``) and then asserts every subsequent call has the same
    # composition. If we let the mock call through, the merge bakes in a
    # "100% H" composition that conflicts with the real structure. Detect
    # the mock by ``atomic_numbers.sum() == 0`` and return dummies of the
    # right shape/dtype without invoking ``predict_unit`` — that way the
    # first real call is what triggers ``_lazy_init`` with the real
    # composition.
    if int(species.abs().sum().item()) == 0:
        out_dtype = self.predict_unit.inference_settings.base_precision_dtype
        dev = pos.device
        result: dict[str, torch.Tensor] = {
            "energy": torch.zeros(n_sys, dtype=out_dtype, device=dev),
            "position_gradients": torch.zeros(
                n_atoms, 3, dtype=out_dtype, device=dev
            ),
        }
        if self.compute_cell_gradients:
            result["cell_gradients"] = torch.zeros(
                n_sys, 3, 3, dtype=out_dtype, device=dev
            )
        return result

    # kUPS's neighbor list is a fixed-size buffer; ``indices_in`` maps
    # unused slots to the OOB sentinel ``len(keys) == n_atoms``. Drop
    # those padded edges before handing the graph to UMA — otherwise the
    # species/position gather hits true out-of-bounds indices and CUDA
    # asserts.
    valid_edge = (edge_index < n_atoms).all(dim=0)
    edge_index = edge_index[:, valid_edge]
    cell_offsets = cell_offsets[valid_edge]

    # Pin the last atom to the last system so ``batch.max() + 1 == n_sys``
    # (validates AtomicData). For ``sorted_by_system`` real input this is
    # a no-op (the last atom is already in the last system).
    batch[-1] = n_sys - 1

    # First-real-call path: ``predict_unit._lazy_init`` runs
    # ``prepare_for_inference`` *before* moving the model to its target
    # device (``self.predict_unit.device``). If we pass cuda tensors at
    # that point, MOLE merge tries to gather cuda indices through cpu
    # weights and crashes. Move data to cpu only for this one call —
    # ``predict_unit.predict`` re-moves to its own device inside, so the
    # actual model forward still runs on cuda. From the second call
    # onward, data can stay on whatever device JAX handed us.
    if not getattr(self.predict_unit, "lazy_model_intialized", True):
        pos = pos.cpu()
        species = species.cpu()
        cell = cell.cpu()
        batch = batch.cpu()
        edge_index = edge_index.cpu()
        cell_offsets = cell_offsets.cpu()
        pbc = pbc.cpu()
        charge = charge.cpu()
        spin = spin.cpu()

    natoms = torch.bincount(batch, minlength=n_sys)
    edge_batch = batch[edge_index[0]]
    nedges = torch.bincount(edge_batch, minlength=n_sys)
    fixed = torch.zeros(n_atoms, dtype=torch.int64, device=pos.device)
    tags = torch.zeros(n_atoms, dtype=torch.int64, device=pos.device)

    data = AtomicData(
        pos=pos,
        atomic_numbers=species,
        cell=cell,
        pbc=pbc,
        natoms=natoms,
        edge_index=edge_index,
        cell_offsets=cell_offsets.to(pos.dtype),
        nedges=nedges,
        charge=charge,
        spin=spin,
        fixed=fixed,
        tags=tags,
        batch=batch,
        sid=[""] * n_sys,
        dataset=[self.task_name] * n_sys,
    )

    preds = self.predict_unit.predict(data)

    forces = preds["forces"]
    # ``predict_unit`` places outputs on its own device, which may differ
    # from the input device (cpu mock vs cuda predict-unit, single- vs
    # multi-gpu). Pin our post-processing tensors to the output device so
    # ``stress * volume`` and ``lattice_gradient_from_virial`` stay on a
    # single device.
    out_device = forces.device
    result: dict[str, torch.Tensor] = {
        "energy": preds["energy"].detach(),
        "position_gradients": (-forces).detach(),
    }
    if self.compute_cell_gradients:
        stress = preds["stress"]
        # UMA flattens stress to (B, 9); reshape to (B, 3, 3) if needed.
        if stress.dim() == 2 and stress.shape[-1] == 9:
            stress = stress.view(-1, 3, 3)
        cell_d = cell.to(out_device)
        volume = torch.linalg.det(cell_d).abs()
        virial = stress * volume.view(-1, 1, 1)
        cell_grad = lattice_gradient_from_virial(
            forces=forces,
            positions=pos.to(out_device),
            batch=batch.to(out_device),
            cell=cell_d,
            virial=virial,
        )
        result["cell_gradients"] = cell_grad.detach()
    return result

lattice_gradient_from_virial(forces, positions, batch, cell, virial)

Recover ∂E/∂h from a symmetric-strain virial.

Many torch MLFF backends (MACE, UMA, …) return a virial or stress quantity that encodes the gradient of energy under a symmetric infinitesimal strain applied jointly to positions and cell:

r_b → r_b + r_b @ ε          (per atom b)
h_s → h_s + h_s @ ε          (per system s)

The virial returned by the backend is then virial = sym(pos_virial + cell_virial) where:

pos_virial[s, j, k]  = Σ_{b∈s} (∂E/∂r_b)_j · (r_b)_k
cell_virial          = cell^T @ ∂E/∂h
sym(M)               = (M + M^T) / 2

Given forces (= -∂E/∂r), positions, batch, cell, and the virial, this function reconstructs pos_virial from -forces ⊗ positions, subtracts it, and solves cell^T @ (∂E/∂h) = cell_virial for the raw lattice gradient. Assumes cell^T @ ∂E/∂h is symmetric (rotational invariance of the energy); the antisymmetric part is unrecoverable from the symmetric-strain virial alone.

Parameters:

Name Type Description Default
forces 'torch.Tensor'

(N, 3) = -∂E/∂r.

required
positions 'torch.Tensor'

(N, 3).

required
batch 'torch.Tensor'

(N,) int system index per atom.

required
cell 'torch.Tensor'

(B, 3, 3).

required
virial 'torch.Tensor'

(B, 3, 3) symmetric strain virial as defined above.

required

Returns:

Type Description
'torch.Tensor'

(B, 3, 3) ∂E/∂h at fixed positions.

Source code in src/kups/potential/mliap/torch/interface.py
def lattice_gradient_from_virial(
    forces: "torch.Tensor",
    positions: "torch.Tensor",
    batch: "torch.Tensor",
    cell: "torch.Tensor",
    virial: "torch.Tensor",
) -> "torch.Tensor":
    """Recover ``∂E/∂h`` from a symmetric-strain virial.

    Many torch MLFF backends (MACE, UMA, …) return a virial or stress quantity
    that encodes the gradient of energy under a *symmetric infinitesimal
    strain* applied jointly to positions and cell:

        r_b → r_b + r_b @ ε          (per atom b)
        h_s → h_s + h_s @ ε          (per system s)

    The virial returned by the backend is then
    ``virial = sym(pos_virial + cell_virial)`` where:

        pos_virial[s, j, k]  = Σ_{b∈s} (∂E/∂r_b)_j · (r_b)_k
        cell_virial          = cell^T @ ∂E/∂h
        sym(M)               = (M + M^T) / 2

    Given forces (= ``-∂E/∂r``), positions, batch, cell, and the virial, this
    function reconstructs ``pos_virial`` from ``-forces ⊗ positions``, subtracts
    it, and solves ``cell^T @ (∂E/∂h) = cell_virial`` for the raw lattice
    gradient. Assumes ``cell^T @ ∂E/∂h`` is symmetric (rotational invariance
    of the energy); the antisymmetric part is unrecoverable from the
    symmetric-strain virial alone.

    Args:
        forces: ``(N, 3)`` ``= -∂E/∂r``.
        positions: ``(N, 3)``.
        batch: ``(N,)`` int system index per atom.
        cell: ``(B, 3, 3)``.
        virial: ``(B, 3, 3)`` symmetric strain virial as defined above.

    Returns:
        ``(B, 3, 3)`` ``∂E/∂h`` at fixed positions.
    """
    # Backends may emit ``forces``/``virial`` at a different precision than
    # ``cell``/``positions`` (e.g. UMA's predict-unit casts to its inference
    # dtype but normalizers/denorm steps can bump back). Unify on the highest
    # precision present so ``torch.linalg.solve`` doesn't reject a Float/Double
    # mix at the end.
    dtypes = (forces.dtype, positions.dtype, cell.dtype, virial.dtype)
    common_dtype = torch.float64 if torch.float64 in dtypes else torch.float32
    forces = forces.to(common_dtype)
    positions = positions.to(common_dtype)
    cell = cell.to(common_dtype)
    virial = virial.to(common_dtype)

    n_sys = cell.shape[0]
    g_r = -forces  # ∂E/∂r
    pos_virial_per_atom = g_r.unsqueeze(2) * positions.unsqueeze(1)  # (N, 3, 3)
    pos_virial = positions.new_zeros(n_sys, 3, 3)
    pos_virial = pos_virial.index_add(0, batch, pos_virial_per_atom)
    sym_pos_virial = 0.5 * (pos_virial + pos_virial.transpose(-1, -2))
    sym_cell_virial = virial - sym_pos_virial
    # Substitute identity for singular ``cell^T`` so ``torch.linalg.solve``
    # never raises on the all-zero mock tensors that ``TorchModuleWrapper``
    # uses for output-shape inference (CUDA's lstsq drivers also require full
    # rank, so we can't rely on them). The output values for singular cells
    # are meaningless and discarded by the wrapper's mock pass.
    cell_T = cell.transpose(-1, -2)
    det = torch.linalg.det(cell_T)
    eye = cell.new_zeros(3, 3)
    eye.fill_diagonal_(1.0)
    eye = eye.expand_as(cell_T)
    is_singular = (det.abs() < 1e-12).view(-1, 1, 1).expand_as(cell_T)
    safe_cell_T = cell_T.where(~is_singular, eye)
    return torch.linalg.solve(safe_cell_T, sym_cell_virial)

load_mace(model_path, device='cuda', dtype='float32', compute_cell_gradients=False, cutoff=None)

Load a PyTorch MACE .model into a universal TorchMliap.

Parameters:

Name Type Description Default
model_path str | Path

Path to a MACE .model checkpoint.

required
device str

Device to load the model onto.

'cuda'
dtype Literal['float32', 'float64']

Model precision — "float32" (default) or "float64".

'float32'
compute_cell_gradients bool

Whether to also compute virials/stress.

False
cutoff float | None

Cutoff radius [Å]. When None, read from model.r_max.

None

Returns:

Type Description
TorchMliap

TorchMliap ready to be wired into the kUPS interface.

Source code in src/kups/potential/mliap/torch/mace.py
def load_mace(
    model_path: str | Path,
    device: str = "cuda",
    dtype: Literal["float32", "float64"] = "float32",
    compute_cell_gradients: bool = False,
    cutoff: float | None = None,
) -> TorchMliap:
    """Load a PyTorch MACE ``.model`` into a universal ``TorchMliap``.

    Args:
        model_path: Path to a MACE ``.model`` checkpoint.
        device: Device to load the model onto.
        dtype: Model precision — ``"float32"`` (default) or ``"float64"``.
        compute_cell_gradients: Whether to also compute virials/stress.
        cutoff: Cutoff radius [Å]. When ``None``, read from ``model.r_max``.

    Returns:
        ``TorchMliap`` ready to be wired into the kUPS interface.
    """
    if device == "cuda" and not torch.cuda.is_available():
        raise RuntimeError(
            "Device 'cuda' requested but CUDA is not available. "
            "Use device='cpu' or ensure CUDA is properly installed."
        )

    path = Path(model_path)
    if not path.exists():
        raise FileNotFoundError(f"MACE model not found: {model_path}")

    model = torch.load(path, weights_only=False, map_location=device)
    model.eval()
    model = model.double() if dtype == "float64" else model.float()
    # Re-broadcast to the target device: some MACE/e3nn TorchScript submodules
    # carry Wigner-3j buffers that ``map_location`` and ``.float()`` don't
    # consistently move, and ``TorchModuleWrapper``'s mock-inference call
    # invokes the module before any ``.to(device)`` could rectify this.
    model = model.to(device)

    species_to_index, num_species = _build_species_to_index(model)
    if cutoff is None:
        cutoff = float(cast(torch.Tensor, model.r_max).item())

    module = MACEModule(
        model,
        species_to_index=species_to_index,
        num_species=num_species,
        compute_cell_gradients=compute_cell_gradients,
    )
    return TorchMliap.from_module(
        module, cutoff=cutoff, compute_cell_gradients=compute_cell_gradients
    )

load_uma(model_path, device='cuda', task_name='omat', compute_cell_gradients=False, cutoff=6.0, inference_settings='default')

Load a Meta FAIR Chemistry UMA checkpoint into a TorchMliap.

Parameters:

Name Type Description Default
model_path str | Path

Path to a UMA .pt checkpoint (e.g. uma-s-1.2.pt).

required
device str

Device to load the model onto.

'cuda'
task_name UMATaskName | str

UMA inference head — "omat" (materials), "omol" (molecules), "oc20" (catalysis), "odac" (MOFs / DAC), "omc" (molecular crystals).

'omat'
compute_cell_gradients bool

Whether to also return cell gradients (stress). See UMAModule for convention caveats.

False
cutoff float

Cutoff radius [Å]. UMA-s-1.2 defaults to 6.0.

6.0
inference_settings str

Forwarded to fairchem.core.units.mlip_unit.load_predict_unit"default" or "turbo".

'default'

Returns:

Type Description
TorchMliap

TorchMliap ready to be wired into the kUPS interface.

Raises:

Type Description
ImportError

If fairchem-core>=2.0 is not installed.

Source code in src/kups/potential/mliap/torch/uma.py
def load_uma(
    model_path: str | Path,
    device: str = "cuda",
    task_name: UMATaskName | str = "omat",
    compute_cell_gradients: bool = False,
    cutoff: float = 6.0,
    inference_settings: str = "default",
) -> TorchMliap:
    """Load a Meta FAIR Chemistry UMA checkpoint into a ``TorchMliap``.

    Args:
        model_path: Path to a UMA ``.pt`` checkpoint (e.g. ``uma-s-1.2.pt``).
        device: Device to load the model onto.
        task_name: UMA inference head — ``"omat"`` (materials),
            ``"omol"`` (molecules), ``"oc20"`` (catalysis),
            ``"odac"`` (MOFs / DAC), ``"omc"`` (molecular crystals).
        compute_cell_gradients: Whether to also return cell gradients
            (stress). See ``UMAModule`` for convention caveats.
        cutoff: Cutoff radius [Å]. UMA-s-1.2 defaults to 6.0.
        inference_settings: Forwarded to
            ``fairchem.core.units.mlip_unit.load_predict_unit`` —
            ``"default"`` or ``"turbo"``.

    Returns:
        ``TorchMliap`` ready to be wired into the kUPS interface.

    Raises:
        ImportError: If ``fairchem-core>=2.0`` is not installed.
    """
    if device == "cuda" and not torch.cuda.is_available():
        raise RuntimeError(
            "Device 'cuda' requested but CUDA is not available. "
            "Use device='cpu' or ensure CUDA is properly installed."
        )

    path = Path(model_path)
    if not path.exists():
        raise FileNotFoundError(f"UMA model not found: {model_path}")

    try:
        from dataclasses import replace

        from fairchem.core.units.mlip_unit import (  # pyright: ignore[reportMissingImports]
            load_predict_unit,
        )
        from fairchem.core.units.mlip_unit.api.inference import (  # pyright: ignore[reportMissingImports]
            guess_inference_settings,
        )
    except ImportError as exc:
        raise ImportError(
            "Loading UMA requires fairchem-core>=2.0. "
            "Install with `uv sync --extra uma`."
        ) from exc

    # Resolve the named settings to a concrete ``InferenceSettings`` and
    # force external graph generation: kUPS already maintains the radius
    # graph (with the exact same cutoff we pass to UMA), so there's no
    # reason to recompute it inside the model. UMA's internal
    # ``radius_graph_pbc_v2`` also has compile/SymInt issues that go away
    # entirely when ``otf_graph=False``.
    settings = guess_inference_settings(inference_settings)
    settings = replace(settings, external_graph_gen=True)

    predict_unit = load_predict_unit(
        path=str(path),
        device=device,  # pyright: ignore[reportArgumentType]
        inference_settings=settings,
    )
    module = UMAModule(
        predict_unit,
        task_name=task_name,
        compute_cell_gradients=compute_cell_gradients,
    )
    return TorchMliap.from_module(
        module, cutoff=cutoff, compute_cell_gradients=compute_cell_gradients
    )

make_torch_mliap_from_state(state, *, compute_position_and_cell_gradients=False)

make_torch_mliap_from_state(
    state: Lens[State, IsTorchMliapState],
    *,
    compute_position_and_cell_gradients: Literal[
        False
    ] = ...,
) -> Potential[State, Array, EmptyType, Any]
make_torch_mliap_from_state(
    state: Lens[State, IsTorchMliapState],
    *,
    compute_position_and_cell_gradients: Literal[True],
) -> Potential[State, PositionAndCell, EmptyType, Any]

Create a torch MLFF potential from a typed state.

Parameters:

Name Type Description Default
state Any

Lens into a sub-state providing particles, systems, neighbor list, and torch MLFF model.

required
compute_position_and_cell_gradients bool

When True, exposes both position and cell gradients. Requires the underlying TorchMliap.compute_cell_gradients to be True.

False

Returns:

Type Description
Any

Configured torch MLFF Potential.

Source code in src/kups/potential/mliap/torch/interface.py
def make_torch_mliap_from_state(
    state: Any,
    *,
    compute_position_and_cell_gradients: bool = False,
) -> Any:
    """Create a torch MLFF potential from a typed state.

    Args:
        state: Lens into a sub-state providing particles, systems, neighbor
            list, and torch MLFF model.
        compute_position_and_cell_gradients: When ``True``, exposes both
            position and cell gradients. Requires the underlying
            ``TorchMliap.compute_cell_gradients`` to be ``True``.

    Returns:
        Configured torch MLFF ``Potential``.
    """
    return make_torch_mliap_potential(
        state.focus(lambda x: x.particles),
        state.focus(lambda x: x.systems),
        state.focus(lambda x: x.neighborlist),
        state.focus(lambda x: x.torch_mliap_model),
        state.focus(lambda x: x.torch_mliap_model.cutoff),
        compute_cell_gradients=compute_position_and_cell_gradients,
    )

make_torch_mliap_potential(particles_view, systems_view, neighborlist_view, model, cutoffs_view, compute_cell_gradients=False, patch_idx_view=None, out_cache_lens=None)

make_torch_mliap_potential(
    particles_view: View[State, Table[ParticleId, P]],
    systems_view: View[State, Table[SystemId, S]],
    neighborlist_view: View[State, NNList],
    model: View[State, TorchMliap] | TorchMliap,
    cutoffs_view: View[State, Table[SystemId, Array]],
    compute_cell_gradients: Literal[False] = False,
    patch_idx_view: View[
        State, PotentialOut[Array, EmptyType]
    ]
    | None = None,
    out_cache_lens: Lens[
        State, PotentialOut[Array, EmptyType]
    ]
    | None = None,
) -> Potential[State, Array, EmptyType, Patch[State]]
make_torch_mliap_potential(
    particles_view: View[State, Table[ParticleId, P]],
    systems_view: View[State, Table[SystemId, S]],
    neighborlist_view: View[State, NNList],
    model: View[State, TorchMliap] | TorchMliap,
    cutoffs_view: View[State, Table[SystemId, Array]],
    compute_cell_gradients: Literal[True],
    patch_idx_view: View[
        State, PotentialOut[PositionAndCell, EmptyType]
    ]
    | None = None,
    out_cache_lens: Lens[
        State, PotentialOut[PositionAndCell, EmptyType]
    ]
    | None = None,
) -> Potential[
    State, PositionAndCell, EmptyType, Patch[State]
]

Create a kUPS Potential from a TorchMliap.

Forces (and optionally stress) are computed inside the torch module; the kUPS side just routes the precomputed gradients through DirectPotential.

Parameters:

Name Type Description Default
particles_view Any

Extracts particle data from state.

required
systems_view Any

Extracts system data (cell) from state.

required
neighborlist_view Any

Extracts neighbor list from state.

required
model Any

TorchMliap instance or view to model in state.

required
cutoffs_view Any

Extracts cutoffs as Table[SystemId, Array].

required
compute_cell_gradients bool

When True, exposes cell gradients (i.e. stress). The wrapped module must produce "cell_gradients".

False
patch_idx_view Any | None

Cached output index structure (optional).

None
out_cache_lens Any | None

Cache location lens (optional).

None

Returns:

Type Description
Any

Configured Potential backed by the torch MLFF.

Source code in src/kups/potential/mliap/torch/interface.py
def make_torch_mliap_potential(
    particles_view: Any,
    systems_view: Any,
    neighborlist_view: Any,
    model: Any,
    cutoffs_view: Any,
    compute_cell_gradients: bool = False,
    patch_idx_view: Any | None = None,
    out_cache_lens: Any | None = None,
) -> Any:
    """Create a kUPS ``Potential`` from a ``TorchMliap``.

    Forces (and optionally stress) are computed inside the torch module; the
    kUPS side just routes the precomputed gradients through ``DirectPotential``.

    Args:
        particles_view: Extracts particle data from state.
        systems_view: Extracts system data (cell) from state.
        neighborlist_view: Extracts neighbor list from state.
        model: ``TorchMliap`` instance or view to model in state.
        cutoffs_view: Extracts cutoffs as ``Table[SystemId, Array]``.
        compute_cell_gradients: When ``True``, exposes cell gradients
            (i.e. stress). The wrapped module must produce ``"cell_gradients"``.
        patch_idx_view: Cached output index structure (optional).
        out_cache_lens: Cache location lens (optional).

    Returns:
        Configured ``Potential`` backed by the torch MLFF.
    """
    model_view = constant(model) if isinstance(model, TorchMliap) else model
    if compute_cell_gradients:

        def cell_fn(inp: Any) -> Any:
            return torch_mliap_model_fn(inp, compute_cell_gradients=True)

        fn: Any = cell_fn
    else:

        def pos_fn(inp: Any) -> Any:
            return torch_mliap_model_fn(inp, compute_cell_gradients=False)

        fn = pos_fn
    return make_direct_mliap_potential(
        model_fn=fn,
        particles_view=particles_view,
        systems_view=systems_view,
        neighborlist_view=neighborlist_view,
        model_view=model_view,
        cutoffs_view=cutoffs_view,
        patch_idx_view=patch_idx_view,
        out_cache_lens=out_cache_lens,
    )

torch_mliap_model_fn(inp, *, compute_cell_gradients=False)

torch_mliap_model_fn(
    inp: TorchMliapInput[P, S],
    *,
    compute_cell_gradients: Literal[False] = False,
) -> WithPatch[PotentialOut[Array, EmptyType], IdPatch]
torch_mliap_model_fn(
    inp: TorchMliapInput[P, S],
    *,
    compute_cell_gradients: Literal[True],
) -> WithPatch[
    PotentialOut[PositionAndCell, EmptyType], IdPatch
]

Run a TorchMliap on a graph input and package the result.

Parameters:

Name Type Description Default
inp TorchMliapInput[P, S]

Graph potential input bundling the model and graph.

required
compute_cell_gradients bool

Whether to wrap "cell_gradients" into a PositionAndCell gradients structure.

False

Returns:

Type Description
WithPatch[PotentialOut[Array, EmptyType], IdPatch] | WithPatch[PotentialOut[PositionAndCell, EmptyType], IdPatch]

WithPatch containing PotentialOut with energy, gradients, and

WithPatch[PotentialOut[Array, EmptyType], IdPatch] | WithPatch[PotentialOut[PositionAndCell, EmptyType], IdPatch]

an identity patch.

Source code in src/kups/potential/mliap/torch/interface.py
def torch_mliap_model_fn[
    P: IsTorchMliapParticles,
    S: HasCell,
](
    inp: TorchMliapInput[P, S],
    *,
    compute_cell_gradients: bool = False,
) -> (
    WithPatch[PotentialOut[Array, EmptyType], IdPatch]
    | WithPatch[PotentialOut[PositionAndCell, EmptyType], IdPatch]
):
    """Run a ``TorchMliap`` on a graph input and package the result.

    Args:
        inp: Graph potential input bundling the model and graph.
        compute_cell_gradients: Whether to wrap ``"cell_gradients"`` into a
            ``PositionAndCell`` gradients structure.

    Returns:
        ``WithPatch`` containing ``PotentialOut`` with energy, gradients, and
        an identity patch.
    """
    graph, sort_order = inp.graph.sorted_by_system(
        sort_edges=True, return_sort_order=True
    )
    unsort_order = jnp.argsort(sort_order)

    input_dict = _prepare_torch_inputs(graph)
    result = inp.parameters.call(input_dict)

    # Torch backends may run at a different (typically lower) precision than
    # the JAX side (e.g. UMA's predict-unit casts to float32 internally;
    # MACE may be loaded as float32 while JAX runs in x64). Pin every output
    # to the JAX input ``pos`` dtype here so adapters don't need to think
    # about precision and downstream ``lax.scan``/optax pipelines see
    # consistent types.
    out_dtype = input_dict["pos"].dtype
    energy = result["energy"].astype(out_dtype)
    pos_grad = result["position_gradients"][unsort_order].astype(out_dtype)
    energy_table = Table.arange(energy, label=SystemId)

    if compute_cell_gradients:
        cell_grad = result["cell_gradients"].astype(out_dtype)
        # Preserve the input cell/frame type: project the raw ∂E/∂h onto
        # the frame's parameter space via its ``from_matrix`` classmethod,
        # then swap in the new frame on a copy of the input cell.
        input_cell = inp.graph.systems.data.cell
        new_frame = input_cell.frame.from_matrix(cell_grad)
        new_cell = bind(input_cell, lambda c: c.frame).set(new_frame)
        gradients = PositionAndCell(
            positions=Table(inp.graph.particles.keys, pos_grad),
            cell=Table(inp.graph.systems.keys, new_cell),
        )
        return WithPatch(
            PotentialOut(energy_table, gradients, EMPTY),
            IdPatch(),
        )
    return WithPatch(
        PotentialOut(energy_table, pos_grad, EMPTY),
        IdPatch(),
    )