Skip to content

kups.potential.mliap.direct

Factory for graph-based MLIAPs whose models output gradients directly.

Bridges a torch- or JAX-side model_fn that returns a PotentialOut (energy + gradients + hessians) into a kUPS Potential via DirectPotential.

This module covers the "direct" branch only: the model produces the gradients (forces, virials, …) itself. For energy-only models that should be differentiated via JAX autodiff, construct PotentialFromEnergy directly — see tojax for that pattern.

Example
from kups.potential.mliap.direct import make_direct_mliap_potential

def my_forces_fn(inp: DirectMliapInput) -> WithPatch[PotentialOut[Array, EmptyType], IdPatch]:
    energy, forces = model(inp.graph)
    return WithPatch(PotentialOut(energy, -forces, EMPTY), IdPatch())

potential = make_direct_mliap_potential(my_forces_fn, ...)

DirectMliapFn

Bases: Protocol

Protocol for a direct MLIAP model function.

Returns a PotentialOut that bundles energy, gradients and (optionally) hessians for one graph input. Conventional Gradients payloads:

  • Array: position gradients only (∂E/∂r).
  • PositionAndCell: position + cell gradients (forces + stress).
  • EmptyType: no gradients — but in that case the autodiff path (PotentialFromEnergy) is more natural; this module is for the gradient-producing case.
Source code in src/kups/potential/mliap/direct.py
class DirectMliapFn[
    Model,
    Gradients,
    Hessians,
    P: HasPositionsAndSystemIndex,
    S: HasCell,
    Ptch: Patch,
](Protocol):
    """Protocol for a direct MLIAP model function.

    Returns a ``PotentialOut`` that bundles energy, gradients and (optionally)
    hessians for one graph input. Conventional ``Gradients`` payloads:

    - ``Array``: position gradients only (``∂E/∂r``).
    - ``PositionAndCell``: position + cell gradients (forces + stress).
    - ``EmptyType``: no gradients — but in that case the autodiff path
      ([PotentialFromEnergy][kups.potential.common.energy.PotentialFromEnergy])
      is more natural; this module is for the gradient-producing case.
    """

    def __call__(
        self, inp: DirectMliapInput[Model, P, S]
    ) -> WithPatch[PotentialOut[Gradients, Hessians], Ptch]: ...

make_direct_mliap_potential(model_fn, particles_view, systems_view, neighborlist_view, model_view, cutoffs_view, *, patch_idx_view=None, out_cache_lens=None)

Wrap a direct-gradient model_fn into a kUPS Potential.

Parameters:

Name Type Description Default
model_fn DirectMliapFn[Model, Gradients, Hessians, P, S, Ptch]

Direct MLIAP function — see DirectMliapFn.

required
particles_view View[State, Table[ParticleId, P]]

View to extract particles from state.

required
systems_view View[State, Table[SystemId, S]]

View to extract systems (cell) from state.

required
neighborlist_view View[State, NearestNeighborList]

View to extract neighbor list from state.

required
model_view View[State, Model]

View to extract model from state.

required
cutoffs_view View[State, Table[SystemId, Array]]

View to extract cutoffs as Table[SystemId, Array].

required
patch_idx_view View[State, PotentialOut[Gradients, Hessians]] | None

View for cached output indices (optional).

None
out_cache_lens Lens[State, PotentialOut[Gradients, Hessians]] | None

Lens for output cache (optional).

None

Returns:

Type Description
Potential[State, Gradients, Hessians, Patch[State]]

Configured kUPS Potential backed by DirectPotential.

Source code in src/kups/potential/mliap/direct.py
def make_direct_mliap_potential[
    Model,
    State,
    Gradients,
    Hessians,
    P: IsRadiusGraphPoints,
    S: HasCell,
    Ptch: Patch,
](
    model_fn: DirectMliapFn[Model, Gradients, Hessians, P, S, Ptch],
    particles_view: View[State, Table[ParticleId, P]],
    systems_view: View[State, Table[SystemId, S]],
    neighborlist_view: View[State, NearestNeighborList],
    model_view: View[State, Model],
    cutoffs_view: View[State, Table[SystemId, Array]],
    *,
    patch_idx_view: View[State, PotentialOut[Gradients, Hessians]] | None = None,
    out_cache_lens: Lens[State, PotentialOut[Gradients, Hessians]] | None = None,
) -> Potential[State, Gradients, Hessians, Patch[State]]:
    """Wrap a direct-gradient ``model_fn`` into a kUPS ``Potential``.

    Args:
        model_fn: Direct MLIAP function — see
            [DirectMliapFn][kups.potential.mliap.direct.DirectMliapFn].
        particles_view: View to extract particles from state.
        systems_view: View to extract systems (cell) from state.
        neighborlist_view: View to extract neighbor list from state.
        model_view: View to extract model from state.
        cutoffs_view: View to extract cutoffs as ``Table[SystemId, Array]``.
        patch_idx_view: View for cached output indices (optional).
        out_cache_lens: Lens for output cache (optional).

    Returns:
        Configured kUPS ``Potential`` backed by ``DirectPotential``.
    """
    composer = FullGraphSumComposer(
        RadiusGraphConstructor(
            particles=particles_view,
            systems=systems_view,
            cutoffs=cutoffs_view,
            neighborlist=neighborlist_view,
            probe=None,
        ),
        model_view,
    )
    return DirectPotential(
        direct_potential_fn=model_fn,
        composer=composer,
        cache_lens=out_cache_lens,
        patch_idx_view=patch_idx_view,
    )