Skip to content

Potentials

A Potential computes the energy of a system and, optionally, its gradients and Hessians. Unlike a Propagator, a potential does not modify the state. It evaluates the current configuration and returns a structured result. Potentials compose by summation (LJ + Coulomb + bonds), while propagators compose by sequencing. The two worlds meet at PotentialAsPropagator.

import jax.numpy as jnp
from jax import Array

from kups.core.capacity import FixedCapacity
from kups.core.data import Index, Table
from kups.core.data.wrappers import WithCache, WithIndices
from kups.core.lens import bind, identity_lens, lens
from kups.core.neighborlist import AllDenseNearestNeighborList
from kups.core.patch import IndexLensPatch
from kups.core.potential import EMPTY, PotentialOut, sum_potentials
from kups.core.typing import ExclusionId, InclusionId, Label, ParticleId, SystemId
from kups.core.unitcell import TriclinicUnitCell
from kups.core.utils.jax import dataclass
from kups.potential.classical.lennard_jones import (
    LennardJonesParameters,
    make_lennard_jones_from_state,
)

The Interface

A Potential takes a state and an optional Patch (for incremental MC updates) and returns a WithPatch[PotentialOut, Patch]. The PotentialOut contains per-system energies, gradients, and Hessians. The Patch is a recipe for writing results back into the state, which the caller applies when ready.

class Potential[State, Gradients, Hessians, StatePatch](Protocol):
    def __call__(self, state, patch=None) -> WithPatch[PotentialOut, Patch]: ...

PotentialOut

PotentialOut holds total_energies (one per system), gradients, and hessians. It supports addition, so when two potentials are summed, their outputs add element-wise. No special composition logic needed.

a = PotentialOut(
    total_energies=Table.arange(jnp.array([1.0]), label=SystemId),
    gradients=jnp.array([0.5]),
    hessians=EMPTY,
)
b = PotentialOut(
    total_energies=Table.arange(jnp.array([2.0]), label=SystemId),
    gradients=jnp.array([0.3]),
    hessians=EMPTY,
)
c = a + b
print("energies:", c.total_energies.data)
print("gradients:", c.gradients)
energies: [3.]
gradients: [0.8]

Setting Up a System

To evaluate a potential, we need particles (positions, species labels, system assignment), a unit cell, a neighbor list, and potential parameters. Let's build a minimal argon system with three particles. We use TriclinicUnitCell for the simulation box, AllDenseNearestNeighborList as the neighbor list, and LennardJonesParameters for the pair parameters.

@dataclass
class ParticleData:
    positions: Array
    labels: Index[Label]
    system: Index[SystemId]
    inclusion: Index[InclusionId]
    exclusion: Index[ExclusionId]


@dataclass
class SystemData:
    unitcell: TriclinicUnitCell


@dataclass
class LJState:
    particles: Table[ParticleId, ParticleData]
    systems: Table[SystemId, SystemData]
    neighborlist: AllDenseNearestNeighborList
    lj_parameters: LennardJonesParameters


positions = jnp.array([[0.0, 0.0, 0.0], [2.0, 0.0, 0.0], [0.0, 2.0, 0.0]])
n = len(positions)

particles = Table.arange(
    ParticleData(
        positions=positions,
        labels=Index.new([Label("Ar")] * n),
        system=Index.new([SystemId(0)] * n),
        inclusion=Index.new([InclusionId(0)] * n),
        exclusion=Index.new([ExclusionId(i) for i in range(n)]),
    ),
    label=ParticleId,
)

systems = Table.arange(
    SystemData(unitcell=TriclinicUnitCell.from_matrix(10.0 * jnp.eye(3)[None])),
    label=SystemId,
)

lj_params = LennardJonesParameters.from_dict(
    cutoff=5.0,
    parameters={"Ar": (1.0, 0.5)},  # (sigma [Ã…], epsilon [eV])
    mixing_rule="lorentz_berthelot",
)

nl = AllDenseNearestNeighborList(
    avg_edges=FixedCapacity(n * n),
    avg_image_candidates=FixedCapacity(n * n),
)

state = LJState(
    particles=particles,
    systems=systems,
    neighborlist=nl,
    lj_parameters=lj_params,
)
print("particles:", len(state.particles))
print("systems:", len(state.systems))
particles: 3
systems: 1

The state satisfies the IsLJState protocol: it has particles, systems, neighborlist, and lj_parameters fields. The AllDenseNearestNeighborList generates all particle pairs (O(N²)), which is fine for small test systems. Production simulations use cell-list-based neighbor lists like CellListNeighborList.

Evaluating the Potential

make_lennard_jones_from_state wires together the graph constructor, energy function, and autodiff gradient machinery from a single Lens into the state. It is a shorthand that assumes the state follows the naming conventions.

potential = make_lennard_jones_from_state(identity_lens(LJState))
out = potential(state)
print("energy:", out.data.total_energies.data)
energy: [-0.06542206]

Under the hood, this built a RadiusGraphConstructor to assemble the pairwise graph from the neighbor list, a LocalGraphSumComposer to plan evaluations (full or incremental), and a PotentialFromEnergy wrapping the LJ energy function with jax.grad.

Gradients via Automatic Differentiation

Potentials only implement the energy function. Gradients are computed automatically by jax.grad inside PotentialFromEnergy. Pass compute_position_and_unitcell_gradients=True to get position gradients (forces) and unit cell gradients (for stress).

potential_with_grads = make_lennard_jones_from_state(
    identity_lens(LJState),
    compute_position_and_unitcell_gradients=True,
)
out = potential_with_grads(state)

print("energy:", out.data.total_energies.data)
print("position gradients (forces = -grad):")
print(out.data.gradients.positions.data)
energy: [-0.06542206]
position gradients (forces = -grad):
[[-0.09082031 -0.09082031  0.        ]
 [ 0.0966568  -0.00583649  0.        ]
 [-0.00583649  0.0966568   0.        ]]

The gradient structure is PositionAndUnitCell containing positions (a Table[ParticleId, Array]) and unitcell (a Table[SystemId, ...]). Forces are the negation of the position gradient.

Composing Potentials

sum_potentials combines multiple potentials. Energies, gradients, and patches all compose automatically.

# Same potential twice = double energy (toy example of composition)
doubled = sum_potentials(potential, potential)

out_single = potential(state)
out_double = doubled(state)
print("single:", out_single.data.total_energies.data)
print("double:", out_double.data.total_energies.data)
single: [-0.06542206]
double: [-0.13084412]

In a real simulation, this would combine LJ + Coulomb + bonded terms. Each sub-potential has its own parameters, graph constructor, and energy function, but sum_potentials handles the composition transparently.

Incremental Updates for Monte Carlo

In Monte Carlo, each step moves only a few particles. Recomputing the full energy from scratch would be O(N). The incremental update machinery computes only the delta:

  1. Build the graph for the old configuration of the moved particles (weight = -1)
  2. Build the graph for the new configuration (weight = +1)
  3. Add the cached total from the previous step

Result: E_new = E_cached - E_affected_old + E_affected_new

To enable this, the LJ parameters are wrapped in WithCache, which stores the most recent PotentialOut alongside the parameters. A Probe tells the potential which particles changed. The return type of the probe is potential-specific. Pair potentials expect IsRadiusGraphProbe[P], a structural protocol requiring particles, neighborlist_before, and neighborlist_after fields. Following the [Conventions][conventions.ipynb] chapter, we'll satisfy that protocol with a plain @dataclass. Let's set this up.

Setting Up a Cached State

We redefine the state with WithCache wrapping the LJ parameters. The cache starts empty and gets populated on the first full evaluation.

@dataclass
class CachedLJState:
    particles: Table[ParticleId, ParticleData]
    systems: Table[SystemId, SystemData]
    neighborlist: AllDenseNearestNeighborList
    lj_parameters: WithCache[LennardJonesParameters, PotentialOut]


empty_cache = PotentialOut(
    total_energies=Table.arange(jnp.zeros(1), label=SystemId),
    gradients=EMPTY,
    hessians=EMPTY,
)

cached_state = CachedLJState(
    particles=particles,
    systems=systems,
    neighborlist=nl,
    lj_parameters=WithCache(lj_params, empty_cache),
)


# LJ's probe must return an object satisfying the
# `IsRadiusGraphProbe[IsLJGraphParticles]` protocol. Three fields:
# `particles` (which particles moved, as a `WithIndices`),
# `neighborlist_before` and `neighborlist_after` (so the potential can
# subtract affected old edges and add affected new edges). We declare a
# plain `@dataclass`; structural typing does the rest, matching the
# `@dataclass` + `Has*`/`Is*` protocol style from the Conventions chapter.
@dataclass
class LJProbe:
    particles: WithIndices[ParticleId, ParticleData]
    neighborlist_before: AllDenseNearestNeighborList
    neighborlist_after: AllDenseNearestNeighborList


def probe(state: CachedLJState, patch) -> LJProbe:
    new_state = patch(
        state, state.systems.set_data(jnp.ones(len(state.systems), dtype=bool))
    )
    # `Index.integer` is the public constructor for "a subset of an
    # existing key-space": `jnp.array([1])` selects particle 1 out of the
    # full ParticleId(0..n-1) range.
    pidx = Index.integer(jnp.array([1]), n=n, label=ParticleId)
    return LJProbe(
        particles=WithIndices(pidx, new_state.particles.at(pidx).get()),
        neighborlist_before=state.neighborlist,
        neighborlist_after=new_state.neighborlist,
    )


potential_inc = make_lennard_jones_from_state(identity_lens(CachedLJState), probe=probe)

Step 1: Full Evaluation (Populate the Cache)

The first call uses patch=None for a full evaluation. The returned patch writes the energy into the cache.

out_full = potential_inc(cached_state, patch=None)
accept = out_full.data.total_energies.set_data(jnp.ones(1, dtype=bool))
cached_state = out_full.patch(cached_state, accept)

print("cached energy:", cached_state.lj_parameters.cache.total_energies.data)
cached energy: [-0.06542206]

Step 2: Incremental Evaluation

Now we move particle 1 from [2, 0, 0] to [3, 0, 0] via an IndexLensPatch. The potential evaluates only the edges affected by the move and adds the delta to the cached total.

# Build a patch that moves particle 1
new_positions = positions.at[1].set(jnp.array([3.0, 0.0, 0.0]))
new_particles = Table.arange(
    ParticleData(
        positions=new_positions,
        labels=Index.new([Label("Ar")] * n),
        system=Index.new([SystemId(0)] * n),
        inclusion=Index.new([InclusionId(0)] * n),
        exclusion=Index.new([ExclusionId(i) for i in range(n)]),
    ),
    label=ParticleId,
)
patch = IndexLensPatch(
    data=new_particles,
    mask_idx=Index((SystemId(0),), jnp.zeros(1, dtype=int)),
    lens=lens(lambda s: s.particles, cls=CachedLJState),
)

# Incremental evaluation: uses cached total + delta
out_inc = potential_inc(cached_state, patch=patch)
e_incremental = out_inc.data.total_energies.data
print("incremental energy:", e_incremental)
incremental energy: [-0.03441136]

Step 3: Verify Against Full Recomputation

The incremental result should match a full evaluation on the updated state. Let's confirm.

# Full recomputation on updated state
new_state = (
    bind(cached_state).focus(lambda s: s.particles.data.positions).set(new_positions)
)
e_full = potential_inc(new_state, patch=None).data.total_energies.data

print("incremental:      ", e_incremental)
print("full recomputation:", e_full)
print("match:", jnp.allclose(e_incremental, e_full))
incremental:       [-0.03441136]
full recomputation: [-0.03441136]
match: True

The two energies match. The incremental path computed only the edges touching particle 1, while the full path recomputed all edges. In a system with thousands of particles, this difference is the key to efficient Monte Carlo sampling.

Not all potentials support incremental updates. Machine-learned force fields typically require the full graph, so they fall back to full recomputation automatically. This is transparent to the caller.

Timing: Full vs. Incremental

With 3 particles, both paths are trivially fast. Let's build a larger system (200 particles) to see the real difference.

import jax

N = 200
key = jax.random.key(42)
big_positions = jax.random.uniform(key, (N, 3)) * 10.0

big_particles = Table.arange(
    ParticleData(
        positions=big_positions,
        labels=Index.new([Label("Ar")] * N),
        system=Index.new([SystemId(0)] * N),
        inclusion=Index.new([InclusionId(0)] * N),
        exclusion=Index.new([ExclusionId(i) for i in range(N)]),
    ),
    label=ParticleId,
)

big_nl = AllDenseNearestNeighborList(
    avg_edges=FixedCapacity(N * N),
    avg_image_candidates=FixedCapacity(N * N),
)

big_state = CachedLJState(
    particles=big_particles,
    systems=systems,
    neighborlist=big_nl,
    lj_parameters=WithCache(lj_params, empty_cache),
)


# Same `LJProbe` dataclass, same `Index.integer` subset selector, just a
# different `N` and a different moved particle (particle 0 here).
def big_probe(state, patch) -> LJProbe:
    new_state = patch(
        state, state.systems.set_data(jnp.ones(len(state.systems), dtype=bool))
    )
    pidx = Index.integer(jnp.array([0]), n=N, label=ParticleId)
    return LJProbe(
        particles=WithIndices(pidx, new_state.particles.at(pidx).get()),
        neighborlist_before=state.neighborlist,
        neighborlist_after=new_state.neighborlist,
    )


big_potential = make_lennard_jones_from_state(
    identity_lens(CachedLJState), probe=big_probe
)

# Full evaluation to populate cache
out = big_potential(big_state, patch=None)
accept = out.data.total_energies.set_data(jnp.ones(1, dtype=bool))
big_state = out.patch(big_state, accept)

# Build patch: move particle 0
big_new_positions = big_positions.at[0].set(jnp.array([5.0, 5.0, 5.0]))
big_new_particles = Table.arange(
    ParticleData(
        positions=big_new_positions,
        labels=Index.new([Label("Ar")] * N),
        system=Index.new([SystemId(0)] * N),
        inclusion=Index.new([InclusionId(0)] * N),
        exclusion=Index.new([ExclusionId(i) for i in range(N)]),
    ),
    label=ParticleId,
)
big_patch = IndexLensPatch(
    data=big_new_particles,
    mask_idx=Index((SystemId(0),), jnp.zeros(1, dtype=int)),
    lens=lens(lambda s: s.particles, cls=CachedLJState),
)

big_new_state = (
    bind(big_state).focus(lambda s: s.particles.data.positions).set(big_new_positions)
)

# JIT compile both paths
full_eval = jax.jit(lambda s: big_potential(s, patch=None))
inc_eval = jax.jit(lambda s: big_potential(s, patch=big_patch))

# Warm up: trigger compilation
_ = full_eval(big_new_state).data.total_energies.data.block_until_ready()
_ = inc_eval(big_state).data.total_energies.data.block_until_ready()
print(f"Ready: {N} particles, compiled both paths")
Ready: 200 particles, compiled both paths
print("Full recomputation:")
%time _ = full_eval(big_new_state).data.total_energies.data.block_until_ready()

print("\nIncremental (1 particle moved):")
%time _ = inc_eval(big_state).data.total_energies.data.block_until_ready()
Full recomputation:


CPU times: user 1.71 s, sys: 56 ms, total: 1.77 s
Wall time: 682 ms

Incremental (1 particle moved):
CPU times: user 15.7 ms, sys: 6.17 ms, total: 21.9 ms
Wall time: 9.42 ms

Tojax: Machine-Learned Force Fields

Machine-learned potentials (MACE, UMA, ORB) are integrated via Tojax. The model is exported from PyTorch to JAX, then wrapped as a standard kUPS Potential via TojaxedMliap and make_tojaxed_from_state:

from kups.potential.mliap.tojax import TojaxedMliap, make_tojaxed_from_state

model = TojaxedMliap.from_zip_file("model.zip")
potential = make_tojaxed_from_state(
    state=identity_lens(MyState),
    model=model,
)

The MLFF uses the same RadiusGraphConstructor and neighbor list infrastructure as classical pair potentials. The cutoff comes from the model metadata. MLFFs typically do not support incremental updates since the model needs the full graph as input. When called with a patch, the potential falls back to full recomputation automatically.

From the integrator's perspective, there is no difference between an MLFF and a classical potential. Both implement Potential and return PotentialOut.

Available Force Fields

See the kups.potential package reference for the full list of classical and machine-learned potentials. Classical potentials use either a RadiusGraphConstructor (pair potentials like Lennard-Jones, Coulomb, Ewald) or an EdgeSetGraphConstructor (bonded terms like harmonic bonds, angles, dihedrals).