kups.potential.mliap.direct
¶
Factory for graph-based MLIAPs whose models output gradients directly.
Bridges a torch- or JAX-side model_fn that returns a PotentialOut
(energy + gradients + hessians) into a kUPS
Potential via
DirectPotential.
This module covers the "direct" branch only: the model produces the gradients (forces, virials, …) itself. For energy-only models that should be differentiated via JAX autodiff, construct PotentialFromEnergy directly — see tojax for that pattern.
Example
from kups.potential.mliap.direct import make_direct_mliap_potential
def my_forces_fn(inp: DirectMliapInput) -> WithPatch[PotentialOut[Array, EmptyType], IdPatch]:
energy, forces = model(inp.graph)
return WithPatch(PotentialOut(energy, -forces, EMPTY), IdPatch())
potential = make_direct_mliap_potential(my_forces_fn, ...)
DirectMliapFn
¶
Bases: Protocol
Protocol for a direct MLIAP model function.
Returns a PotentialOut that bundles energy, gradients and (optionally)
hessians for one graph input. Conventional Gradients payloads:
Array: position gradients only (∂E/∂r).PositionAndCell: position + cell gradients (forces + stress).EmptyType: no gradients — but in that case the autodiff path (PotentialFromEnergy) is more natural; this module is for the gradient-producing case.
Source code in src/kups/potential/mliap/direct.py
make_direct_mliap_potential(model_fn, particles_view, systems_view, neighborlist_view, model_view, cutoffs_view, *, patch_idx_view=None, out_cache_lens=None)
¶
Wrap a direct-gradient model_fn into a kUPS Potential.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_fn
|
DirectMliapFn[Model, Gradients, Hessians, P, S, Ptch]
|
Direct MLIAP function — see DirectMliapFn. |
required |
particles_view
|
View[State, Table[ParticleId, P]]
|
View to extract particles from state. |
required |
systems_view
|
View[State, Table[SystemId, S]]
|
View to extract systems (cell) from state. |
required |
neighborlist_view
|
View[State, NearestNeighborList]
|
View to extract neighbor list from state. |
required |
model_view
|
View[State, Model]
|
View to extract model from state. |
required |
cutoffs_view
|
View[State, Table[SystemId, Array]]
|
View to extract cutoffs as |
required |
patch_idx_view
|
View[State, PotentialOut[Gradients, Hessians]] | None
|
View for cached output indices (optional). |
None
|
out_cache_lens
|
Lens[State, PotentialOut[Gradients, Hessians]] | None
|
Lens for output cache (optional). |
None
|
Returns:
| Type | Description |
|---|---|
Potential[State, Gradients, Hessians, Patch[State]]
|
Configured kUPS |