Skip to content

kups.core.neighborlist

Neighbor list construction and edge representations for molecular systems.

This module provides multiple neighbor list algorithms for finding interacting pairs of particles within cutoff distances, with different performance and accuracy trade-offs.

Core Components

  • Edges: Represents connections between particles with periodic shifts
  • NearestNeighborList: Protocol for neighbor search implementations
  • Pipeline: Selector → mask sequence → compactor

Neighbor List Implementations

Primary Implementations

  1. CellListNeighborList (Recommended when cutoff << box size)

    • O(N) complexity using spatial hashing
    • Best when cutoff / box_size < 0.3 (cutoff much smaller than box)
    • Honors the cell's per-axis periodic mask (bulk and bounded non-periodic)
  2. DenseNearestNeighborList

    • O(N²/K) complexity (K = number of systems)
    • Best when cutoff / box_size ~ 1 (cutoff comparable to box)
  3. AllDenseNearestNeighborList

    • O(N²) complexity across all systems
    • Only for single-system simulations or testing
    • Crosses system boundaries (use with caution!)

Refinement Implementations

These let one expensive base neighbor list be shared across multiple potentials.

  1. RefineMaskNeighborList: apply different inclusion/exclusion masks to precomputed edges.
  2. RefineCutoffNeighborList: refine precomputed edges with new cutoff distances.

Pipeline Primitives

Every neighbor list above is a Pipeline of a CandidateSelector, a tuple of Mask criteria, and a Compactor. Users wanting custom behavior can compose their own pipeline directly.

AllDenseNearestNeighborList

Dense O(N²) neighbor list considering all pairs across all systems.

This implementation generates all possible particle pairs without spatial optimization. It is only suitable for very small systems or testing.

Warning: This crosses system boundaries! Only use for single-system simulations. For multiple systems, use DenseNearestNeighborList instead.

Complexity: O(N²) where N is the total number of particles across all systems.

Attributes:

Name Type Description
avg_edges Capacity[int]

Capacity manager for edge array.

avg_image_candidates Capacity[int]

Capacity manager for image candidate pairs.

Example
# Construct from state and a lens to the neighbor list parameters:
nl = AllDenseNearestNeighborList.new(state, lens(lambda s: s.nl_params))

# Or, if the state implements IsNeighborListState:
nl = AllDenseNearestNeighborList.from_state(state)

edges = nl(particles, None, systems, cutoffs, None)
Source code in src/kups/core/neighborlist/all_dense.py
@dataclass
class AllDenseNearestNeighborList:
    """Dense O(N²) neighbor list considering all pairs across all systems.

    This implementation generates all possible particle pairs without spatial
    optimization. It is only suitable for very small systems or testing.

    **Warning**: This crosses system boundaries! Only use for single-system
    simulations. For multiple systems, use
    [DenseNearestNeighborList][kups.core.neighborlist.DenseNearestNeighborList]
    instead.

    Complexity: O(N²) where N is the total number of particles across all systems.

    Attributes:
        avg_edges: Capacity manager for edge array.
        avg_image_candidates: Capacity manager for image candidate pairs.

    Example:
        ```python
        # Construct from state and a lens to the neighbor list parameters:
        nl = AllDenseNearestNeighborList.new(state, lens(lambda s: s.nl_params))

        # Or, if the state implements IsNeighborListState:
        nl = AllDenseNearestNeighborList.from_state(state)

        edges = nl(particles, None, systems, cutoffs, None)
        ```
    """

    avg_edges: Capacity[int]
    avg_image_candidates: Capacity[int]

    @classmethod
    def new[S](
        cls, state: S, lens: Lens[S, IsAllDenseNeighborListParams]
    ) -> AllDenseNearestNeighborList:
        params = lens.get(state)
        return AllDenseNearestNeighborList(
            avg_edges=LensCapacity(params.avg_edges, lens.focus(lambda x: x.avg_edges)),
            avg_image_candidates=LensCapacity(
                params.avg_image_candidates,
                lens.focus(lambda x: x.avg_image_candidates),
            ),
        )

    @classmethod
    def from_state(
        cls, state: IsNeighborListState[IsAllDenseNeighborListParams]
    ) -> AllDenseNearestNeighborList:
        return cls.new(state, lens(lambda s: s.neighborlist_params))

    @jit
    def __call__(
        self,
        lh: Table[ParticleId, NeighborListPoints],
        rh: Table[ParticleId, NeighborListPoints] | None,
        systems: Table[SystemId, NeighborListSystems],
        cutoffs: Table[SystemId, Array],
        rh_index_remap: Index[ParticleId] | None = None,
    ) -> Edges[Literal[2]]:
        if lh.data.inclusion.num_labels >= 2:
            logging.warning(
                "AllDenseNearestNeighborList is intended for single-system simulations. "
                "Performance may be degraded when using multiple systems. "
                "Consider using DenseNearestNeighborList or CellListNeighborList instead."
            )
        rh_size = rh.size if rh is not None else lh.size
        cutoffs = Table.broadcast_to(cutoffs, systems)
        pipeline = Pipeline[Literal[2]](
            selector=AllDenseSelector(
                cutoffs=cutoffs,
                max_image_candidates=self.avg_image_candidates.multiply(rh_size),
            ),
            masks=(
                InBoundsMask(),
                InclusionMatchMask(),
                RemapDedupMask(),
                DistanceCutoffMask(cutoffs=cutoffs),
                ExclusionMask(),
            ),
            compactor=ReduceCompactor(avg_edges=self.avg_edges.multiply(rh_size)),
        )
        return pipeline(lh, rh, systems, rh_index_remap)

AllDenseSelector

Selector that emits every (i, j) pair across all systems.

Source code in src/kups/core/neighborlist/all_dense.py
@dataclass
class AllDenseSelector:
    """Selector that emits every ``(i, j)`` pair across all systems."""

    cutoffs: Table[SystemId, Array]
    max_image_candidates: Capacity[int]

    def __call__(self, ctx: PipelineContext) -> CandidateBatch[Literal[2]]:
        candidates = _all_subselect(ctx.lh, ctx.rh, ctx.systems)
        return replicate_for_images(
            candidates,
            ctx.lh,
            ctx.rh,
            ctx.systems,
            self.cutoffs,
            self.max_image_candidates,
        )

CandidateBatch

Bases: NamedTuple

Candidate set of degree D carried through the pipeline.

Reuses Edges[D] for the (indices, shifts) layout (indices shape (n, D), shifts shape (n, D-1, 3)); adds the is_minimum_image flag that ExclusionMask needs to keep non-minimum periodic copies of excluded pairs.

Auto-registered as a JAX PyTree because it is a NamedTuple.

Attributes:

Name Type Description
edges Edges[D]

Candidate edges (indices + fractional shifts).

is_minimum_image Array

(n,) bool — True where the candidate's shift equals the minimum-image shift; False for non-MIC replicated copies emitted by selectors that handle PBC image expansion.

Source code in src/kups/core/neighborlist/types.py
class CandidateBatch[D: int](NamedTuple):
    """Candidate set of degree ``D`` carried through the pipeline.

    Reuses [`Edges[D]`][kups.core.neighborlist.edges.Edges] for the
    `(indices, shifts)` layout (`indices` shape `(n, D)`,
    `shifts` shape `(n, D-1, 3)`); adds the
    ``is_minimum_image`` flag that
    [`ExclusionMask`][kups.core.neighborlist.masks.ExclusionMask] needs to
    keep non-minimum periodic copies of excluded pairs.

    Auto-registered as a JAX PyTree because it is a NamedTuple.

    Attributes:
        edges: Candidate edges (indices + fractional shifts).
        is_minimum_image: ``(n,)`` bool — True where the candidate's shift
            equals the minimum-image shift; False for non-MIC replicated
            copies emitted by selectors that handle PBC image expansion.
    """

    edges: Edges[D]
    is_minimum_image: Array

    @property
    def lh_idx(self) -> Array:
        """Pair-specific: raw lh-side index array of shape ``(n,)``. Only meaningful for ``D == 2``."""
        return self.edges.indices.indices[:, 0]

    @property
    def rh_idx(self) -> Array:
        """Pair-specific: raw rh-side index array of shape ``(n,)``. Only meaningful for ``D == 2``."""
        return self.edges.indices.indices[:, 1]

lh_idx property

Pair-specific: raw lh-side index array of shape (n,). Only meaningful for D == 2.

rh_idx property

Pair-specific: raw rh-side index array of shape (n,). Only meaningful for D == 2.

CandidateSelector

Bases: Protocol

Produces a CandidateBatch[D] from the pipeline context.

Owns all candidate-set construction, including any PBC image replication required when max(cutoff/perp_axis) > 0.5.

Source code in src/kups/core/neighborlist/types.py
class CandidateSelector[D: int](Protocol):
    """Produces a ``CandidateBatch[D]`` from the pipeline context.

    Owns all candidate-set construction, including any PBC image
    replication required when ``max(cutoff/perp_axis) > 0.5``.
    """

    def __call__(self, ctx: PipelineContext) -> CandidateBatch[D]: ...

CellListNeighborList

Efficient O(N) neighbor list using spatial hashing with cell lists.

This is the recommended implementation when the cutoff is much smaller than the box size. It divides space into a grid of cells and only checks pairs in neighboring cells, achieving linear scaling with system size.

Honors the cell's per-axis periodic mask: stencil offsets that cross a non-periodic face are routed to an out-of-bounds bin (no key matches), and minimum-image shifts are zero on non-periodic axes. The fully-periodic path is byte-identical to the original (gated at trace time on all(periodic)) so PBC kernels see no overhead.

Complexity: O(N) for well-distributed particles where cutoff << box size. Efficiency improves as cutoff/box ratio decreases.

Attributes:

Name Type Description
avg_candidates Capacity[int]

Capacity for candidate pair storage (from cell list).

avg_edges Capacity[int]

Capacity for final edge array.

cells Capacity[int]

Capacity for cell hash table (grows with box_size³/cutoff³).

avg_image_candidates Capacity[int]

Capacity for image candidate pairs.

Algorithm
  1. Partition space into grid cells of size ~cutoff
  2. Hash each particle to its cell
  3. For each particle, check only neighboring 27 cells (3D)
  4. Filter candidates by actual distance
When to use
  • When cutoff/box_size << 1 (cutoff much smaller than box)
  • Typically cutoff/box < 0.3 for good efficiency
  • On non-periodic axes positions must lie inside [0, L) in real coordinates (the caller's invariant; out-of-range positions are silently routed to the OOB bin)
Example
# Example: 10 Å cutoff in 50 Å box → cutoff/box = 0.2 -- Good for CellList
nl = CellListNeighborList.new(state, lens(lambda s: s.nl_params))

# Or, if the state implements IsNeighborListState:
nl = CellListNeighborList.from_state(state)

edges = nl(particles, None, systems, cutoffs, None)
Source code in src/kups/core/neighborlist/cell_list.py
@dataclass
class CellListNeighborList:
    """Efficient O(N) neighbor list using spatial hashing with cell lists.

    This is the recommended implementation when the cutoff is much smaller than
    the box size. It divides space into a grid of cells and only checks pairs in
    neighboring cells, achieving linear scaling with system size.

    Honors the cell's per-axis ``periodic`` mask: stencil offsets that cross a
    non-periodic face are routed to an out-of-bounds bin (no key matches), and
    minimum-image shifts are zero on non-periodic axes. The fully-periodic path
    is byte-identical to the original (gated at trace time on ``all(periodic)``)
    so PBC kernels see no overhead.

    Complexity: O(N) for well-distributed particles where cutoff << box size.
    Efficiency improves as cutoff/box ratio decreases.

    Attributes:
        avg_candidates: Capacity for candidate pair storage (from cell list).
        avg_edges: Capacity for final edge array.
        cells: Capacity for cell hash table (grows with box_size³/cutoff³).
        avg_image_candidates: Capacity for image candidate pairs.

    Algorithm:
        1. Partition space into grid cells of size ~cutoff
        2. Hash each particle to its cell
        3. For each particle, check only neighboring 27 cells (3D)
        4. Filter candidates by actual distance

    When to use:
        - When cutoff/box_size << 1 (cutoff much smaller than box)
        - Typically cutoff/box < 0.3 for good efficiency
        - On non-periodic axes positions must lie inside ``[0, L)`` in real
          coordinates (the caller's invariant; out-of-range positions are
          silently routed to the OOB bin)

    Example:
        ```python
        # Example: 10 Å cutoff in 50 Å box → cutoff/box = 0.2 -- Good for CellList
        nl = CellListNeighborList.new(state, lens(lambda s: s.nl_params))

        # Or, if the state implements IsNeighborListState:
        nl = CellListNeighborList.from_state(state)

        edges = nl(particles, None, systems, cutoffs, None)
        ```
    """

    avg_candidates: Capacity[int]
    avg_edges: Capacity[int]
    cells: Capacity[int]
    avg_image_candidates: Capacity[int]

    @classmethod
    def new[S](cls, state: S, lens: Lens[S, IsCellListParams]) -> CellListNeighborList:
        params = lens.get(state)
        return CellListNeighborList(
            avg_candidates=LensCapacity(
                params.avg_candidates, lens.focus(lambda x: x.avg_candidates)
            ),
            avg_edges=LensCapacity(params.avg_edges, lens.focus(lambda x: x.avg_edges)),
            avg_image_candidates=LensCapacity(
                params.avg_image_candidates,
                lens.focus(lambda x: x.avg_image_candidates),
            ),
            cells=LensCapacity(params.cells, lens.focus(lambda x: x.cells), base=1),
        )

    @classmethod
    def from_state(
        cls, state: IsNeighborListState[IsCellListParams]
    ) -> CellListNeighborList:
        return cls.new(state, lens(lambda s: s.neighborlist_params))

    @jit
    def __call__(
        self,
        lh: Table[ParticleId, NeighborListPoints],
        rh: Table[ParticleId, NeighborListPoints] | None,
        systems: Table[SystemId, NeighborListSystems],
        cutoffs: Table[SystemId, Array],
        rh_index_remap: Index[ParticleId] | None = None,
    ) -> Edges[Literal[2]]:
        rh_size = rh.size if rh is not None else lh.size
        cutoffs = Table.broadcast_to(cutoffs, systems)
        pipeline = Pipeline[Literal[2]](
            selector=CellListSelector(
                cutoffs=cutoffs,
                max_cells=self.cells,
                max_candidates=self.avg_candidates.multiply(rh_size),
                max_image_candidates=self.avg_image_candidates.multiply(rh_size),
            ),
            masks=(
                InBoundsMask(),
                InclusionMatchMask(),
                RemapDedupMask(),
                DistanceCutoffMask(cutoffs=cutoffs),
                ExclusionMask(),
            ),
            compactor=ReduceCompactor(avg_edges=self.avg_edges.multiply(rh_size)),
        )
        return pipeline(lh, rh, systems, rh_index_remap)

CellListSelector

Selector for the cell-list algorithm.

Calls the raw spatial-hash candidate emission, then replicates per image multiplicity when max(cutoff/perp) > 0.5.

Source code in src/kups/core/neighborlist/cell_list.py
@dataclass
class CellListSelector:
    """Selector for the cell-list algorithm.

    Calls the raw spatial-hash candidate emission, then replicates per image
    multiplicity when ``max(cutoff/perp) > 0.5``.
    """

    cutoffs: Table[SystemId, Array]
    max_cells: Capacity[int]
    max_candidates: Capacity[int]
    max_image_candidates: Capacity[int]

    def __call__(self, ctx: PipelineContext) -> CandidateBatch[Literal[2]]:
        candidates = _cell_list_subselect(
            ctx.lh,
            ctx.rh,
            ctx.systems,
            cutoffs=self.cutoffs.data,
            max_num_cells=self.max_cells,
            max_num_candidates=self.max_candidates,
        )
        return replicate_for_images(
            candidates,
            ctx.lh,
            ctx.rh,
            ctx.systems,
            self.cutoffs,
            self.max_image_candidates,
        )

Compactor

Bases: Protocol

Produces final Edges[D] from the accumulated keep mask.

Source code in src/kups/core/neighborlist/types.py
class Compactor[D: int](Protocol):
    """Produces final ``Edges[D]`` from the accumulated ``keep`` mask."""

    def __call__(
        self, keep: Array, batch: CandidateBatch[D], ctx: PipelineContext
    ) -> Edges[D]: ...

DenseNearestNeighborList

Dense O(N²) neighbor list respecting system boundaries.

This implementation generates all particle pairs within each system separately, avoiding cross-system interactions. Efficient when the cutoff is comparable to the box size (cutoff/box ~ 1).

Complexity: O(N² / K²) where N is total particles and K is number of systems.

Attributes:

Name Type Description
avg_candidates Capacity[int]

Capacity for candidate pair storage.

avg_edges Capacity[int]

Capacity for final edge array.

avg_image_candidates Capacity[int]

Capacity for image candidate pairs.

When to use
  • When cutoff/box_size ~ 1 (cutoff comparable to box dimensions)
  • Small box relative to cutoff (few cells would fit)
  • Non-periodic systems
Example
# Example: 15 Å cutoff in 20 Å box → cutoff/box = 0.75
nl = DenseNearestNeighborList.new(state, lens(lambda s: s.nl_params))

# Or, if the state implements IsNeighborListState:
nl = DenseNearestNeighborList.from_state(state)

edges = nl(particles, None, systems, cutoffs, None)
Source code in src/kups/core/neighborlist/dense.py
@dataclass
class DenseNearestNeighborList:
    """Dense O(N²) neighbor list respecting system boundaries.

    This implementation generates all particle pairs within each system
    separately, avoiding cross-system interactions. Efficient when the cutoff
    is comparable to the box size (cutoff/box ~ 1).

    Complexity: O(N² / K²) where N is total particles and K is number of systems.

    Attributes:
        avg_candidates: Capacity for candidate pair storage.
        avg_edges: Capacity for final edge array.
        avg_image_candidates: Capacity for image candidate pairs.

    When to use:
        - When cutoff/box_size ~ 1 (cutoff comparable to box dimensions)
        - Small box relative to cutoff (few cells would fit)
        - Non-periodic systems

    Example:
        ```python
        # Example: 15 Å cutoff in 20 Å box → cutoff/box = 0.75
        nl = DenseNearestNeighborList.new(state, lens(lambda s: s.nl_params))

        # Or, if the state implements IsNeighborListState:
        nl = DenseNearestNeighborList.from_state(state)

        edges = nl(particles, None, systems, cutoffs, None)
        ```
    """

    avg_candidates: Capacity[int]
    avg_edges: Capacity[int]
    avg_image_candidates: Capacity[int]

    @classmethod
    def new[S](
        cls, state: S, lens: Lens[S, IsDenseNeighborlistParams]
    ) -> DenseNearestNeighborList:
        params = lens.get(state)
        return DenseNearestNeighborList(
            avg_candidates=LensCapacity(
                params.avg_candidates, lens.focus(lambda x: x.avg_candidates)
            ),
            avg_edges=LensCapacity(params.avg_edges, lens.focus(lambda x: x.avg_edges)),
            avg_image_candidates=LensCapacity(
                params.avg_image_candidates,
                lens.focus(lambda x: x.avg_image_candidates),
            ),
        )

    @classmethod
    def from_state(
        cls, state: IsNeighborListState[IsDenseNeighborlistParams]
    ) -> DenseNearestNeighborList:
        return cls.new(state, lens(lambda s: s.neighborlist_params))

    def __call__(
        self,
        lh: Table[ParticleId, NeighborListPoints],
        rh: Table[ParticleId, NeighborListPoints] | None,
        systems: Table[SystemId, NeighborListSystems],
        cutoffs: Table[SystemId, Array],
        rh_index_remap: Index[ParticleId] | None = None,
    ) -> Edges[Literal[2]]:
        rh_size = rh.size if rh is not None else lh.size
        cutoffs = Table.broadcast_to(cutoffs, systems)
        pipeline = Pipeline[Literal[2]](
            selector=DenseSelector(
                cutoffs=cutoffs,
                max_candidates=self.avg_candidates.multiply(rh_size),
                max_image_candidates=self.avg_image_candidates.multiply(rh_size),
            ),
            masks=(
                InBoundsMask(),
                InclusionMatchMask(),
                RemapDedupMask(),
                DistanceCutoffMask(cutoffs=cutoffs),
                ExclusionMask(),
            ),
            compactor=ReduceCompactor(avg_edges=self.avg_edges.multiply(rh_size)),
        )
        return pipeline(lh, rh, systems, rh_index_remap)

DenseSelector

Selector for the per-system dense O(N²/K²) algorithm.

Source code in src/kups/core/neighborlist/dense.py
@dataclass
class DenseSelector:
    """Selector for the per-system dense ``O(N²/K²)`` algorithm."""

    cutoffs: Table[SystemId, Array]
    max_candidates: Capacity[int]
    max_image_candidates: Capacity[int]

    def __call__(self, ctx: PipelineContext) -> CandidateBatch[Literal[2]]:
        candidates = _dense_subselect(
            ctx.lh, ctx.rh, ctx.systems, max_num_candidates=self.max_candidates
        )
        return replicate_for_images(
            candidates,
            ctx.lh,
            ctx.rh,
            ctx.systems,
            self.cutoffs,
            self.max_image_candidates,
        )

DistanceCutoffMask

Drops candidates whose squared real-space distance exceeds cutoff².

Source code in src/kups/core/neighborlist/masks.py
@dataclass
class DistanceCutoffMask:
    """Drops candidates whose squared real-space distance exceeds ``cutoff²``."""

    cutoffs: Table[SystemId, Array]

    def __call__(self, batch: CandidateBatch, ctx: PipelineContext) -> Array:
        cutoffs = Table.broadcast_to(self.cutoffs, ctx.systems)
        shifts = batch.edges.shifts[:, 0, :]
        dist_sq = real_distance_sq(
            ctx.lh, ctx.rh, ctx.systems, batch.lh_idx, batch.rh_idx, shifts
        )
        cand_sys = ctx.lh.data.system[batch.lh_idx]
        return dist_sq < cutoffs[cand_sys] ** 2

Edges

Bases: Sliceable

Represents edges (connections) between particles in a molecular system.

An edge connects Degree particles, where degree=2 represents pairwise interactions (bonds), degree=3 represents three-body interactions (angles), etc.

For periodic systems, edges include shift vectors that indicate how many cells to traverse when computing distances between connected particles.

Class Type Parameters:

Name Bound or Constraints Description Default
Degree int

Number of particles connected by each edge (static type check)

required

Attributes:

Name Type Description
indices Index[ParticleId]

Particle indices for each edge, shape (n_edges, Degree)

shifts Array

Periodic shift vectors, shape (n_edges, Degree-1, 3). Shift vectors for the 2nd through Degree-th particle relative to the first.

Example
# Pairwise edges (bonds) between particles
edges = Edges(
    indices=jnp.array([[0, 1], [1, 2], [0, 2]]),  # 3 edges
    shifts=jnp.array([[[0, 0, 0]], [[0, 0, 0]], [[1, 0, 0]]])  # 3rd edge crosses boundary
)
Source code in src/kups/core/neighborlist/edges.py
@dataclass
class Edges[Degree: int](Sliceable):
    """Represents edges (connections) between particles in a molecular system.

    An edge connects `Degree` particles, where degree=2 represents pairwise
    interactions (bonds), degree=3 represents three-body interactions (angles), etc.

    For periodic systems, edges include shift vectors that indicate how many
    cells to traverse when computing distances between connected particles.

    Type Parameters:
        Degree: Number of particles connected by each edge (static type check)

    Attributes:
        indices: Particle indices for each edge, shape `(n_edges, Degree)`
        shifts: Periodic shift vectors, shape `(n_edges, Degree-1, 3)`.
            Shift vectors for the 2nd through Degree-th particle relative to the first.

    Example:
        ```python
        # Pairwise edges (bonds) between particles
        edges = Edges(
            indices=jnp.array([[0, 1], [1, 2], [0, 2]]),  # 3 edges
            shifts=jnp.array([[[0, 0, 0]], [[0, 0, 0]], [[1, 0, 0]]])  # 3rd edge crosses boundary
        )
        ```
    """

    # The degree is purely for type checking and does not affect runtime behavior
    indices: Index[ParticleId]  # (n_edges, Degree)
    shifts: Array  # (n_edges, Degree - 1, 3)

    def __post_init__(self):
        # Resolve the underlying array for validation
        raw = self.indices.indices if isinstance(self.indices, Index) else self.indices
        if not isinstance(raw, Array):
            return
        assert jnp.issubdtype(raw.dtype, jnp.integer), (
            f"Indices must be of integer type, got {raw.dtype}"
        )
        target_shape = (
            *self.indices.shape[:-1],
            self.indices.shape[-1] - 1 if self.indices.shape[-1] > 1 else 0,
            3,
        )
        assert self.shifts.shape == target_shape, (
            f"Shifts must have shape {target_shape}, got {self.shifts.shape}"
        )

    def difference_vectors(
        self,
        particles: Table[ParticleId, HasPositionsAndSystemIndex],
        systems: Table[SystemId, HasCell],
    ) -> Array:
        """Compute difference vectors between connected particles.

        For each edge, computes the vector from the first particle to each
        subsequent particle, accounting for periodic boundary conditions.

        Args:
            particles: Particle positions with system index information.
            systems: System data with cell for periodic boundary conditions.

        Returns:
            Array of shape `(n_edges, Degree-1, 3)` containing difference vectors.
        """

        shifts = self.absolute_shifts(particles, systems)
        pos = particles[self.indices].positions
        return pos[:, 1:] - pos[:, :1] + shifts

    def absolute_shifts(
        self,
        particles: Table[ParticleId, HasPositionsAndSystemIndex],
        systems: Table[SystemId, HasCell],
    ) -> Array:
        """Compute absolute shift vectors for all particles in each edge.

        Converts relative shifts to absolute Cartesian shift vectors.

        Args:
            particles: Particle data with system index information.
            systems: System data with cell for periodic boundary conditions.

        Returns:
            Array of shape `(n_edges, Degree-1, 3)` containing absolute shift vectors.
        """
        lattice = systems.map_data(lambda x: x.cell.vectors)
        vecs = lattice[particles[self.indices[:, 0]].system]
        return triangular_3x3_matmul(vecs[:, None], self.shifts)

    @property
    def degree(self) -> int:
        return self.indices.shape[-1]

    def __len__(self) -> int:
        return self.indices.shape[0]

absolute_shifts(particles, systems)

Compute absolute shift vectors for all particles in each edge.

Converts relative shifts to absolute Cartesian shift vectors.

Parameters:

Name Type Description Default
particles Table[ParticleId, HasPositionsAndSystemIndex]

Particle data with system index information.

required
systems Table[SystemId, HasCell]

System data with cell for periodic boundary conditions.

required

Returns:

Type Description
Array

Array of shape (n_edges, Degree-1, 3) containing absolute shift vectors.

Source code in src/kups/core/neighborlist/edges.py
def absolute_shifts(
    self,
    particles: Table[ParticleId, HasPositionsAndSystemIndex],
    systems: Table[SystemId, HasCell],
) -> Array:
    """Compute absolute shift vectors for all particles in each edge.

    Converts relative shifts to absolute Cartesian shift vectors.

    Args:
        particles: Particle data with system index information.
        systems: System data with cell for periodic boundary conditions.

    Returns:
        Array of shape `(n_edges, Degree-1, 3)` containing absolute shift vectors.
    """
    lattice = systems.map_data(lambda x: x.cell.vectors)
    vecs = lattice[particles[self.indices[:, 0]].system]
    return triangular_3x3_matmul(vecs[:, None], self.shifts)

difference_vectors(particles, systems)

Compute difference vectors between connected particles.

For each edge, computes the vector from the first particle to each subsequent particle, accounting for periodic boundary conditions.

Parameters:

Name Type Description Default
particles Table[ParticleId, HasPositionsAndSystemIndex]

Particle positions with system index information.

required
systems Table[SystemId, HasCell]

System data with cell for periodic boundary conditions.

required

Returns:

Type Description
Array

Array of shape (n_edges, Degree-1, 3) containing difference vectors.

Source code in src/kups/core/neighborlist/edges.py
def difference_vectors(
    self,
    particles: Table[ParticleId, HasPositionsAndSystemIndex],
    systems: Table[SystemId, HasCell],
) -> Array:
    """Compute difference vectors between connected particles.

    For each edge, computes the vector from the first particle to each
    subsequent particle, accounting for periodic boundary conditions.

    Args:
        particles: Particle positions with system index information.
        systems: System data with cell for periodic boundary conditions.

    Returns:
        Array of shape `(n_edges, Degree-1, 3)` containing difference vectors.
    """

    shifts = self.absolute_shifts(particles, systems)
    pos = particles[self.indices].positions
    return pos[:, 1:] - pos[:, :1] + shifts

ExclusionMask

Drops minimum-image pairs that share an exclusion segment.

Non-minimum-image periodic copies of excluded pairs survive (allowed when batch.is_minimum_image is False for that copy).

Source code in src/kups/core/neighborlist/masks.py
@dataclass
class ExclusionMask:
    """Drops minimum-image pairs that share an exclusion segment.

    Non-minimum-image periodic copies of excluded pairs survive (allowed when
    ``batch.is_minimum_image`` is False for that copy).
    """

    def __call__(self, batch: CandidateBatch, ctx: PipelineContext) -> Array:
        # batch.edges.indices carries a single set of keys; rebuild a rh-keyed
        # Index so Table indexing aligns when ctx.lh and ctx.rh have distinct keys.
        lh_view = ctx.lh[Index(ctx.lh.keys, batch.lh_idx)]
        rh_view = ctx.rh[Index(ctx.rh.keys, batch.rh_idx)]
        lh_excl, rh_excl = Index.match(lh_view.exclusion, rh_view.exclusion)
        return (lh_excl != rh_excl) | ~batch.is_minimum_image

InBoundsMask

Drops candidates whose lh/rh indices fall outside the valid inclusion-segment range.

Implements the per-side inclusion.indices < num_labels check used to guard scatter/gather lookups when the candidate buffer is padded.

Source code in src/kups/core/neighborlist/masks.py
@dataclass
class InBoundsMask:
    """Drops candidates whose lh/rh indices fall outside the valid inclusion-segment range.

    Implements the per-side ``inclusion.indices < num_labels`` check used to
    guard scatter/gather lookups when the candidate buffer is padded.
    """

    def __call__(self, batch: CandidateBatch, ctx: PipelineContext) -> Array:
        ngraphs = ctx.lh.data.inclusion.num_labels
        lh_in = (
            (ctx.lh.data.inclusion.indices < ngraphs)
            .at[batch.lh_idx]
            .get(mode="fill", fill_value=False)
        )
        rh_in = (
            (ctx.rh.data.inclusion.indices < ngraphs)
            .at[batch.rh_idx]
            .get(mode="fill", fill_value=False)
        )
        return lh_in & rh_in

InclusionGroupSelector

Pairs every particle with every other in the same inclusion segment.

Ignores the cutoff entirely. Shifts are int-typed minimum-image fractional rounds — matches today's all_connected_neighborlist (which is Ewald-only and assumed fully periodic).

Source code in src/kups/core/neighborlist/all_connected.py
@dataclass
class InclusionGroupSelector:
    """Pairs every particle with every other in the same inclusion segment.

    Ignores the cutoff entirely. Shifts are int-typed minimum-image
    fractional rounds — matches today's ``all_connected_neighborlist``
    (which is Ewald-only and assumed fully periodic).
    """

    capacity: Capacity[int]

    def __call__(self, ctx: PipelineContext) -> CandidateBatch[Literal[2]]:
        ngraphs = ctx.lh.data.inclusion.num_labels
        selection_result = subselect(
            ctx.lh.data.inclusion.indices,
            ctx.rh.data.inclusion.indices,
            output_buffer_size=self.capacity,
            num_segments=ngraphs,
        )
        candidates = Candidates(
            lhs=Index(ctx.lh.keys, selection_result.scatter_idxs),
            rhs=Index(ctx.rh.keys, selection_result.gather_idxs),
        )
        deltas = (
            ctx.lh.data.positions[candidates.lhs.indices]
            - ctx.rh.data.positions[candidates.rhs.indices]
        )
        shifts = jnp.round(deltas).astype(int)
        return candidates_to_batch(
            candidates,
            shifts,
            jnp.ones((candidates.lhs.size,), dtype=bool),
        )

InclusionMatchMask

Drops candidates whose lh/rh inclusion segments differ.

Source code in src/kups/core/neighborlist/masks.py
@dataclass
class InclusionMatchMask:
    """Drops candidates whose lh/rh inclusion segments differ."""

    def __call__(self, batch: CandidateBatch, ctx: PipelineContext) -> Array:
        lh_incl = ctx.lh.data.inclusion.indices[batch.lh_idx]
        rh_incl = ctx.rh.data.inclusion.indices[batch.rh_idx]
        return lh_incl == rh_incl

IsAllDenseNeighborListParams

Bases: Protocol

Protocol for parameters required by AllDenseNearestNeighborList.

Source code in src/kups/core/neighborlist/all_dense.py
class IsAllDenseNeighborListParams(Protocol):
    """Protocol for parameters required by ``AllDenseNearestNeighborList``."""

    @property
    def avg_edges(self) -> int: ...
    @property
    def avg_image_candidates(self) -> int: ...

IsCellListParams

Bases: Protocol

Protocol for parameters required by CellListNeighborList.

Source code in src/kups/core/neighborlist/cell_list.py
class IsCellListParams(Protocol):
    """Protocol for parameters required by ``CellListNeighborList``."""

    @property
    def avg_candidates(self) -> int: ...
    @property
    def avg_edges(self) -> int: ...
    @property
    def cells(self) -> int: ...
    @property
    def avg_image_candidates(self) -> int: ...

IsDenseNeighborlistParams

Bases: Protocol

Protocol for parameters required by DenseNearestNeighborList.

Source code in src/kups/core/neighborlist/dense.py
class IsDenseNeighborlistParams(Protocol):
    """Protocol for parameters required by ``DenseNearestNeighborList``."""

    @property
    def avg_candidates(self) -> int: ...
    @property
    def avg_edges(self) -> int: ...
    @property
    def avg_image_candidates(self) -> int: ...

IsNeighborListState

Bases: Protocol

Protocol for states that expose neighbor list parameters.

A state satisfying this protocol can be passed to from_state() on any neighbor list class. The type parameter P determines which neighbor list types the state can construct (e.g., IsAllDenseNeighborListParams, IsDenseNeighborlistParams, IsCellListParams, or IsUniversalNeighborlistParams).

Source code in src/kups/core/neighborlist/types.py
class IsNeighborListState[P](Protocol):
    """Protocol for states that expose neighbor list parameters.

    A state satisfying this protocol can be passed to ``from_state()`` on any
    neighbor list class. The type parameter ``P`` determines which neighbor
    list types the state can construct (e.g., ``IsAllDenseNeighborListParams``,
    ``IsDenseNeighborlistParams``, ``IsCellListParams``, or
    ``IsUniversalNeighborlistParams``).
    """

    @property
    def neighborlist_params(self) -> P: ...

IsUniversalNeighborlistParams

Bases: Protocol

Protocol for parameters required by any neighbor list implementation.

A superset of IsAllDenseNeighborListParams, IsDenseNeighborlistParams, and IsCellListParams. Satisfying this protocol allows constructing any of the three neighbor list types.

Source code in src/kups/core/neighborlist/types.py
class IsUniversalNeighborlistParams(Protocol):
    """Protocol for parameters required by any neighbor list implementation.

    A superset of ``IsAllDenseNeighborListParams``, ``IsDenseNeighborlistParams``,
    and ``IsCellListParams``. Satisfying this protocol allows constructing any
    of the three neighbor list types.
    """

    @property
    def avg_edges(self) -> int: ...
    @property
    def avg_candidates(self) -> int: ...
    @property
    def avg_image_candidates(self) -> int: ...
    @property
    def cells(self) -> int: ...

Mask

Bases: Protocol

Returns this criterion's bool array; pipeline conjuncts the results.

Degree-agnostic at the type level — the pair-only masks shipped here annotate batch: CandidateBatch (any D) and internally assume D == 2 via batch.lh_idx / batch.rh_idx. Higher-degree masks would not use those properties.

Cannot change batch.edges, batch.is_minimum_image, or the candidate count. Pure (batch, ctx) -> Array.

Source code in src/kups/core/neighborlist/types.py
class Mask(Protocol):
    """Returns this criterion's bool array; pipeline conjuncts the results.

    Degree-agnostic at the type level — the pair-only masks shipped here
    annotate ``batch: CandidateBatch`` (any ``D``) and internally assume
    ``D == 2`` via ``batch.lh_idx`` / ``batch.rh_idx``. Higher-degree masks
    would not use those properties.

    Cannot change ``batch.edges``, ``batch.is_minimum_image``, or the
    candidate count. Pure ``(batch, ctx) -> Array``.
    """

    def __call__(self, batch: CandidateBatch, ctx: PipelineContext) -> Array: ...

MaskOnlyCompactor

In-place compaction: failing entries become OOB indices and zero shifts.

No size change; preserves the candidate count from the selector. Applies the shared rh→lh remap so the output indices live in lh-space — matching ReduceCompactor's contract.

Source code in src/kups/core/neighborlist/compact.py
@dataclass
class MaskOnlyCompactor:
    """In-place compaction: failing entries become OOB indices and zero shifts.

    No size change; preserves the candidate count from the selector. Applies
    the shared rh→lh remap so the output indices live in lh-space — matching
    ``ReduceCompactor``'s contract.
    """

    def __call__(
        self,
        keep: Array,
        batch: CandidateBatch[Literal[2]],
        ctx: PipelineContext,
    ) -> Edges[Literal[2]]:
        oob = ctx.lh.size
        rh_idx_remapped = remap_rh_to_lh(batch.rh_idx, ctx)
        indices_in = jnp.stack([batch.lh_idx, rh_idx_remapped], axis=-1)
        indices = where_broadcast_last(keep, indices_in, oob)
        shifts = where_broadcast_last(keep, batch.edges.shifts, 0)
        return Edges(Index(batch.edges.indices.keys, indices), shifts)

NearestNeighborList

Bases: Protocol

Protocol for neighbor list construction algorithms.

Implementations find pairs of particles within a cutoff distance, handling periodic boundary conditions and inclusion/exclusion masks.

Source code in src/kups/core/neighborlist/types.py
class NearestNeighborList(Protocol):
    """Protocol for neighbor list construction algorithms.

    Implementations find pairs of particles within a cutoff distance, handling
    periodic boundary conditions and inclusion/exclusion masks.
    """

    def __call__[P: NeighborListPoints](
        self,
        lh: Table[ParticleId, P],
        rh: Table[ParticleId, P] | None,
        systems: Table[SystemId, NeighborListSystems],
        cutoffs: Table[SystemId, Array],
        rh_index_remap: Index[ParticleId] | None = None,
    ) -> Edges[Literal[2]]:
        """Find all particle pairs within the cutoff distance.

        Args:
            lh: Left-hand particles to find neighbors for
            rh: Right-hand particles to search within (or None for self-neighbors)
            systems: Indexed system data with cell information
            cutoffs: Indexed cutoff data per system
            rh_index_remap: Optional index mapping rh particles back to lh
                particle IDs for self-interaction exclusion. When ``None``,
                rh is treated as disjoint from lh.

        Returns:
            Edges connecting particle pairs within cutoff
        """
        ...

__call__(lh, rh, systems, cutoffs, rh_index_remap=None)

Find all particle pairs within the cutoff distance.

Parameters:

Name Type Description Default
lh Table[ParticleId, P]

Left-hand particles to find neighbors for

required
rh Table[ParticleId, P] | None

Right-hand particles to search within (or None for self-neighbors)

required
systems Table[SystemId, NeighborListSystems]

Indexed system data with cell information

required
cutoffs Table[SystemId, Array]

Indexed cutoff data per system

required
rh_index_remap Index[ParticleId] | None

Optional index mapping rh particles back to lh particle IDs for self-interaction exclusion. When None, rh is treated as disjoint from lh.

None

Returns:

Type Description
Edges[Literal[2]]

Edges connecting particle pairs within cutoff

Source code in src/kups/core/neighborlist/types.py
def __call__[P: NeighborListPoints](
    self,
    lh: Table[ParticleId, P],
    rh: Table[ParticleId, P] | None,
    systems: Table[SystemId, NeighborListSystems],
    cutoffs: Table[SystemId, Array],
    rh_index_remap: Index[ParticleId] | None = None,
) -> Edges[Literal[2]]:
    """Find all particle pairs within the cutoff distance.

    Args:
        lh: Left-hand particles to find neighbors for
        rh: Right-hand particles to search within (or None for self-neighbors)
        systems: Indexed system data with cell information
        cutoffs: Indexed cutoff data per system
        rh_index_remap: Optional index mapping rh particles back to lh
            particle IDs for self-interaction exclusion. When ``None``,
            rh is treated as disjoint from lh.

    Returns:
        Edges connecting particle pairs within cutoff
    """
    ...

Pipeline

Selector → mask sequence → compactor.

Attributes:

Name Type Description
selector CandidateSelector[D]

Produces a CandidateBatch[D] (handles PBC replication).

masks tuple[Mask, ...]

Tuple of mask criteria over CandidateBatch[D]; results are conjuncted via &.

compactor Compactor[D]

Produces the final Edges[D] from the accumulated mask.

Source code in src/kups/core/neighborlist/pipeline.py
@dataclass
class Pipeline[D: int]:
    """Selector → mask sequence → compactor.

    Attributes:
        selector: Produces a ``CandidateBatch[D]`` (handles PBC replication).
        masks: Tuple of mask criteria over ``CandidateBatch[D]``; results
            are conjuncted via ``&``.
        compactor: Produces the final ``Edges[D]`` from the accumulated mask.
    """

    selector: CandidateSelector[D]
    masks: tuple[Mask, ...] = field(static=True)
    compactor: Compactor[D]

    def __call__(
        self,
        lh: Table[ParticleId, NeighborListPoints],
        rh: Table[ParticleId, NeighborListPoints] | None,
        systems: Table[SystemId, NeighborListSystems],
        rh_index_remap: Index[ParticleId] | None = None,
    ) -> Edges[D]:
        ctx = _prepare(lh, rh, systems, rh_index_remap)
        batch = self.selector(ctx)
        keep = jnp.ones((batch.lh_idx.size,), dtype=bool)
        for mask in self.masks:
            keep &= mask(batch, ctx)
        return self.compactor(keep, batch, ctx)

PipelineContext

Read-only inputs shared by every mask and the compactor.

Positions in lh and rh are in fractional coordinates (transformed by [_prepare][kups.core.neighborlist.pipeline._prepare]). There is no out_of_bounds field — masks/compactors that need an OOB sentinel compute max(ctx.lh.size, ctx.rh.size) locally.

Attributes:

Name Type Description
lh Table[ParticleId, NeighborListPoints]

Left-hand particle table in fractional coords.

rh Table[ParticleId, NeighborListPoints]

Right-hand particle table in fractional coords (== lh when the caller passed rh=None).

systems Table[SystemId, NeighborListSystems]

Indexed system data with cell information.

rh_index_remap Array | None

Raw remap array mapping rh-positions to lh-space particle IDs, or None when no remap was supplied. Empty remaps are replaced with a one-element OOB-sentinel array by _prepare so downstream lookups never see a zero-length array.

Source code in src/kups/core/neighborlist/types.py
@dataclass
class PipelineContext:
    """Read-only inputs shared by every mask and the compactor.

    Positions in ``lh`` and ``rh`` are in **fractional** coordinates
    (transformed by [`_prepare`][kups.core.neighborlist.pipeline._prepare]).
    There is no ``out_of_bounds`` field — masks/compactors that need an
    OOB sentinel compute ``max(ctx.lh.size, ctx.rh.size)`` locally.

    Attributes:
        lh: Left-hand particle table in fractional coords.
        rh: Right-hand particle table in fractional coords (== ``lh`` when
            the caller passed ``rh=None``).
        systems: Indexed system data with cell information.
        rh_index_remap: Raw remap array mapping rh-positions to lh-space
            particle IDs, or ``None`` when no remap was supplied. Empty
            remaps are replaced with a one-element OOB-sentinel array by
            ``_prepare`` so downstream lookups never see a zero-length array.
    """

    lh: Table[ParticleId, NeighborListPoints]
    rh: Table[ParticleId, NeighborListPoints]
    systems: Table[SystemId, NeighborListSystems]
    rh_index_remap: Array | None

PrecomputedEdgesSelector

Selector that wraps precomputed Edges for both refine variants.

Precomputed edges use the same index convention as the call that produced them: public lh-space edges when an rh remap is supplied, and raw rh-space indices for a disjoint rh without a remap. Remapped rh rows are overlaid onto lh before this selector runs.

Attributes:

Name Type Description
candidates Edges[Literal[2]]

Precomputed edges (indices in lh-space).

recompute_mic_shifts bool

When True, drop the precomputed shifts and recompute minimum-image shifts on the current positions (RefineCutoffNeighborList — the precomputed shifts may be stale relative to the current cell). When False, reuse candidates.shifts directly (RefineMaskNeighborList). is_minimum_image is always all-True (no image replication).

Source code in src/kups/core/neighborlist/refine.py
@dataclass
class PrecomputedEdgesSelector:
    """Selector that wraps precomputed ``Edges`` for both refine variants.

    Precomputed edges use the same index convention as the call that produced
    them: public lh-space edges when an ``rh`` remap is supplied, and raw
    rh-space indices for a disjoint ``rh`` without a remap. Remapped ``rh``
    rows are overlaid onto ``lh`` before this selector runs.

    Attributes:
        candidates: Precomputed edges (indices in lh-space).
        recompute_mic_shifts: When ``True``, drop the precomputed shifts and
            recompute minimum-image shifts on the current positions
            (``RefineCutoffNeighborList`` — the precomputed shifts may be
            stale relative to the current cell). When ``False``, reuse
            ``candidates.shifts`` directly (``RefineMaskNeighborList``).
            ``is_minimum_image`` is always all-True (no image replication).
    """

    candidates: Edges[Literal[2]]
    recompute_mic_shifts: bool = field(static=True, default=False)

    def __call__(self, ctx: PipelineContext) -> CandidateBatch[Literal[2]]:
        if self.recompute_mic_shifts:
            indices = self.candidates.indices.indices
            raw_candidates = Candidates(
                lhs=Index(ctx.lh.keys, indices[:, 0]),
                rhs=Index(ctx.rh.keys, indices[:, 1]),
            )
            return make_batch_with_mic(raw_candidates, ctx.lh, ctx.rh, ctx.systems)
        indices = self.candidates.indices.indices
        edges = Edges(Index(ctx.lh.keys, indices), self.candidates.shifts)
        return CandidateBatch(
            edges=edges,
            is_minimum_image=jnp.ones((len(self.candidates),), dtype=bool),
        )

ReduceCompactor

Compacts surviving candidates to a size-bounded Edges[2].

Applies the shared rh→lh remap, then — when ctx.rh_index_remap is set — mirrors each surviving edge with its reverse (concatenating shifts with their negatives). The mirror restores the symmetry that the paired RemapDedupMask removed upstream.

Source code in src/kups/core/neighborlist/compact.py
@dataclass
class ReduceCompactor:
    """Compacts surviving candidates to a size-bounded ``Edges[2]``.

    Applies the shared rh→lh remap, then — when ``ctx.rh_index_remap`` is
    set — mirrors each surviving edge with its reverse (concatenating shifts
    with their negatives). The mirror restores the symmetry that the paired
    ``RemapDedupMask`` removed upstream.
    """

    avg_edges: Capacity[int]

    def __call__(
        self,
        keep: Array,
        batch: CandidateBatch[Literal[2]],
        ctx: PipelineContext,
    ) -> Edges[Literal[2]]:
        oob = max(ctx.lh.size, ctx.rh.size)
        max_edges = self.avg_edges.generate_assertion(keep.sum())
        sort_idxs = jnp.where(keep, size=max_edges.size, fill_value=keep.size)[0]
        shifts = batch.edges.shifts.at[sort_idxs].get(
            mode="fill", fill_value=0, indices_are_sorted=True
        )
        rh_idx_remapped = remap_rh_to_lh(batch.rh_idx, ctx)
        lh_edge = batch.lh_idx.at[sort_idxs].get(
            mode="fill", fill_value=oob, indices_are_sorted=True
        )
        rh_edge = rh_idx_remapped.at[sort_idxs].get(
            mode="fill", fill_value=oob, indices_are_sorted=True
        )

        if ctx.rh_index_remap is not None:
            shifts = jnp.concatenate([shifts, -shifts], axis=0)
            lh_edge, rh_edge = (
                jnp.concatenate([lh_edge, rh_edge], axis=0),
                jnp.concatenate([rh_edge, lh_edge], axis=0),
            )

        return Edges(
            Index(batch.edges.indices.keys, jnp.stack([lh_edge, rh_edge], axis=-1)),
            shifts,
        )

RefineCutoffNeighborList

Refine precomputed edges by re-checking distances with new cutoffs.

This neighbor list takes an existing set of candidate edges and filters them by computing actual distances and comparing to cutoffs. Enables sharing a single conservative neighbor list across multiple potentials with different cutoff distances.

Key benefit: Compute expensive neighbor list once with maximum cutoff, then refine for each potential with its specific cutoff (e.g., Lennard-Jones at 10 Å, Coulomb at 15 Å).

Attributes:

Name Type Description
candidates Edges[Literal[2]]

Precomputed edges to refine (should be conservative/over-inclusive).

avg_edges Capacity[int]

Capacity for output edge array.

Use cases
  • Multiple potentials sharing one neighbor list with different cutoffs
  • Multi-stage neighbor list construction (coarse then fine)
  • Adaptive cutoffs that change during simulation
  • Using a static "super" neighbor list with varying actual cutoffs
Example
# Compute base neighbor list once with maximum cutoff
max_cutoff = 15.0  # Maximum of all potential cutoffs
base_edges = base_nl(particles, None, cells, max_cutoff, None)

# Share across potentials with different cutoffs
lj_nl = RefineCutoffNeighborList(candidates=base_edges, avg_edges=cap1)
lj_edges = lj_nl(particles, None, cells, cutoff=10.0, None)  # LJ cutoff

coulomb_nl = RefineCutoffNeighborList(candidates=base_edges, avg_edges=cap2)
coulomb_edges = coulomb_nl(particles, None, cells, cutoff=15.0, None)  # Coulomb cutoff
Source code in src/kups/core/neighborlist/refine.py
@dataclass
class RefineCutoffNeighborList:
    """Refine precomputed edges by re-checking distances with new cutoffs.

    This neighbor list takes an existing set of candidate edges and filters them
    by computing actual distances and comparing to cutoffs. Enables sharing a
    single conservative neighbor list across multiple potentials with different
    cutoff distances.

    **Key benefit**: Compute expensive neighbor list once with maximum cutoff,
    then refine for each potential with its specific cutoff (e.g., Lennard-Jones
    at 10 Å, Coulomb at 15 Å).

    Attributes:
        candidates: Precomputed edges to refine (should be conservative/over-inclusive).
        avg_edges: Capacity for output edge array.

    Use cases:
        - Multiple potentials sharing one neighbor list with different cutoffs
        - Multi-stage neighbor list construction (coarse then fine)
        - Adaptive cutoffs that change during simulation
        - Using a static "super" neighbor list with varying actual cutoffs

    Example:
        ```python
        # Compute base neighbor list once with maximum cutoff
        max_cutoff = 15.0  # Maximum of all potential cutoffs
        base_edges = base_nl(particles, None, cells, max_cutoff, None)

        # Share across potentials with different cutoffs
        lj_nl = RefineCutoffNeighborList(candidates=base_edges, avg_edges=cap1)
        lj_edges = lj_nl(particles, None, cells, cutoff=10.0, None)  # LJ cutoff

        coulomb_nl = RefineCutoffNeighborList(candidates=base_edges, avg_edges=cap2)
        coulomb_edges = coulomb_nl(particles, None, cells, cutoff=15.0, None)  # Coulomb cutoff
        ```
    """

    candidates: Edges[Literal[2]]
    avg_edges: Capacity[int]

    @jit
    def __call__(
        self,
        lh: Table[ParticleId, NeighborListPoints],
        rh: Table[ParticleId, NeighborListPoints] | None,
        systems: Table[SystemId, NeighborListSystems],
        cutoffs: Table[SystemId, Array],
        rh_index_remap: Index[ParticleId] | None = None,
    ) -> Edges[Literal[2]]:
        resolved_lh, resolved_rh = _resolve_precomputed_inputs(lh, rh, rh_index_remap)
        rh_size = rh.size if rh is not None else lh.size
        cutoffs = Table.broadcast_to(cutoffs, systems)
        pipeline = Pipeline[Literal[2]](
            selector=PrecomputedEdgesSelector(
                self.candidates, recompute_mic_shifts=True
            ),
            masks=(
                InBoundsMask(),
                InclusionMatchMask(),
                DistanceCutoffMask(cutoffs=cutoffs),
                ExclusionMask(),
            ),
            compactor=ReduceCompactor(avg_edges=self.avg_edges.multiply(rh_size)),
        )
        return pipeline(resolved_lh, resolved_rh, systems, None)

RefineMaskNeighborList

Refine a precomputed neighbor list by applying inclusion/exclusion masks.

This neighbor list takes an existing set of candidate edges and filters them based on segmentation masks, without recomputing distances. Enables sharing a single base neighbor list across multiple potentials with different interaction rules.

Key benefit: Compute expensive neighbor list once, apply different masks for different potentials (e.g., Lennard-Jones excludes 1-4 interactions, Coulomb has different exclusions).

Attributes:

Name Type Description
candidates Edges[Literal[2]]

Precomputed edges to refine

Use cases
  • Multiple potentials sharing one neighbor list with different exclusions
  • Excluding bonded pairs (1-2, 1-3, 1-4) from non-bonded interactions
  • Applying group-specific interaction rules
  • Multi-scale simulations with different interaction levels
Example
# Compute base neighbor list once
base_edges = base_nl(particles, None, cells, cutoffs, None)

# Share across potentials with different masks
lj_nl = RefineMaskNeighborList(candidates=base_edges)
lj_edges = lj_nl(lj_particles, None, cells, cutoffs, None)  # 1-4 exclusions

coulomb_nl = RefineMaskNeighborList(candidates=base_edges)
coulomb_edges = coulomb_nl(coulomb_particles, None, cells, cutoffs, None)  # 1-2 exclusions only
Source code in src/kups/core/neighborlist/refine.py
@dataclass
class RefineMaskNeighborList:
    """Refine a precomputed neighbor list by applying inclusion/exclusion masks.

    This neighbor list takes an existing set of candidate edges and filters them
    based on segmentation masks, without recomputing distances. Enables sharing
    a single base neighbor list across multiple potentials with different
    interaction rules.

    **Key benefit**: Compute expensive neighbor list once, apply different masks
    for different potentials (e.g., Lennard-Jones excludes 1-4 interactions,
    Coulomb has different exclusions).

    Attributes:
        candidates: Precomputed edges to refine

    Use cases:
        - Multiple potentials sharing one neighbor list with different exclusions
        - Excluding bonded pairs (1-2, 1-3, 1-4) from non-bonded interactions
        - Applying group-specific interaction rules
        - Multi-scale simulations with different interaction levels

    Example:
        ```python
        # Compute base neighbor list once
        base_edges = base_nl(particles, None, cells, cutoffs, None)

        # Share across potentials with different masks
        lj_nl = RefineMaskNeighborList(candidates=base_edges)
        lj_edges = lj_nl(lj_particles, None, cells, cutoffs, None)  # 1-4 exclusions

        coulomb_nl = RefineMaskNeighborList(candidates=base_edges)
        coulomb_edges = coulomb_nl(coulomb_particles, None, cells, cutoffs, None)  # 1-2 exclusions only
        ```
    """

    candidates: Edges[Literal[2]]

    @jit
    def __call__(
        self,
        lh: Table[ParticleId, NeighborListPoints],
        rh: Table[ParticleId, NeighborListPoints] | None,
        systems: Table[SystemId, NeighborListSystems],
        cutoffs: Table[SystemId, Array],
        rh_index_remap: Index[ParticleId] | None = None,
    ) -> Edges[Literal[2]]:
        resolved_lh, resolved_rh = _resolve_precomputed_inputs(lh, rh, rh_index_remap)
        pipeline = Pipeline[Literal[2]](
            selector=PrecomputedEdgesSelector(self.candidates),
            masks=(InBoundsMask(), InclusionMatchMask(), ExclusionMask()),
            compactor=MaskOnlyCompactor(),
        )
        return pipeline(resolved_lh, resolved_rh, systems, None)

RemapDedupMask

Deduplicate the rh→lh remapped subset.

When ctx.rh_index_remap is set, rh is a subset of lh and each rh-position maps to an lh-position via rh_index_remap. We then keep only one direction per pair: edges where lh_idx is not in the remap (i.e., the pair is lh-only) or where lh_idx >= remapped_rh.

Returns all-True when no remap is in effect.

Source code in src/kups/core/neighborlist/masks.py
@dataclass
class RemapDedupMask:
    """Deduplicate the rh→lh remapped subset.

    When ``ctx.rh_index_remap`` is set, ``rh`` is a subset of ``lh`` and each
    rh-position maps to an lh-position via ``rh_index_remap``. We then keep
    only one direction per pair: edges where ``lh_idx`` is **not** in the
    remap (i.e., the pair is lh-only) or where ``lh_idx >= remapped_rh``.

    Returns all-True when no remap is in effect.
    """

    def __call__(self, batch: CandidateBatch, ctx: PipelineContext) -> Array:
        if ctx.rh_index_remap is None:
            return jnp.ones((batch.lh_idx.size,), dtype=bool)
        oob = max(ctx.lh.size, ctx.rh.size)
        rh_remapped = ctx.rh_index_remap.at[batch.rh_idx].get(
            mode="fill", fill_value=oob
        )
        return ~isin(batch.lh_idx, ctx.rh_index_remap, ctx.lh.size) | (
            batch.lh_idx >= rh_remapped
        )

UniversalNeighborlistParameters

Concrete parameter dataclass satisfying IsUniversalNeighborlistParams.

Holds the capacity hints needed by every neighbor list implementation. Use the estimate() classmethod to compute reasonable initial values from system geometry rather than guessing manually.

Attributes:

Name Type Description
avg_edges int

Average number of edges per particle (for edge capacity).

avg_candidates int

Average number of candidate pairs per particle.

avg_image_candidates int

Average number of image candidate pairs per particle.

cells int

Maximum number of spatial hash cells across all systems.

Source code in src/kups/core/neighborlist/parameters.py
@dataclass
class UniversalNeighborlistParameters:
    """Concrete parameter dataclass satisfying ``IsUniversalNeighborlistParams``.

    Holds the capacity hints needed by every neighbor list implementation.
    Use the ``estimate()`` classmethod to compute reasonable initial values
    from system geometry rather than guessing manually.

    Attributes:
        avg_edges: Average number of edges per particle (for edge capacity).
        avg_candidates: Average number of candidate pairs per particle.
        avg_image_candidates: Average number of image candidate pairs per particle.
        cells: Maximum number of spatial hash cells across all systems.
    """

    avg_edges: int = field(static=True)
    avg_candidates: int = field(static=True)
    avg_image_candidates: int = field(static=True)
    cells: int = field(static=True)

    @classmethod
    @no_jax_tracing
    def estimate(
        cls,
        particles_per_system: Table[SystemId, Array],
        systems: Table[SystemId, NeighborListSystems],
        cutoffs: Table[SystemId, Array],
        *,
        base: float = 2,
        multiplier: float = 1.0,
    ) -> UniversalNeighborlistParameters:
        """Estimate parameters for all neighbor list types from system geometry.

        Computes conservative initial capacities based on particle density
        and cutoff radii. The estimates are rounded up to the next power of
        ``base`` to amortize future resizing.

        Args:
            particles_per_system: Number of particles per system.
            systems: System data with cell information.
            cutoffs: Cutoff distance per system.
            base: Base for power-of rounding (default 2).
            multiplier: Safety factor applied to the estimate (default 1.0).

        Returns:
            A ``UniversalNeighborlistParameters`` instance with estimated values.
        """
        sys = Table.join(systems, particles_per_system, cutoffs)
        total_candidates = total_edges = max_cells = 0
        for _, (s, n_p, c) in sys:
            n_bins = num_cells(s, c).prod()
            total_candidates += min(n_p / n_bins * (3**3), n_p)
            total_edges += _estimate_avg_num_edges(
                n_p, s.cell.volume, c, base, multiplier
            )
            max_cells = max(n_bins, max_cells)
        total_candidates = next_higher_power(
            jnp.array(total_candidates * multiplier / sys.size), base=base
        )
        return UniversalNeighborlistParameters(
            avg_edges=int(total_edges // sys.size),
            avg_candidates=int(total_candidates),
            avg_image_candidates=int(total_candidates),  # Image candidates ~ candidates
            cells=int(max_cells),
        )

estimate(particles_per_system, systems, cutoffs, *, base=2, multiplier=1.0) classmethod

Estimate parameters for all neighbor list types from system geometry.

Computes conservative initial capacities based on particle density and cutoff radii. The estimates are rounded up to the next power of base to amortize future resizing.

Parameters:

Name Type Description Default
particles_per_system Table[SystemId, Array]

Number of particles per system.

required
systems Table[SystemId, NeighborListSystems]

System data with cell information.

required
cutoffs Table[SystemId, Array]

Cutoff distance per system.

required
base float

Base for power-of rounding (default 2).

2
multiplier float

Safety factor applied to the estimate (default 1.0).

1.0

Returns:

Type Description
UniversalNeighborlistParameters

A UniversalNeighborlistParameters instance with estimated values.

Source code in src/kups/core/neighborlist/parameters.py
@classmethod
@no_jax_tracing
def estimate(
    cls,
    particles_per_system: Table[SystemId, Array],
    systems: Table[SystemId, NeighborListSystems],
    cutoffs: Table[SystemId, Array],
    *,
    base: float = 2,
    multiplier: float = 1.0,
) -> UniversalNeighborlistParameters:
    """Estimate parameters for all neighbor list types from system geometry.

    Computes conservative initial capacities based on particle density
    and cutoff radii. The estimates are rounded up to the next power of
    ``base`` to amortize future resizing.

    Args:
        particles_per_system: Number of particles per system.
        systems: System data with cell information.
        cutoffs: Cutoff distance per system.
        base: Base for power-of rounding (default 2).
        multiplier: Safety factor applied to the estimate (default 1.0).

    Returns:
        A ``UniversalNeighborlistParameters`` instance with estimated values.
    """
    sys = Table.join(systems, particles_per_system, cutoffs)
    total_candidates = total_edges = max_cells = 0
    for _, (s, n_p, c) in sys:
        n_bins = num_cells(s, c).prod()
        total_candidates += min(n_p / n_bins * (3**3), n_p)
        total_edges += _estimate_avg_num_edges(
            n_p, s.cell.volume, c, base, multiplier
        )
        max_cells = max(n_bins, max_cells)
    total_candidates = next_higher_power(
        jnp.array(total_candidates * multiplier / sys.size), base=base
    )
    return UniversalNeighborlistParameters(
        avg_edges=int(total_edges // sys.size),
        avg_candidates=int(total_candidates),
        avg_image_candidates=int(total_candidates),  # Image candidates ~ candidates
        cells=int(max_cells),
    )

all_connected_neighborlist(lh, rh, systems, cutoffs, rh_index_remap=None)

Neighbor list connecting all pairs sharing the same inclusion segment, ignoring distance.

Connects every particle pair that belongs to the same inclusion segment and has differing exclusion segment IDs. The cutoff is ignored for neighbor selection; the cell is used only to compute minimum-image shifts.

Requires max_count to be set on the inclusion Index.

Source code in src/kups/core/neighborlist/all_connected.py
def all_connected_neighborlist(
    lh: Table[ParticleId, NeighborListPoints],
    rh: Table[ParticleId, NeighborListPoints] | None,
    systems: Table[SystemId, NeighborListSystems],
    cutoffs: Table[SystemId, Array],
    rh_index_remap: Index[ParticleId] | None = None,
) -> Edges[Literal[2]]:
    """Neighbor list connecting all pairs sharing the same inclusion segment, ignoring distance.

    Connects every particle pair that belongs to the same inclusion segment and has
    differing exclusion segment IDs. The cutoff is ignored for neighbor selection;
    the cell is used only to compute minimum-image shifts.

    Requires ``max_count`` to be set on the inclusion ``Index``.
    """
    if rh is None:
        rh = lh
        rh_index_remap = Index.arange(len(lh), label=ParticleId)

    max_count = lh.data.inclusion.max_count
    assert max_count is not None, "inclusion.max_count must be set"
    capacity = FixedCapacity(max_count).multiply(min(lh.size, rh.size))

    pipeline = Pipeline[Literal[2]](
        selector=InclusionGroupSelector(capacity=capacity),
        masks=(ExclusionMask(), RemapDedupMask()),
        compactor=ReduceCompactor(avg_edges=capacity),
    )
    return pipeline(lh, rh, systems, rh_index_remap)

neighborlist_changes(neighborlist, lh, rh, systems, cutoffs, compaction=0.5)

Compute added/removed edges from a particle change in a single call.

Appends proposed positions to the particle array and queries both old and new interactions at once, then splits the result by filtering edge indices into removed (before) and added (after) sets.

Parameters:

Name Type Description Default
neighborlist NearestNeighborList

Neighbor list implementation.

required
lh Table[ParticleId, NeighborListPoints]

Full original particle table.

required
rh WithIndices[ParticleId, Table[ParticleId, NeighborListPoints]]

Proposed changes — rh.indices maps entries to particle IDs in lh, rh.data holds the new particle data.

required
systems Table[SystemId, NeighborListSystems]

Per-system data (cells, etc.).

required
cutoffs Table[SystemId, Array]

Per-system cutoff distances.

required
compaction float

Fraction of total edges allocated per output (0–1). 0.5 means each of added/removed gets half the buffer. 1.0 means no compaction — full buffer with masking only.

0.5

Returns:

Type Description
NeighborListChangesResult

NeighborListChangesResult(added, removed).

Source code in src/kups/core/neighborlist/changes.py
@partial(jit, static_argnames=("compaction",))
def neighborlist_changes(
    neighborlist: NearestNeighborList,
    lh: Table[ParticleId, NeighborListPoints],
    rh: WithIndices[ParticleId, Table[ParticleId, NeighborListPoints]],
    systems: Table[SystemId, NeighborListSystems],
    cutoffs: Table[SystemId, Array],
    compaction: float = 0.5,
) -> NeighborListChangesResult:
    """Compute added/removed edges from a particle change in a single call.

    Appends proposed positions to the particle array and queries both old
    and new interactions at once, then splits the result by filtering
    edge indices into ``removed`` (before) and ``added`` (after) sets.

    Args:
        neighborlist: Neighbor list implementation.
        lh: Full original particle table.
        rh: Proposed changes — ``rh.indices`` maps entries to particle IDs
            in ``lh``, ``rh.data`` holds the new particle data.
        systems: Per-system data (cells, etc.).
        cutoffs: Per-system cutoff distances.
        compaction: Fraction of total edges allocated per output (0–1).
            0.5 means each of added/removed gets half the buffer.
            1.0 means no compaction — full buffer with masking only.

    Returns:
        ``NeighborListChangesResult(added, removed)``.
    """
    N, k = lh.size, rh.data.size
    p_idx = rh.indices.indices_in(lh.keys)

    # Build a single query with new particles on the left-hand side
    # (original particles + new particles) and both old and new particles
    # on the right-hand side (old positions at changed indices + new positions).
    lh_combined = Table.union((lh, rh.data))
    rh_combined = Table.union((Table.arange(lh[rh.indices], label=ParticleId), rh.data))
    combined_remap = Index(
        lh_combined.keys, jnp.concatenate([p_idx, jnp.arange(k) + N])
    )

    # single neighborlist call
    all_edges = neighborlist(lh_combined, rh_combined, systems, cutoffs, combined_remap)

    # split into removed / added
    raw = all_edges.indices.indices  # (n_edges, 2)
    c0, c1 = raw[:, 0], raw[:, 1]
    # Removed mask checks for edges that exist in the original set (both indices < N).
    removed_mask = (c0 < N) & (c1 < N)

    # is_stale mask checks that both edges need to be in the original set
    # or one needs to be in the original set and the other needs to be in the new set.
    is_stale = isin(c0, p_idx, N + k) & (c0 < N) | isin(c1, p_idx, N + k) & (c1 < N)
    # Added mask checks for edges that involve at least one new particle.
    added_mask = (c0 < N + k) & (c1 < N + k) & ((c0 >= N) | (c1 >= N)) & ~is_stale

    # remap appended indices N+m -> p_idx[m]
    remapped = jnp.where(raw >= N, p_idx[raw - N], raw)

    # compact each output
    n_total = raw.shape[0]
    shifts = all_edges.shifts

    def _mask_only(mask: Array, indices: Array, shifts: Array) -> Edges[Literal[2]]:
        idx = where_broadcast_last(mask, indices, N)
        sh = where_broadcast_last(mask, shifts, 0)
        return Edges(Index(lh.keys, idx), sh)

    def _compact(mask: Array, indices: Array, label: str) -> Edges[Literal[2]]:
        count = mask.sum()
        runtime_assert(
            count <= capacity,
            f"neighborlist_changes: {label} edges ({{count}}) exceed "
            f"capacity ({{capacity}})",
            fmt_args={"count": count, "capacity": jnp.array(capacity)},
        )
        sel: Array = jnp.where(mask, size=capacity, fill_value=n_total - 1)[0]
        valid = mask.at[sel].get(mode="fill", fill_value=False)
        return _mask_only(valid, indices[sel], shifts[sel])

    if compaction >= 1.0:
        return NeighborListChangesResult(
            _mask_only(added_mask, remapped, shifts),
            _mask_only(removed_mask, remapped, shifts),
        )

    capacity = int(n_total * compaction)
    return NeighborListChangesResult(
        _compact(added_mask, remapped, "added"),
        _compact(removed_mask, remapped, "removed"),
    )