kups.potential.mliap.tojax
¶
Jaxified machine learning interatomic potential interface.
This module provides integration with generic JAX-exported MLFF models
via the AtomGraphInput / EnergyFn protocol. Supports periodic
systems with graph-based atomic representations.
AtomGraphInput
¶
Bases: TypedDict
Typed dictionary for jaxified model graph input.
Source code in src/kups/potential/mliap/tojax.py
EnergyFn
¶
IsTojaxedState
¶
Bases: Protocol
Protocol for states providing all inputs for the jaxified potential.
Source code in src/kups/potential/mliap/tojax.py
TojaxedMliap
¶
Jaxified model container.
Attributes:
| Name | Type | Description |
|---|---|---|
cutoff |
Table[SystemId, Array]
|
Model cutoff radius [Angstrom]. |
params |
list[Array]
|
Model parameters as a list of arrays. |
model |
Exported
|
Exported JAX model. |
Source code in src/kups/potential/mliap/tojax.py
call(input)
¶
Call the jaxified model on the given input.
Source code in src/kups/potential/mliap/tojax.py
from_zip_file(zip_file)
staticmethod
¶
Load a jaxified model from a zip archive.
Expects the archive to contain model.jax, metadata.json
(with a cutoff key), and params.msgpack.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
zip_file
|
str | Path
|
Path to the |
required |
Returns:
| Type | Description |
|---|---|
TojaxedMliap
|
Loaded jaxified model. |
Source code in src/kups/potential/mliap/tojax.py
make_tojaxed_from_state(state, *, compute_position_and_unitcell_gradients=False)
¶
Create a jaxified potential from a typed state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
Any
|
Lens into the sub-state providing particles, unit cell, neighbor list, and jaxified model. |
required |
compute_position_and_unitcell_gradients
|
bool
|
When |
False
|
Returns:
| Type | Description |
|---|---|
Any
|
Configured jaxified Potential. |
Source code in src/kups/potential/mliap/tojax.py
make_tojaxed_potential(particles_view, systems_view, neighborlist_view, model, cutoffs_view, gradient_lens, hessian_lens, hessian_idx_view, patch_idx_view=None, out_cache_lens=None)
¶
Create a jaxified machine learning potential.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
particles_view
|
View[State, Table[ParticleId, IsTojaxedParticles]]
|
Extracts particle data (positions, species). |
required |
systems_view
|
View[State, Table[SystemId, HasUnitCell]]
|
Extracts system data (unit cell). |
required |
neighborlist_view
|
View[State, NearestNeighborList]
|
Extracts neighbor list. |
required |
model
|
View[State, TojaxedMliap] | TojaxedMliap
|
Jaxified model instance or view to model in state. |
required |
cutoffs_view
|
View[State, Table[SystemId, Array]]
|
Extracts cutoffs as |
required |
gradient_lens
|
Lens[JaxifiedInput, Gradients]
|
Lens specifying which gradients to compute. |
required |
hessian_lens
|
Lens[Gradients, Hessians]
|
Lens specifying which Hessians to compute. |
required |
hessian_idx_view
|
View[State, Hessians]
|
View to hessian index structure. |
required |
patch_idx_view
|
View[State, PotentialOut[Gradients, Hessians]] | None
|
View to cached output index structure. |
None
|
out_cache_lens
|
Lens[State, PotentialOut[Gradients, Hessians]] | None
|
Lens to cache location. |
None
|
Returns:
| Type | Description |
|---|---|
PotentialFromEnergy[State, JaxifiedInput, Gradients, Hessians, Any]
|
Jaxified potential. |
Source code in src/kups/potential/mliap/tojax.py
tojaxed_energy(inp)
¶
Compute energy using a jaxified model.
Prepares graph data and calls the exported model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
inp
|
JaxifiedInput
|
Graph potential input containing the jaxified model and graph data. |
required |
Returns:
| Type | Description |
|---|---|
WithPatch[Table[SystemId, Energy], IdPatch]
|
Per-system energies. |