Skip to content

Neighbor Lists

Local potentials, classical or machine-learned, all ask the same question inside the inner loop: for each particle, which others sit within r_cut? Answered naively, the cost is O(N²), and at production sizes the neighbor search easily dominates the potential evaluation itself. A good neighbor list turns that scan into something close to O(N) by exploiting the fact that most pairs are too far apart to matter.

JAX adds a second constraint. Inside jax.jit every array has a fixed shape known at trace time, so the neighbor list cannot simply grow when more edges are needed. The data structures have to be pre-sized, the kernels written as gather-scatter over padded buffers, and overflow detected without branching. The neighbor list module is where this machinery is hidden so that the rest of kUPS can treat particles as if they interacted through a clean, fixed-shape edge list.

import jax
import jax.numpy as jnp
from jax import Array

from kups.core.capacity import FixedCapacity
from kups.core.data import Index, Table
from kups.core.data.wrappers import WithIndices
from kups.core.lens import identity_lens
from kups.core.neighborlist import (
    CellListNeighborList,
    DenseNearestNeighborList,
    Edges,
    RefineCutoffNeighborList,
    UniversalNeighborlistParameters,
    neighborlist_changes,
)
from kups.core.result import as_result_function
from kups.core.typing import ExclusionId, InclusionId, ParticleId, SystemId
from kups.core.unitcell import TriclinicUnitCell
from kups.core.utils.jax import dataclass


@dataclass
class Points:
    positions: Array
    system: Index[SystemId]
    inclusion: Index[InclusionId]
    exclusion: Index[ExclusionId]


@dataclass
class Systems:
    unitcell: TriclinicUnitCell


def make_state(positions, system_of_particle, box_size=10.0, n_systems=None):
    """Helper: build a (particles, systems) pair with cubic unit cells."""
    n = len(positions)
    if n_systems is None:
        n_systems = int(max(system_of_particle)) + 1
    particles = Table.arange(
        Points(
            positions=jnp.asarray(positions),
            system=Index.new([SystemId(int(i)) for i in system_of_particle]),
            inclusion=Index.new([InclusionId(int(i)) for i in system_of_particle]),
            exclusion=Index.new([ExclusionId(i) for i in range(n)]),
        ),
        label=ParticleId,
    )
    unitcell = TriclinicUnitCell.from_matrix(
        box_size * jnp.eye(3)[None].repeat(n_systems, axis=0)
    )
    systems = Table.arange(Systems(unitcell=unitcell), label=SystemId)
    return particles, systems


def valid_edges(edges):
    """Return the (n_valid, degree) array of in-bounds edges."""
    mask = edges.indices.valid_mask.all(axis=-1)
    return edges.indices.indices[mask]


def edge_set(edges, n_particles):
    raw = edges.indices.indices
    return {
        (int(raw[i, 0]), int(raw[i, 1]))
        for i in range(len(raw))
        if raw[i, 0] < n_particles and raw[i, 1] < n_particles
    }

The output: Edges

Every implementation returns an Edges object. Edges represents hyperedges of arbitrary degree: Edges[Literal[2]] is a set of pairs (the usual output of a nearest-neighbor search), Edges[Literal[3]] is a set of triplets (used for three-body terms such as angle potentials), Edges[Literal[4]] covers dihedrals, and so on. The degree is a static type parameter, so the type system distinguishes a pair edge list from a triplet one and consumers cannot accidentally feed the wrong order into a potential. Each edge carries ParticleId indices into the particle table together with integer shift vectors of shape (Degree - 1, 3) that identify which periodic image of each subsequent particle is being connected to the first.

Shifts rather than wrapped positions are carried on the edge because positions belong to particles, not to pairs. Keeping the integer shift alongside the edge lets consumers reconstruct the absolute separation vector with a small matrix multiply against the current unit cell, which is what Edges.difference_vectors and Edges.absolute_shifts do.

The fixed-shape contract of JAX shows up as padding. Valid edges fill the first rows; unused rows hold a sentinel index that points out of bounds, mirroring the occupation convention used by Buffered. Downstream code filters on the out-of-bounds flag rather than on a separate boolean mask.

Below we construct an Edges object by hand for three particles at x = 0, 1, 2 in a cubic box of edge 3 Å, then compute difference vectors for the same 0 → 2 pair with three different shift choices: direct, through the +a image, and through the -a image. The -a route is the minimum-image one.

positions = jnp.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [2.0, 0.0, 0.0]])
particles, systems = make_state(positions, [0, 0, 0], box_size=3.0)

direct = Edges(
    indices=Index(particles.keys, jnp.array([[0, 2]])),
    shifts=jnp.zeros((1, 1, 3), dtype=int),
)
plus_a = Edges(
    indices=Index(particles.keys, jnp.array([[0, 2]])),
    shifts=jnp.array([[[1, 0, 0]]]),
)
minus_a = Edges(
    indices=Index(particles.keys, jnp.array([[0, 2]])),
    shifts=jnp.array([[[-1, 0, 0]]]),
)

print("direct  :", direct.difference_vectors(particles, systems)[0, 0])
print("+a image:", plus_a.difference_vectors(particles, systems)[0, 0])
print("-a image:", minus_a.difference_vectors(particles, systems)[0, 0])
direct  : [2. 0. 0.]
+a image: [5. 0. 0.]
-a image: [-1.  0.  0.]

The protocol: NearestNeighborList

Callers never depend on a specific algorithm. They hold something that satisfies NearestNeighborList, a protocol with a single __call__ method that takes a left-hand particle table lh, an optional right-hand table rh, a system table, per-system cutoffs, and an optional index remap, and returns an Edges[Literal[2]]. When rh is None the neighbor list finds pairs within lh, which is the familiar symmetric case driven by MD force evaluations.

The reason the protocol separates lh from rh at all is that most non-MD workloads are asymmetric. A Monte Carlo displacement moves a handful of atoms per step, and the acceptance test only needs the neighbors of those few against the rest of the system, not a rebuild of the entire pair list. A grand canonical insertion queries a candidate position against the current particles without touching the existing pair structure, and a deletion does the same in reverse. A partial refresh after a hybrid move only revisits the atoms that have drifted far enough to matter. In each case lh is small and rh is the full state, and forcing the two to be identical would either inflate lh to the whole particle table (paying the full O(N) or O(N²) cost every step) or push the subset bookkeeping onto the caller. The protocol makes the asymmetric case cheap and explicit; rh=None is just the convenience shortcut for the symmetric one.

The rh_index_remap argument declares whether items in rh are also members of lh. When it is supplied the neighbor list treats the two sets as overlapping: self-pairs are dropped, duplicate interactions arising from the overlap are merged, and every surviving pair is emitted symmetrically (both (i, j) and (j, i) appear) so that downstream code sees the usual undirected pair structure. When the argument is left None the sets are treated as disjoint and the raw directed pair list is produced. This is what lets a Monte Carlo step query only the moved atoms against the full system without double-counting anything that sits in both sides.

The call below first runs the neighbor list symmetrically over a four-particle chain, then runs it again querying only particle 1 against the full set with rh_index_remap supplied. The second call returns the pairs that touch particle 1 in both directions, and nothing else.

positions = jnp.array(
    [
        [0.0, 0.0, 0.0],
        [1.0, 0.0, 0.0],
        [2.0, 0.0, 0.0],
        [5.0, 0.0, 0.0],
        [6.0, 0.0, 0.0],
    ]
)
particles, systems = make_state(positions, [0, 0, 0, 0, 0], box_size=20.0)
cutoffs = Table.arange(jnp.array([1.5]), label=SystemId)

nl = DenseNearestNeighborList(
    avg_candidates=FixedCapacity(32),
    avg_edges=FixedCapacity(16),
    avg_image_candidates=FixedCapacity(32),
)

sym = nl(particles, None, systems, cutoffs)
print("symmetric (rh=None):", sorted(edge_set(sym, len(particles))))

# Query only particle 1 against the full set.
query_idx = 1
query = Table(
    keys=(ParticleId(0),),
    data=Points(
        positions=particles.data.positions[query_idx : query_idx + 1],
        system=Index.new([SystemId(0)]),
        inclusion=Index.new([InclusionId(0)]),
        exclusion=Index.new([ExclusionId(query_idx)]),
    ),
)
rh_remap = Index(particles.keys, jnp.array([query_idx]))

asym = nl(particles, query, systems, cutoffs, rh_remap)
print("query only {1} with remap:", sorted(edge_set(asym, len(particles))))
symmetric (rh=None): [(0, 1), (1, 0), (1, 2), (2, 1), (3, 4), (4, 3)]


query only {1} with remap: [(0, 1), (1, 0), (1, 2), (2, 1)]

Two implementations and how to pick between them

The choice comes down to the ratio of cutoff to box size. DenseNearestNeighborList is O(N²) per system. It is the right choice when the cutoff is close to the box size and spatial partitioning would produce only a handful of cells.

CellListNeighborList is the workhorse. It partitions space into cells roughly the size of the cutoff, hashes each particle into its cell, and for each particle examines only the 27 neighboring cells in 3D, giving O(N) scaling for uniform density. It needs a unit cell and prefers cutoff/box below roughly 0.3.

Both agree on the same underlying pair set; the difference is how they search for it.

key = jax.random.key(0)
positions = jax.random.uniform(key, (30, 3), minval=0.0, maxval=8.0)
particles, systems = make_state(positions, [0] * 30, box_size=8.0)
cutoffs = Table.arange(jnp.array([2.5]), label=SystemId)

dense = DenseNearestNeighborList(
    avg_candidates=FixedCapacity(128),
    avg_edges=FixedCapacity(128),
    avg_image_candidates=FixedCapacity(128),
)
cell = CellListNeighborList(
    avg_candidates=FixedCapacity(128),
    avg_edges=FixedCapacity(128),
    cells=FixedCapacity(128),
    avg_image_candidates=FixedCapacity(128),
)

e_den = edge_set(dense(particles, None, systems, cutoffs), len(particles))
e_cel = edge_set(cell(particles, None, systems, cutoffs), len(particles))

print("Dense:   ", len(e_den), "edges")
print("CellList:", len(e_cel), "edges")
print("both agree:", e_den == e_cel)
Dense:    98 edges
CellList: 98 edges
both agree: True

Periodic boundaries and shifts

Particle positions are kept unwrapped. Wrapping would introduce discontinuities that fight autodiff and velocity updates and would not compose cleanly with triclinic cells. Periodicity is carried entirely on the edges through the integer shift vectors: [0, 0, 0] connects two particles directly, [1, 0, 0] connects the first to the image of the second translated by one cell along a.

Internally, the search runs in fractional coordinates so that cell hashing and distance tests can be written in a cell-agnostic way, and the distances that get compared against the cutoff are projected back to Cartesian. Non-periodic systems reuse the same machinery with a box large enough that all shifts come out zero, though very sparse or strictly non-periodic cases are better served by the dense variants.

The example below places two particles near opposite faces of a 3 Å cubic box with a 1.0 Å cutoff. Their direct separation is 2.8 Å and would miss the cutoff, but the minimum-image separation is 0.2 Å and the neighbor list finds the edge through the periodic shift.

positions = jnp.array([[0.1, 0.0, 0.0], [2.9, 0.0, 0.0]])
particles, systems = make_state(positions, [0, 0], box_size=3.0)
cutoffs = Table.arange(jnp.array([1.0]), label=SystemId)

nl = CellListNeighborList(
    avg_candidates=FixedCapacity(16),
    avg_edges=FixedCapacity(16),
    cells=FixedCapacity(64),
    avg_image_candidates=FixedCapacity(32),
)
edges = nl(particles, None, systems, cutoffs)
print("pairs:", sorted(edge_set(edges, len(particles))))

# The shift on the first valid edge encodes the cell crossing
mask = edges.indices.valid_mask.all(axis=-1)
first = jnp.argmax(mask)
print("shift :", edges.shifts[first, 0])
print(
    "minimum-image separation:", edges.difference_vectors(particles, systems)[first, 0]
)
pairs: [(0, 1), (1, 0)]
shift : [-1.  0.  0.]


minimum-image separation: [-0.19999981  0.          0.        ]

Capacity management: fixed shapes for a variable problem

The number of neighbors per particle is a runtime quantity. JAX insists on fixed shapes. The Capacity abstraction bridges the two. A neighbor list is parameterized by an avg_edges capacity for the output buffer, an avg_candidates capacity for the intermediate pair list produced by the spatial decomposition, an avg_image_candidates capacity for pairs that need image translation, and a cells capacity for the hash table of the cell list.

These values are upper bounds. If they are exceeded at runtime, a RuntimeAssertion fires inside the compiled kernel and the host-side retry loop resizes the buffer before re-entering. The Runtime Assertions notebook explains that mechanism, so it is not repeated here. What matters in this context is that initial capacities only need to be reasonable; the doubling machinery converges on the right sizes during warmup without further intervention.

The cell below deliberately under-sizes avg_edges, observes the assertion failure, applies the fix, and shows that the next call succeeds with the resized buffer.

from kups.core.utils.jax import field


@dataclass
class NLParams:
    avg_edges: int = field(static=True)
    avg_candidates: int = field(static=True)
    avg_image_candidates: int = field(static=True)
    cells: int = field(static=True)


key = jax.random.key(1)
positions = jax.random.uniform(key, (20, 3), minval=0.0, maxval=10.0)
particles, systems = make_state(positions, [0] * 20, box_size=10.0)
cutoffs = Table.arange(jnp.array([3.0]), label=SystemId)

# Start with an intentionally small edge capacity. The effective buffer used
# by the kernel is `avg_edges * n_rh_particles`, so `avg_edges=2` with 20
# particles only reserves space for 40 edges, which will be too few.
params = NLParams(avg_edges=2, avg_candidates=64, avg_image_candidates=64, cells=64)


@as_result_function
def run(params):
    nl = CellListNeighborList.new(params, identity_lens(NLParams))
    return nl(particles, None, systems, cutoffs)


print(
    f"initial avg_edges = {params.avg_edges} (buffer = {params.avg_edges * len(particles)})"
)
result = jax.jit(run)(params)
print(
    f"first call: assertions pass = {bool(result.all_assertions_pass)}, "
    f"failed = {len(result.failed_assertions)}"
)

# Apply the fix and retry until the assertion passes.
retries = 0
while not result.all_assertions_pass:
    params = result.fix_or_raise(params)
    result = jax.jit(run)(params)
    retries += 1

print(f"retries needed: {retries}")
print(
    f"final avg_edges = {params.avg_edges} (buffer = {params.avg_edges * len(particles)})"
)
print(
    f"valid edges produced: {int(result.value.indices.valid_mask.all(axis=-1).sum())}"
)
initial avg_edges = 2 (buffer = 40)


first call: assertions pass = False, failed = 1


retries needed: 1
final avg_edges = 4 (buffer = 80)
valid edges produced: 46

Parameter estimation from geometry

Constructing a neighbor list takes a parameter object satisfying one of the Is*Params protocols (IsDenseNeighborlistParams, IsCellListParams, or the superset IsUniversalNeighborlistParams). UniversalNeighborlistParameters is a concrete dataclass that satisfies the universal form, and its estimate classmethod computes starting values from the system geometry by combining the uniform-density estimate ρ · (4π/3) · r_cut³ with the cell count implied by the box. Both quantities are rounded up to a power of two so resizing only happens when the system actually changes regime, and a safety multiplier trades a little memory for fewer warmup retries.

N = 200
key = jax.random.key(2)
positions = jax.random.uniform(key, (N, 3), minval=0.0, maxval=15.0)
particles, systems = make_state(positions, [0] * N, box_size=15.0)
cutoffs = Table.arange(jnp.array([3.0]), label=SystemId)
n_per_sys = Table.arange(jnp.array([N]), label=SystemId)

params = UniversalNeighborlistParameters.estimate(n_per_sys, systems, cutoffs)
print(f"N={N}, box=15Å, cutoff=3Å")
print("estimate:", params)
N=200, box=15Å, cutoff=3Å
estimate: UniversalNeighborlistParameters(avg_edges=8, avg_candidates=128, avg_image_candidates=128, cells=64)

Sharing one neighbor list between several potentials

Real simulations combine multiple potentials, each wanting its own cutoff and exclusion rules. Building a separate neighbor list for each duplicates the most expensive part of the step. The refinement classes avoid that.

RefineMaskNeighborList takes a precomputed edge list and filters it using the inclusion and exclusion indices carried on each particle. Pairs whose inclusion labels match and whose exclusion labels differ are kept. No distances are recomputed, only labels compared. This puts a small contract on the caller: the consumer's cutoff must be smaller than the base cutoff, otherwise pairs that would be in range for the consumer are missing from the base list. The consumer also has to tolerate receiving pairs whose actual distance sits above its own cutoff, because the refiner does not re-check geometry, and handle them at the energy-function level (for example with a smooth switching function or an explicit cutoff mask). With those two points in mind, this is how non-bonded exclusions are layered on top of a shared geometric base list.

RefineCutoffNeighborList is the complement. The base list is built once with the maximum cutoff across all potentials, and each consumer re-checks distances against its own smaller cutoff. The expensive spatial decomposition happens once, the cheap distance filter happens per potential.

The pattern that emerges is a layered design: one base CellListNeighborList at the maximum cutoff with no exclusions, wrapped per potential by whichever refinement suits its needs.

Below we build a base list at a 3.0 Å cutoff, then use RefineCutoffNeighborList to produce a shorter-range edge list at 1.5 Å, and verify against an independent construction.

key = jax.random.key(3)
positions = jax.random.uniform(key, (40, 3), minval=0.0, maxval=8.0)
particles, systems = make_state(positions, [0] * 40, box_size=8.0)

base_cutoff = Table.arange(jnp.array([3.0]), label=SystemId)
short_cutoff = Table.arange(jnp.array([1.5]), label=SystemId)

base_nl = CellListNeighborList(
    avg_candidates=FixedCapacity(256),
    avg_edges=FixedCapacity(256),
    cells=FixedCapacity(128),
    avg_image_candidates=FixedCapacity(256),
)
base_edges = base_nl(particles, None, systems, base_cutoff)

refiner = RefineCutoffNeighborList(
    candidates=base_edges,
    avg_edges=FixedCapacity(256),
)
refined = refiner(particles, None, systems, short_cutoff)
independent = base_nl(particles, None, systems, short_cutoff)

n = len(particles)
print(f"base edges (cutoff 3.0):       {len(edge_set(base_edges, n))}")
print(f"refined edges (cutoff 1.5):    {len(edge_set(refined, n))}")
print(f"independent edges (cutoff 1.5): {len(edge_set(independent, n))}")
print(
    f"refined matches independent:   {edge_set(refined, n) == edge_set(independent, n)}"
)
base edges (cutoff 3.0):       372


refined edges (cutoff 1.5):    34


independent edges (cutoff 1.5): 34


refined matches independent:   True

Incremental changes for Monte Carlo

A full neighbor list rebuild after each local move is wasteful when the question is really which edges appeared and which disappeared as a consequence of the move. neighborlist_changes answers that directly. It takes the full particle table and a set of proposed local changes packaged as a WithIndices, runs a single query against the augmented state, and returns a NeighborListChangesResult whose removed edges touch the old positions and whose added edges touch the new ones.

This plugs straight into the patch system: the acceptance test reads added and removed edges, the state update commits positions and cached sums atomically via a composed patch, and the neighbor list itself is rebuilt on a coarser schedule driven by accumulated displacement.

The example below moves a single particle from (1, 0, 0) to (4.5, 0, 0), so the edge to its old neighbor at the origin disappears and a new edge to the neighbor at (5, 0, 0) appears. Running neighborlist_changes returns exactly those.

positions = jnp.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [5.0, 0.0, 0.0]])
particles, systems = make_state(positions, [0, 0, 0], box_size=10.0)
cutoffs = Table.arange(jnp.array([1.5]), label=SystemId)

nl = DenseNearestNeighborList(
    avg_candidates=FixedCapacity(32),
    avg_edges=FixedCapacity(32),
    avg_image_candidates=FixedCapacity(32),
)

new_positions = jnp.array([[4.5, 0.0, 0.0]])
moved_idx = jnp.array([1])

rh = Table(
    keys=(ParticleId(0),),
    data=Points(
        positions=new_positions,
        system=Index.new([SystemId(0)]),
        inclusion=Index.new([InclusionId(0)]),
        exclusion=Index.new([ExclusionId(1)]),
    ),
)
rh_with_indices = WithIndices(Index(particles.keys, moved_idx), rh)

changes = neighborlist_changes(nl, particles, rh_with_indices, systems, cutoffs)
print("removed:", sorted(edge_set(changes.removed, len(particles))))
print("added  :", sorted(edge_set(changes.added, len(particles))))
removed: [(0, 1), (1, 0)]
added  : [(1, 2), (2, 1)]

Where to go next

Most users interact with the neighbor list indirectly through the potential factories. When a simulation is slow, it is usually worth opening this layer and checking three things: whether the cutoff-to-box ratio justifies the current implementation, whether the capacity parameters have stabilized, and whether several potentials could share a base list through the refinement classes. Each is a self-contained lever that can be pulled without changing the rest of the simulation.

The Potentials notebook shows how these pieces connect to energy evaluation, and the Runtime Assertions notebook covers the host-side retry loop that makes capacity growth painless.