Skip to content

kups.core.neighborlist

Neighbor list construction and edge representations for molecular systems.

This module provides multiple neighbor list algorithms for finding interacting pairs of particles within cutoff distances, with different performance and accuracy trade-offs.

Core Components

Neighbor List Implementations

Primary Implementations

  1. CellListNeighborList (Recommended when cutoff << box size)

    • O(N) complexity using spatial hashing
    • Best when cutoff / box_size < 0.3 (cutoff much smaller than box)
    • Requires periodic boundary conditions
    • Efficiency improves as cutoff/box ratio decreases
  2. DenseNearestNeighborList

    • O(N²/K) complexity (K = number of systems)
    • Best when cutoff / box_size ~ 1 (cutoff comparable to box)
    • Works with or without periodic boundaries
    • More efficient when few cells would fit in box
  3. AllDenseNearestNeighborList

    • O(N²) complexity across all systems
    • Only for single-system simulations or testing
    • Crosses system boundaries (use with caution!)

Refinement Implementations

These allow sharing a single base neighbor list across multiple potentials with different cutoffs or interaction rules (e.g., Lennard-Jones and Coulomb).

  1. RefineMaskNeighborList

    • Applies inclusion/exclusion masks to precomputed edges
    • Use for bonded exclusions or group-specific interactions
    • No distance recalculation
    • Share one neighbor list, apply different masks per potential
  2. RefineCutoffNeighborList

    • Refines precomputed edges with new cutoff distances
    • Use for multi-stage construction or adaptive cutoffs
    • Recalculates distances
    • Share one conservative neighbor list, apply different cutoffs per potential

Features

All neighbor lists handle: - Periodic boundary conditions via shift vectors - Multiple systems in parallel with segmentation - Automatic capacity management for variable neighbor counts - Integration with JAX transformations (JIT, vmap, etc.)

Choosing an Implementation

# When cutoff << box size (cutoff/box < 0.3)
# Example: 10 Å cutoff, 50 Å box → use CellList
nl = CellListNeighborList.new(state, lens=lens(lambda s: s.nl_params))

# When cutoff ~ box size (cutoff/box ~ 1)
# Example: 15 Å cutoff, 20 Å box → use Dense
nl = DenseNearestNeighborList.new(state, lens=lens(lambda s: s.nl_params))

# Share one neighbor list across multiple potentials with different masks
base_edges = base_nl(particles, None, cells, cutoffs, None)
lj_nl = RefineMaskNeighborList(candidates=base_edges)  # Exclude 1-4 interactions
coulomb_nl = RefineMaskNeighborList(candidates=base_edges)  # Different exclusions

# Share one neighbor list across potentials with different cutoffs
base_edges = base_nl(particles, None, cells, max_cutoff, None)
lj_nl = RefineCutoffNeighborList(candidates=base_edges, avg_edges=cap1)  # r_cut = 10 Å
coulomb_nl = RefineCutoffNeighborList(candidates=base_edges, avg_edges=cap2)  # r_cut = 15 Å

AllDenseNearestNeighborList

Dense O(N²) neighbor list considering all pairs across all systems.

This implementation generates all possible particle pairs without spatial optimization. It is only suitable for very small systems or testing.

Warning: This crosses system boundaries! Only use for single-system simulations. For multiple systems, use DenseNearestNeighborList instead.

Complexity: O(N²) where N is the total number of particles across all systems.

Attributes:

Name Type Description
avg_edges Capacity[int]

Capacity manager for edge array.

avg_image_candidates Capacity[int]

Capacity manager for image candidate pairs.

Example
# Construct from state and a lens to the neighbor list parameters:
nl = AllDenseNearestNeighborList.new(state, lens(lambda s: s.nl_params))

# Or, if the state implements IsNeighborListState:
nl = AllDenseNearestNeighborList.from_state(state)

edges = nl(particles, None, unit_cells, cutoffs, None)
Source code in src/kups/core/neighborlist.py
@dataclass
class AllDenseNearestNeighborList:
    """Dense O(N²) neighbor list considering all pairs across all systems.

    This implementation generates all possible particle pairs without spatial
    optimization. It is only suitable for very small systems or testing.

    **Warning**: This crosses system boundaries! Only use for single-system
    simulations. For multiple systems, use
    [DenseNearestNeighborList][kups.core.neighborlist.DenseNearestNeighborList]
    instead.

    Complexity: O(N²) where N is the total number of particles across all systems.

    Attributes:
        avg_edges: Capacity manager for edge array.
        avg_image_candidates: Capacity manager for image candidate pairs.

    Example:
        ```python
        # Construct from state and a lens to the neighbor list parameters:
        nl = AllDenseNearestNeighborList.new(state, lens(lambda s: s.nl_params))

        # Or, if the state implements IsNeighborListState:
        nl = AllDenseNearestNeighborList.from_state(state)

        edges = nl(particles, None, unit_cells, cutoffs, None)
        ```
    """

    avg_edges: Capacity[int]
    avg_image_candidates: Capacity[int]

    @classmethod
    def new[S](
        cls, state: S, lens: Lens[S, IsAllDenseNeighborListParams]
    ) -> AllDenseNearestNeighborList:
        params = lens.get(state)
        return AllDenseNearestNeighborList(
            avg_edges=LensCapacity(params.avg_edges, lens.focus(lambda x: x.avg_edges)),
            avg_image_candidates=LensCapacity(
                params.avg_image_candidates,
                lens.focus(lambda x: x.avg_image_candidates),
            ),
        )

    @classmethod
    def from_state(
        cls, state: IsNeighborListState[IsAllDenseNeighborListParams]
    ) -> AllDenseNearestNeighborList:
        return cls.new(state, lens(lambda s: s.neighborlist_params))

    @jit
    def __call__(
        self,
        lh: Table[ParticleId, NeighborListPoints],
        rh: Table[ParticleId, NeighborListPoints] | None,
        systems: Table[SystemId, NeighborListSystems],
        cutoffs: Table[SystemId, Array],
        rh_index_remap: Index[ParticleId] | None = None,
    ) -> Edges[Literal[2]]:
        if lh.data.inclusion.num_labels >= 2:
            logging.warning(
                "AllDenseNearestNeighborList is intended for single-system simulations. "
                "Performance may be degraded when using multiple systems. "
                "Consider using DenseNearestNeighborList or CellListNeighborList instead."
            )
        rh_size = rh.size if rh is not None else lh.size
        return basic_neighborlist(
            lh,
            rh,
            systems,
            cutoffs,
            rh_index_remap,
            candidate_selector=_all_subselect,
            max_num_edges=self.avg_edges.multiply(rh_size),
            max_image_candidates=self.avg_image_candidates.multiply(rh_size)
            if self.avg_image_candidates
            else None,
        )

CellListNeighborList

Efficient O(N) neighbor list using spatial hashing with cell lists.

This is the recommended implementation when the cutoff is much smaller than the box size. It divides space into a grid of cells and only checks pairs in neighboring cells, achieving linear scaling with system size.

Requires periodic boundary conditions (UnitCell).

Complexity: O(N) for well-distributed particles where cutoff << box size. Efficiency improves as cutoff/box ratio decreases.

Attributes:

Name Type Description
avg_candidates Capacity[int]

Capacity for candidate pair storage (from cell list).

avg_edges Capacity[int]

Capacity for final edge array.

cells Capacity[int]

Capacity for cell hash table (grows with box_size³/cutoff³).

avg_image_candidates Capacity[int]

Capacity for image candidate pairs.

Algorithm
  1. Partition space into grid cells of size ~cutoff
  2. Hash each particle to its cell
  3. For each particle, check only neighboring 27 cells (3D)
  4. Filter candidates by actual distance
When to use
  • When cutoff/box_size << 1 (cutoff much smaller than box)
  • Typically cutoff/box < 0.3 for good efficiency
  • Periodic boundary conditions required
Example
# Example: 10 Å cutoff in 50 Å box → cutoff/box = 0.2 -- Good for CellList
nl = CellListNeighborList.new(state, lens(lambda s: s.nl_params))

# Or, if the state implements IsNeighborListState:
nl = CellListNeighborList.from_state(state)

edges = nl(particles, None, unit_cells, cutoffs, None)
Source code in src/kups/core/neighborlist.py
@dataclass
class CellListNeighborList:
    """Efficient O(N) neighbor list using spatial hashing with cell lists.

    This is the recommended implementation when the cutoff is much smaller than
    the box size. It divides space into a grid of cells and only checks pairs in
    neighboring cells, achieving linear scaling with system size.

    **Requires periodic boundary conditions** (UnitCell).

    Complexity: O(N) for well-distributed particles where cutoff << box size.
    Efficiency improves as cutoff/box ratio decreases.

    Attributes:
        avg_candidates: Capacity for candidate pair storage (from cell list).
        avg_edges: Capacity for final edge array.
        cells: Capacity for cell hash table (grows with box_size³/cutoff³).
        avg_image_candidates: Capacity for image candidate pairs.

    Algorithm:
        1. Partition space into grid cells of size ~cutoff
        2. Hash each particle to its cell
        3. For each particle, check only neighboring 27 cells (3D)
        4. Filter candidates by actual distance

    When to use:
        - When cutoff/box_size << 1 (cutoff much smaller than box)
        - Typically cutoff/box < 0.3 for good efficiency
        - Periodic boundary conditions required

    Example:
        ```python
        # Example: 10 Å cutoff in 50 Å box → cutoff/box = 0.2 -- Good for CellList
        nl = CellListNeighborList.new(state, lens(lambda s: s.nl_params))

        # Or, if the state implements IsNeighborListState:
        nl = CellListNeighborList.from_state(state)

        edges = nl(particles, None, unit_cells, cutoffs, None)
        ```
    """

    avg_candidates: Capacity[int]
    avg_edges: Capacity[int]
    cells: Capacity[int]
    avg_image_candidates: Capacity[int]

    @classmethod
    def new[S](cls, state: S, lens: Lens[S, IsCellListParams]) -> CellListNeighborList:
        params = lens.get(state)
        return CellListNeighborList(
            avg_candidates=LensCapacity(
                params.avg_candidates, lens.focus(lambda x: x.avg_candidates)
            ),
            avg_edges=LensCapacity(params.avg_edges, lens.focus(lambda x: x.avg_edges)),
            avg_image_candidates=LensCapacity(
                params.avg_image_candidates,
                lens.focus(lambda x: x.avg_image_candidates),
            ),
            cells=LensCapacity(params.cells, lens.focus(lambda x: x.cells), base=1),
        )

    @classmethod
    def from_state(
        cls, state: IsNeighborListState[IsCellListParams]
    ) -> CellListNeighborList:
        return cls.new(state, lens(lambda s: s.neighborlist_params))

    @jit
    def __call__(
        self,
        lh: Table[ParticleId, NeighborListPoints],
        rh: Table[ParticleId, NeighborListPoints] | None,
        systems: Table[SystemId, NeighborListSystems],
        cutoffs: Table[SystemId, Array],
        rh_index_remap: Index[ParticleId] | None = None,
    ) -> Edges[Literal[2]]:
        rh_size = rh.size if rh is not None else lh.size
        return basic_neighborlist(
            lh,
            rh,
            systems,
            cutoffs,
            rh_index_remap,
            candidate_selector=partial(
                _cell_list_subselect,
                cutoffs=cutoffs.data,
                max_num_cells=self.cells,
                max_num_candidates=self.avg_candidates.multiply(rh_size),
            ),
            max_num_edges=self.avg_edges.multiply(rh_size),
            max_image_candidates=self.avg_image_candidates.multiply(rh_size),
        )

DenseNearestNeighborList

Dense O(N²) neighbor list respecting system boundaries.

This implementation generates all particle pairs within each system separately, avoiding cross-system interactions. Efficient when the cutoff is comparable to the box size (cutoff/box ~ 1).

Complexity: O(N² / K²) where N is total particles and K is number of systems.

Attributes:

Name Type Description
avg_candidates Capacity[int]

Capacity for candidate pair storage.

avg_edges Capacity[int]

Capacity for final edge array.

avg_image_candidates Capacity[int]

Capacity for image candidate pairs.

When to use
  • When cutoff/box_size ~ 1 (cutoff comparable to box dimensions)
  • Small box relative to cutoff (few cells would fit)
  • Non-periodic systems
Example
# Example: 15 Å cutoff in 20 Å box → cutoff/box = 0.75
nl = DenseNearestNeighborList.new(state, lens(lambda s: s.nl_params))

# Or, if the state implements IsNeighborListState:
nl = DenseNearestNeighborList.from_state(state)

edges = nl(particles, None, systems, cutoffs, None)
Source code in src/kups/core/neighborlist.py
@dataclass
class DenseNearestNeighborList:
    """Dense O(N²) neighbor list respecting system boundaries.

    This implementation generates all particle pairs within each system
    separately, avoiding cross-system interactions. Efficient when the cutoff
    is comparable to the box size (cutoff/box ~ 1).

    Complexity: O(N² / K²) where N is total particles and K is number of systems.

    Attributes:
        avg_candidates: Capacity for candidate pair storage.
        avg_edges: Capacity for final edge array.
        avg_image_candidates: Capacity for image candidate pairs.

    When to use:
        - When cutoff/box_size ~ 1 (cutoff comparable to box dimensions)
        - Small box relative to cutoff (few cells would fit)
        - Non-periodic systems

    Example:
        ```python
        # Example: 15 Å cutoff in 20 Å box → cutoff/box = 0.75
        nl = DenseNearestNeighborList.new(state, lens(lambda s: s.nl_params))

        # Or, if the state implements IsNeighborListState:
        nl = DenseNearestNeighborList.from_state(state)

        edges = nl(particles, None, systems, cutoffs, None)
        ```
    """

    avg_candidates: Capacity[int]
    avg_edges: Capacity[int]
    avg_image_candidates: Capacity[int]

    @classmethod
    def new[S](
        cls, state: S, lens: Lens[S, IsDenseNeighborlistParams]
    ) -> DenseNearestNeighborList:
        params = lens.get(state)
        return DenseNearestNeighborList(
            avg_candidates=LensCapacity(
                params.avg_candidates, lens.focus(lambda x: x.avg_candidates)
            ),
            avg_edges=LensCapacity(params.avg_edges, lens.focus(lambda x: x.avg_edges)),
            avg_image_candidates=LensCapacity(
                params.avg_image_candidates,
                lens.focus(lambda x: x.avg_image_candidates),
            ),
        )

    @classmethod
    def from_state(
        cls, state: IsNeighborListState[IsDenseNeighborlistParams]
    ) -> DenseNearestNeighborList:
        return cls.new(state, lens(lambda s: s.neighborlist_params))

    def __call__(
        self,
        lh: Table[ParticleId, NeighborListPoints],
        rh: Table[ParticleId, NeighborListPoints] | None,
        systems: Table[SystemId, NeighborListSystems],
        cutoffs: Table[SystemId, Array],
        rh_index_remap: Index[ParticleId] | None = None,
    ) -> Edges[Literal[2]]:
        rh_size = rh.size if rh is not None else lh.size
        selector = partial(
            _dense_subselect,
            max_num_candidates=self.avg_candidates.multiply(rh_size),
        )
        return basic_neighborlist(
            lh,
            rh,
            systems,
            cutoffs,
            rh_index_remap,
            candidate_selector=selector,
            max_num_edges=self.avg_edges.multiply(rh_size),
            max_image_candidates=self.avg_image_candidates.multiply(rh_size),
        )

Edges

Bases: Sliceable

Represents edges (connections) between particles in a molecular system.

An edge connects Degree particles, where degree=2 represents pairwise interactions (bonds), degree=3 represents three-body interactions (angles), etc.

For periodic systems, edges include shift vectors that indicate how many unit cells to traverse when computing distances between connected particles.

Class Type Parameters:

Name Bound or Constraints Description Default
Degree int

Number of particles connected by each edge (static type check)

required

Attributes:

Name Type Description
indices Index[ParticleId]

Particle indices for each edge, shape (n_edges, Degree)

shifts Array

Periodic shift vectors, shape (n_edges, Degree-1, 3). Shift vectors for the 2nd through Degree-th particle relative to the first.

Example
# Pairwise edges (bonds) between particles
edges = Edges(
    indices=jnp.array([[0, 1], [1, 2], [0, 2]]),  # 3 edges
    shifts=jnp.array([[[0, 0, 0]], [[0, 0, 0]], [[1, 0, 0]]])  # 3rd edge crosses boundary
)
Source code in src/kups/core/neighborlist.py
@dataclass
class Edges[Degree: int](Sliceable):
    """Represents edges (connections) between particles in a molecular system.

    An edge connects `Degree` particles, where degree=2 represents pairwise
    interactions (bonds), degree=3 represents three-body interactions (angles), etc.

    For periodic systems, edges include shift vectors that indicate how many
    unit cells to traverse when computing distances between connected particles.

    Type Parameters:
        Degree: Number of particles connected by each edge (static type check)

    Attributes:
        indices: Particle indices for each edge, shape `(n_edges, Degree)`
        shifts: Periodic shift vectors, shape `(n_edges, Degree-1, 3)`.
            Shift vectors for the 2nd through Degree-th particle relative to the first.

    Example:
        ```python
        # Pairwise edges (bonds) between particles
        edges = Edges(
            indices=jnp.array([[0, 1], [1, 2], [0, 2]]),  # 3 edges
            shifts=jnp.array([[[0, 0, 0]], [[0, 0, 0]], [[1, 0, 0]]])  # 3rd edge crosses boundary
        )
        ```
    """

    # The degree is purely for type checking and does not affect runtime behavior
    indices: Index[ParticleId]  # (n_edges, Degree)
    shifts: Array  # (n_edges, Degree - 1, 3)

    def __post_init__(self):
        # Resolve the underlying array for validation
        raw = self.indices.indices if isinstance(self.indices, Index) else self.indices
        if not isinstance(raw, Array):
            return
        assert jnp.issubdtype(raw.dtype, jnp.integer), (
            f"Indices must be of integer type, got {raw.dtype}"
        )
        target_shape = (
            *self.indices.shape[:-1],
            self.indices.shape[-1] - 1 if self.indices.shape[-1] > 1 else 0,
            3,
        )
        assert self.shifts.shape == target_shape, (
            f"Shifts must have shape {target_shape}, got {self.shifts.shape}"
        )

    def difference_vectors(
        self,
        particles: Table[ParticleId, HasPositionsAndSystemIndex],
        systems: Table[SystemId, HasUnitCell],
    ) -> Array:
        """Compute difference vectors between connected particles.

        For each edge, computes the vector from the first particle to each
        subsequent particle, accounting for periodic boundary conditions.

        Args:
            particles: Particle positions with system index information.
            systems: System data with unit cell for periodic boundary conditions.

        Returns:
            Array of shape `(n_edges, Degree-1, 3)` containing difference vectors.
        """

        shifts = self.absolute_shifts(particles, systems)
        pos = particles[self.indices].positions
        return pos[:, 1:] - pos[:, :1] + shifts

    def absolute_shifts(
        self,
        particles: Table[ParticleId, HasPositionsAndSystemIndex],
        systems: Table[SystemId, HasUnitCell],
    ) -> Array:
        """Compute absolute shift vectors for all particles in each edge.

        Converts relative shifts to absolute Cartesian shift vectors.

        Args:
            particles: Particle data with system index information.
            systems: System data with unit cell for periodic boundary conditions.

        Returns:
            Array of shape `(n_edges, Degree-1, 3)` containing absolute shift vectors.
        """
        lattice = systems.map_data(lambda x: x.unitcell.lattice_vectors)
        vecs = lattice[particles[self.indices[:, 0]].system]
        return triangular_3x3_matmul(vecs[:, None], self.shifts)

    @property
    def degree(self) -> int:
        return self.indices.shape[-1]

    def __len__(self) -> int:
        return self.indices.shape[0]

absolute_shifts(particles, systems)

Compute absolute shift vectors for all particles in each edge.

Converts relative shifts to absolute Cartesian shift vectors.

Parameters:

Name Type Description Default
particles Table[ParticleId, HasPositionsAndSystemIndex]

Particle data with system index information.

required
systems Table[SystemId, HasUnitCell]

System data with unit cell for periodic boundary conditions.

required

Returns:

Type Description
Array

Array of shape (n_edges, Degree-1, 3) containing absolute shift vectors.

Source code in src/kups/core/neighborlist.py
def absolute_shifts(
    self,
    particles: Table[ParticleId, HasPositionsAndSystemIndex],
    systems: Table[SystemId, HasUnitCell],
) -> Array:
    """Compute absolute shift vectors for all particles in each edge.

    Converts relative shifts to absolute Cartesian shift vectors.

    Args:
        particles: Particle data with system index information.
        systems: System data with unit cell for periodic boundary conditions.

    Returns:
        Array of shape `(n_edges, Degree-1, 3)` containing absolute shift vectors.
    """
    lattice = systems.map_data(lambda x: x.unitcell.lattice_vectors)
    vecs = lattice[particles[self.indices[:, 0]].system]
    return triangular_3x3_matmul(vecs[:, None], self.shifts)

difference_vectors(particles, systems)

Compute difference vectors between connected particles.

For each edge, computes the vector from the first particle to each subsequent particle, accounting for periodic boundary conditions.

Parameters:

Name Type Description Default
particles Table[ParticleId, HasPositionsAndSystemIndex]

Particle positions with system index information.

required
systems Table[SystemId, HasUnitCell]

System data with unit cell for periodic boundary conditions.

required

Returns:

Type Description
Array

Array of shape (n_edges, Degree-1, 3) containing difference vectors.

Source code in src/kups/core/neighborlist.py
def difference_vectors(
    self,
    particles: Table[ParticleId, HasPositionsAndSystemIndex],
    systems: Table[SystemId, HasUnitCell],
) -> Array:
    """Compute difference vectors between connected particles.

    For each edge, computes the vector from the first particle to each
    subsequent particle, accounting for periodic boundary conditions.

    Args:
        particles: Particle positions with system index information.
        systems: System data with unit cell for periodic boundary conditions.

    Returns:
        Array of shape `(n_edges, Degree-1, 3)` containing difference vectors.
    """

    shifts = self.absolute_shifts(particles, systems)
    pos = particles[self.indices].positions
    return pos[:, 1:] - pos[:, :1] + shifts

IsAllDenseNeighborListParams

Bases: Protocol

Protocol for parameters required by AllDenseNearestNeighborList.

Source code in src/kups/core/neighborlist.py
class IsAllDenseNeighborListParams(Protocol):
    """Protocol for parameters required by ``AllDenseNearestNeighborList``."""

    @property
    def avg_edges(self) -> int: ...
    @property
    def avg_image_candidates(self) -> int: ...

IsCellListParams

Bases: Protocol

Protocol for parameters required by CellListNeighborList.

Source code in src/kups/core/neighborlist.py
class IsCellListParams(Protocol):
    """Protocol for parameters required by ``CellListNeighborList``."""

    @property
    def avg_candidates(self) -> int: ...
    @property
    def avg_edges(self) -> int: ...
    @property
    def cells(self) -> int: ...
    @property
    def avg_image_candidates(self) -> int: ...

IsDenseNeighborlistParams

Bases: Protocol

Protocol for parameters required by DenseNearestNeighborList.

Source code in src/kups/core/neighborlist.py
class IsDenseNeighborlistParams(Protocol):
    """Protocol for parameters required by ``DenseNearestNeighborList``."""

    @property
    def avg_candidates(self) -> int: ...
    @property
    def avg_edges(self) -> int: ...
    @property
    def avg_image_candidates(self) -> int: ...

IsNeighborListState

Bases: Protocol

Protocol for states that expose neighbor list parameters.

A state satisfying this protocol can be passed to from_state() on any neighbor list class. The type parameter P determines which neighbor list types the state can construct (e.g., IsAllDenseNeighborListParams, IsDenseNeighborlistParams, IsCellListParams, or IsUniversalNeighborlistParams).

Source code in src/kups/core/neighborlist.py
class IsNeighborListState[P](Protocol):
    """Protocol for states that expose neighbor list parameters.

    A state satisfying this protocol can be passed to ``from_state()`` on any
    neighbor list class. The type parameter ``P`` determines which neighbor
    list types the state can construct (e.g., ``IsAllDenseNeighborListParams``,
    ``IsDenseNeighborlistParams``, ``IsCellListParams``, or
    ``IsUniversalNeighborlistParams``).
    """

    @property
    def neighborlist_params(self) -> P: ...

IsUniversalNeighborlistParams

Bases: Protocol

Protocol for parameters required by any neighbor list implementation.

A superset of IsAllDenseNeighborListParams, IsDenseNeighborlistParams, and IsCellListParams. Satisfying this protocol allows constructing any of the three neighbor list types.

Source code in src/kups/core/neighborlist.py
class IsUniversalNeighborlistParams(Protocol):
    """Protocol for parameters required by any neighbor list implementation.

    A superset of ``IsAllDenseNeighborListParams``, ``IsDenseNeighborlistParams``,
    and ``IsCellListParams``. Satisfying this protocol allows constructing any
    of the three neighbor list types.
    """

    @property
    def avg_edges(self) -> int: ...
    @property
    def avg_candidates(self) -> int: ...
    @property
    def avg_image_candidates(self) -> int: ...
    @property
    def cells(self) -> int: ...

NearestNeighborList

Bases: Protocol

Protocol for neighbor list construction algorithms.

Implementations find pairs of particles within a cutoff distance, handling periodic boundary conditions and inclusion/exclusion masks.

Source code in src/kups/core/neighborlist.py
class NearestNeighborList(Protocol):
    """Protocol for neighbor list construction algorithms.

    Implementations find pairs of particles within a cutoff distance, handling
    periodic boundary conditions and inclusion/exclusion masks.
    """

    def __call__[P: NeighborListPoints](
        self,
        lh: Table[ParticleId, P],
        rh: Table[ParticleId, P] | None,
        systems: Table[SystemId, NeighborListSystems],
        cutoffs: Table[SystemId, Array],
        rh_index_remap: Index[ParticleId] | None = None,
    ) -> Edges[Literal[2]]:
        """Find all particle pairs within the cutoff distance.

        Args:
            lh: Left-hand particles to find neighbors for
            rh: Right-hand particles to search within (or None for self-neighbors)
            systems: Indexed system data with unit cell information
            cutoffs: Indexed cutoff data per system
            rh_index_remap: Optional index mapping rh particles back to lh
                particle IDs for self-interaction exclusion. When ``None``,
                rh is treated as disjoint from lh.

        Returns:
            Edges connecting particle pairs within cutoff
        """
        ...

__call__(lh, rh, systems, cutoffs, rh_index_remap=None)

Find all particle pairs within the cutoff distance.

Parameters:

Name Type Description Default
lh Table[ParticleId, P]

Left-hand particles to find neighbors for

required
rh Table[ParticleId, P] | None

Right-hand particles to search within (or None for self-neighbors)

required
systems Table[SystemId, NeighborListSystems]

Indexed system data with unit cell information

required
cutoffs Table[SystemId, Array]

Indexed cutoff data per system

required
rh_index_remap Index[ParticleId] | None

Optional index mapping rh particles back to lh particle IDs for self-interaction exclusion. When None, rh is treated as disjoint from lh.

None

Returns:

Type Description
Edges[Literal[2]]

Edges connecting particle pairs within cutoff

Source code in src/kups/core/neighborlist.py
def __call__[P: NeighborListPoints](
    self,
    lh: Table[ParticleId, P],
    rh: Table[ParticleId, P] | None,
    systems: Table[SystemId, NeighborListSystems],
    cutoffs: Table[SystemId, Array],
    rh_index_remap: Index[ParticleId] | None = None,
) -> Edges[Literal[2]]:
    """Find all particle pairs within the cutoff distance.

    Args:
        lh: Left-hand particles to find neighbors for
        rh: Right-hand particles to search within (or None for self-neighbors)
        systems: Indexed system data with unit cell information
        cutoffs: Indexed cutoff data per system
        rh_index_remap: Optional index mapping rh particles back to lh
            particle IDs for self-interaction exclusion. When ``None``,
            rh is treated as disjoint from lh.

    Returns:
        Edges connecting particle pairs within cutoff
    """
    ...

RefineCutoffNeighborList

Refine precomputed edges by re-checking distances with new cutoffs.

This neighbor list takes an existing set of candidate edges and filters them by computing actual distances and comparing to cutoffs. Enables sharing a single conservative neighbor list across multiple potentials with different cutoff distances.

Key benefit: Compute expensive neighbor list once with maximum cutoff, then refine for each potential with its specific cutoff (e.g., Lennard-Jones at 10 Å, Coulomb at 15 Å).

Attributes:

Name Type Description
candidates Edges[Literal[2]]

Precomputed edges to refine (should be conservative/over-inclusive).

avg_edges Capacity[int]

Capacity for output edge array.

Use cases
  • Multiple potentials sharing one neighbor list with different cutoffs
  • Multi-stage neighbor list construction (coarse then fine)
  • Adaptive cutoffs that change during simulation
  • Using a static "super" neighbor list with varying actual cutoffs
Example
# Compute base neighbor list once with maximum cutoff
max_cutoff = 15.0  # Maximum of all potential cutoffs
base_edges = base_nl(particles, None, cells, max_cutoff, None)

# Share across potentials with different cutoffs
lj_nl = RefineCutoffNeighborList(candidates=base_edges, avg_edges=cap1)
lj_edges = lj_nl(particles, None, cells, cutoff=10.0, None)  # LJ cutoff

coulomb_nl = RefineCutoffNeighborList(candidates=base_edges, avg_edges=cap2)
coulomb_edges = coulomb_nl(particles, None, cells, cutoff=15.0, None)  # Coulomb cutoff
Source code in src/kups/core/neighborlist.py
@dataclass
class RefineCutoffNeighborList:
    """Refine precomputed edges by re-checking distances with new cutoffs.

    This neighbor list takes an existing set of candidate edges and filters them
    by computing actual distances and comparing to cutoffs. Enables sharing a
    single conservative neighbor list across multiple potentials with different
    cutoff distances.

    **Key benefit**: Compute expensive neighbor list once with maximum cutoff,
    then refine for each potential with its specific cutoff (e.g., Lennard-Jones
    at 10 Å, Coulomb at 15 Å).

    Attributes:
        candidates: Precomputed edges to refine (should be conservative/over-inclusive).
        avg_edges: Capacity for output edge array.

    Use cases:
        - Multiple potentials sharing one neighbor list with different cutoffs
        - Multi-stage neighbor list construction (coarse then fine)
        - Adaptive cutoffs that change during simulation
        - Using a static "super" neighbor list with varying actual cutoffs

    Example:
        ```python
        # Compute base neighbor list once with maximum cutoff
        max_cutoff = 15.0  # Maximum of all potential cutoffs
        base_edges = base_nl(particles, None, cells, max_cutoff, None)

        # Share across potentials with different cutoffs
        lj_nl = RefineCutoffNeighborList(candidates=base_edges, avg_edges=cap1)
        lj_edges = lj_nl(particles, None, cells, cutoff=10.0, None)  # LJ cutoff

        coulomb_nl = RefineCutoffNeighborList(candidates=base_edges, avg_edges=cap2)
        coulomb_edges = coulomb_nl(particles, None, cells, cutoff=15.0, None)  # Coulomb cutoff
        ```
    """

    candidates: Edges[Literal[2]]
    avg_edges: Capacity[int]

    @jit
    def __call__(
        self,
        lh: Table[ParticleId, NeighborListPoints],
        rh: Table[ParticleId, NeighborListPoints] | None,
        systems: Table[SystemId, NeighborListSystems],
        cutoffs: Table[SystemId, Array],
        rh_index_remap: Index[ParticleId] | None = None,
    ) -> Edges[Literal[2]]:
        rh_remap_raw = (
            rh_index_remap.indices_in(lh.keys) if rh_index_remap is not None else None
        )

        if rh_remap_raw is not None:
            assert rh is not None
            inv_rh_index_remap = jnp.full(lh.size, rh.size, dtype=int)
            inv_rh_index_remap = inv_rh_index_remap.at[rh_remap_raw].set(
                jnp.arange(rh.size, dtype=int)
            )
        else:
            inv_rh_index_remap = None

        def _cand_selector(
            lh: Table[ParticleId, NeighborListPoints],
            rh: Table[ParticleId, NeighborListPoints],
            systems: Table[SystemId, NeighborListSystems],
        ) -> _Candidates:
            rh_c = self.candidates.indices[:, 1].indices
            if inv_rh_index_remap is not None:
                rh_c = inv_rh_index_remap.at[rh_c].get(mode="fill", fill_value=len(lh))
            return _Candidates(self.candidates.indices[:, 0], Index(rh.keys, rh_c))

        rh_size = rh.size if rh is not None else lh.size
        return basic_neighborlist(
            lh,
            rh,
            systems,
            cutoffs,
            rh_index_remap,
            candidate_selector=_cand_selector,
            max_num_edges=self.avg_edges.multiply(rh_size),
            consider_images=False,
        )

RefineMaskNeighborList

Refine a precomputed neighbor list by applying inclusion/exclusion masks.

This neighbor list takes an existing set of candidate edges and filters them based on segmentation masks, without recomputing distances. Enables sharing a single base neighbor list across multiple potentials with different interaction rules.

Key benefit: Compute expensive neighbor list once, apply different masks for different potentials (e.g., Lennard-Jones excludes 1-4 interactions, Coulomb has different exclusions).

Attributes:

Name Type Description
candidates Edges[Literal[2]]

Precomputed edges to refine

Use cases
  • Multiple potentials sharing one neighbor list with different exclusions
  • Excluding bonded pairs (1-2, 1-3, 1-4) from non-bonded interactions
  • Applying group-specific interaction rules
  • Multi-scale simulations with different interaction levels
Example
# Compute base neighbor list once
base_edges = base_nl(particles, None, cells, cutoffs, None)

# Share across potentials with different masks
lj_nl = RefineMaskNeighborList(candidates=base_edges)
lj_edges = lj_nl(lj_particles, None, cells, cutoffs, None)  # 1-4 exclusions

coulomb_nl = RefineMaskNeighborList(candidates=base_edges)
coulomb_edges = coulomb_nl(coulomb_particles, None, cells, cutoffs, None)  # 1-2 exclusions only
Source code in src/kups/core/neighborlist.py
@dataclass
class RefineMaskNeighborList:
    """Refine a precomputed neighbor list by applying inclusion/exclusion masks.

    This neighbor list takes an existing set of candidate edges and filters them
    based on segmentation masks, without recomputing distances. Enables sharing
    a single base neighbor list across multiple potentials with different
    interaction rules.

    **Key benefit**: Compute expensive neighbor list once, apply different masks
    for different potentials (e.g., Lennard-Jones excludes 1-4 interactions,
    Coulomb has different exclusions).

    Attributes:
        candidates: Precomputed edges to refine

    Use cases:
        - Multiple potentials sharing one neighbor list with different exclusions
        - Excluding bonded pairs (1-2, 1-3, 1-4) from non-bonded interactions
        - Applying group-specific interaction rules
        - Multi-scale simulations with different interaction levels

    Example:
        ```python
        # Compute base neighbor list once
        base_edges = base_nl(particles, None, cells, cutoffs, None)

        # Share across potentials with different masks
        lj_nl = RefineMaskNeighborList(candidates=base_edges)
        lj_edges = lj_nl(lj_particles, None, cells, cutoffs, None)  # 1-4 exclusions

        coulomb_nl = RefineMaskNeighborList(candidates=base_edges)
        coulomb_edges = coulomb_nl(coulomb_particles, None, cells, cutoffs, None)  # 1-2 exclusions only
        ```
    """

    candidates: Edges[Literal[2]]

    @jit
    def __call__(
        self,
        lh: Table[ParticleId, NeighborListPoints],
        rh: Table[ParticleId, NeighborListPoints] | None,
        systems: Table[SystemId, NeighborListSystems],
        cutoffs: Table[SystemId, Array],
        rh_index_remap: Index[ParticleId] | None = None,
    ) -> Edges[Literal[2]]:
        lh_c = self.candidates.indices[:, 0]
        rh_c = self.candidates.indices[:, 1]
        lh_d, rh_d = lh[lh_c], lh[rh_c]
        lh_incl, rh_incl = Index.match(lh_d.inclusion, rh_d.inclusion)
        lh_excl, rh_excl = Index.match(lh_d.exclusion, rh_d.exclusion)
        mask = lh_incl == rh_incl
        mask &= lh_excl != rh_excl
        indices = where_broadcast_last(mask, self.candidates.indices.indices, lh.size)
        shifts = where_broadcast_last(mask, self.candidates.shifts, 0)
        return Edges(Index(self.candidates.indices.keys, indices), shifts)

UniversalNeighborlistParameters

Concrete parameter dataclass satisfying IsUniversalNeighborlistParams.

Holds the capacity hints needed by every neighbor list implementation. Use the estimate() classmethod to compute reasonable initial values from system geometry rather than guessing manually.

Attributes:

Name Type Description
avg_edges int

Average number of edges per particle (for edge capacity).

avg_candidates int

Average number of candidate pairs per particle.

avg_image_candidates int

Average number of image candidate pairs per particle.

cells int

Maximum number of spatial hash cells across all systems.

Source code in src/kups/core/neighborlist.py
@dataclass
class UniversalNeighborlistParameters:
    """Concrete parameter dataclass satisfying ``IsUniversalNeighborlistParams``.

    Holds the capacity hints needed by every neighbor list implementation.
    Use the ``estimate()`` classmethod to compute reasonable initial values
    from system geometry rather than guessing manually.

    Attributes:
        avg_edges: Average number of edges per particle (for edge capacity).
        avg_candidates: Average number of candidate pairs per particle.
        avg_image_candidates: Average number of image candidate pairs per particle.
        cells: Maximum number of spatial hash cells across all systems.
    """

    avg_edges: int = field(static=True)
    avg_candidates: int = field(static=True)
    avg_image_candidates: int = field(static=True)
    cells: int = field(static=True)

    @classmethod
    @no_jax_tracing
    def estimate(
        cls,
        particles_per_system: Table[SystemId, Array],
        systems: Table[SystemId, NeighborListSystems],
        cutoffs: Table[SystemId, Array],
        *,
        base: float = 2,
        multiplier: float = 1.0,
    ) -> UniversalNeighborlistParameters:
        """Estimate parameters for all neighbor list types from system geometry.

        Computes conservative initial capacities based on particle density
        and cutoff radii. The estimates are rounded up to the next power of
        ``base`` to amortize future resizing.

        Args:
            particles_per_system: Number of particles per system.
            systems: System data with unit cell information.
            cutoffs: Cutoff distance per system.
            base: Base for power-of rounding (default 2).
            multiplier: Safety factor applied to the estimate (default 1.0).

        Returns:
            A ``UniversalNeighborlistParameters`` instance with estimated values.
        """
        sys = Table.join(systems, particles_per_system, cutoffs)
        total_candidates = total_edges = max_cells = 0
        for _, (s, n_p, c) in sys:
            num_cells = _num_cells(s, c).prod()
            total_candidates += min(n_p / num_cells * (3**3), n_p)
            total_edges += _estimate_avg_num_edges(
                n_p, s.unitcell.volume, c, base, multiplier
            )
            max_cells = max(num_cells, max_cells)
        total_candidates = next_higher_power(
            jnp.array(total_candidates * multiplier / sys.size), base=base
        )
        return UniversalNeighborlistParameters(
            avg_edges=int(total_edges // sys.size),
            avg_candidates=int(total_candidates),
            avg_image_candidates=int(total_candidates),  # Image candidates ~ candidates
            cells=int(max_cells),
        )

estimate(particles_per_system, systems, cutoffs, *, base=2, multiplier=1.0) classmethod

Estimate parameters for all neighbor list types from system geometry.

Computes conservative initial capacities based on particle density and cutoff radii. The estimates are rounded up to the next power of base to amortize future resizing.

Parameters:

Name Type Description Default
particles_per_system Table[SystemId, Array]

Number of particles per system.

required
systems Table[SystemId, NeighborListSystems]

System data with unit cell information.

required
cutoffs Table[SystemId, Array]

Cutoff distance per system.

required
base float

Base for power-of rounding (default 2).

2
multiplier float

Safety factor applied to the estimate (default 1.0).

1.0

Returns:

Type Description
UniversalNeighborlistParameters

A UniversalNeighborlistParameters instance with estimated values.

Source code in src/kups/core/neighborlist.py
@classmethod
@no_jax_tracing
def estimate(
    cls,
    particles_per_system: Table[SystemId, Array],
    systems: Table[SystemId, NeighborListSystems],
    cutoffs: Table[SystemId, Array],
    *,
    base: float = 2,
    multiplier: float = 1.0,
) -> UniversalNeighborlistParameters:
    """Estimate parameters for all neighbor list types from system geometry.

    Computes conservative initial capacities based on particle density
    and cutoff radii. The estimates are rounded up to the next power of
    ``base`` to amortize future resizing.

    Args:
        particles_per_system: Number of particles per system.
        systems: System data with unit cell information.
        cutoffs: Cutoff distance per system.
        base: Base for power-of rounding (default 2).
        multiplier: Safety factor applied to the estimate (default 1.0).

    Returns:
        A ``UniversalNeighborlistParameters`` instance with estimated values.
    """
    sys = Table.join(systems, particles_per_system, cutoffs)
    total_candidates = total_edges = max_cells = 0
    for _, (s, n_p, c) in sys:
        num_cells = _num_cells(s, c).prod()
        total_candidates += min(n_p / num_cells * (3**3), n_p)
        total_edges += _estimate_avg_num_edges(
            n_p, s.unitcell.volume, c, base, multiplier
        )
        max_cells = max(num_cells, max_cells)
    total_candidates = next_higher_power(
        jnp.array(total_candidates * multiplier / sys.size), base=base
    )
    return UniversalNeighborlistParameters(
        avg_edges=int(total_edges // sys.size),
        avg_candidates=int(total_candidates),
        avg_image_candidates=int(total_candidates),  # Image candidates ~ candidates
        cells=int(max_cells),
    )

all_connected_neighborlist(lh, rh, systems, cutoffs, rh_index_remap=None)

Neighbor list connecting all pairs sharing the same inclusion segment, ignoring distance.

Connects every particle pair that belongs to the same inclusion segment and has differing exclusion segment IDs. The cutoff is ignored for neighbor selection; the unit cell is used only to compute minimum-image shifts.

Requires max_count to be set on the inclusion Index.

Source code in src/kups/core/neighborlist.py
def all_connected_neighborlist(
    lh: Table[ParticleId, NeighborListPoints],
    rh: Table[ParticleId, NeighborListPoints] | None,
    systems: Table[SystemId, NeighborListSystems],
    cutoffs: Table[SystemId, Array],
    rh_index_remap: Index[ParticleId] | None = None,
) -> Edges[Literal[2]]:
    """Neighbor list connecting all pairs sharing the same inclusion segment, ignoring distance.

    Connects every particle pair that belongs to the same inclusion segment and has
    differing exclusion segment IDs. The cutoff is ignored for neighbor selection;
    the unit cell is used only to compute minimum-image shifts.

    Requires ``max_count`` to be set on the inclusion ``Index``.
    """
    if rh is None:
        rh = lh
        rh_index_remap = Index.arange(len(lh), label=ParticleId)

    ngraphs = lh.data.inclusion.num_labels
    max_count = lh.data.inclusion.max_count
    assert max_count is not None, "inclusion.max_count must be set"
    capacity = FixedCapacity(max_count).multiply(min(lh.size, rh.size))
    out_of_bounds = max(lh.size, rh.size)

    lh_sys = systems[lh.data.system]
    rh_sys = systems[rh.data.system]

    selection_result = subselect(
        lh.data.inclusion.indices,
        rh.data.inclusion.indices,
        output_buffer_size=capacity,
        num_segments=ngraphs,
    )
    candidates = _Candidates(
        lhs=Index(lh.keys, selection_result.scatter_idxs),
        rhs=Index(rh.keys, selection_result.gather_idxs),
    )
    lh_idx, rh_idx = candidates.lhs, candidates.rhs
    lh_data, rh_data = lh[lh_idx], rh[rh_idx]
    lh_excl, rh_excl = Index.match(lh_data.exclusion, rh_data.exclusion)
    mask = lh_excl != rh_excl
    if rh_index_remap is not None:
        lh_i, rh_i = Index.match(lh_idx, rh.set_data(rh_index_remap)[rh_idx])
        mask &= ~lh_idx.isin(rh_index_remap) | (lh_i >= rh_i)

    lh_frac = triangular_3x3_matmul(
        lh_sys.unitcell.inverse_lattice_vectors, lh.data.positions
    )
    rh_frac = triangular_3x3_matmul(
        rh_sys.unitcell.inverse_lattice_vectors, rh.data.positions
    )
    shifts = jnp.round(lh_frac[lh_idx.indices] - rh_frac[rh_idx.indices]).astype(int)
    return _compact_edges(
        candidates,
        mask,
        shifts,
        rh_index_remap.indices_in(lh.keys) if rh_index_remap is not None else None,
        capacity,
        out_of_bounds,
    )

basic_neighborlist(lh, rh, systems, cutoffs, rh_index_remap, *, candidate_selector, max_num_edges, max_image_candidates=None, consider_images=True)

Core neighbor list construction algorithm with pluggable candidate selection.

Source code in src/kups/core/neighborlist.py
def basic_neighborlist(
    lh: Table[ParticleId, NeighborListPoints],
    rh: Table[ParticleId, NeighborListPoints] | None,
    systems: Table[SystemId, NeighborListSystems],
    cutoffs: Table[SystemId, Array],
    rh_index_remap: Index[ParticleId] | None,
    *,
    candidate_selector: CandidateSelector,
    max_num_edges: Capacity[int],
    max_image_candidates: Capacity[int] | None = None,
    consider_images: bool = True,
) -> Edges[Literal[2]]:
    """Core neighbor list construction algorithm with pluggable candidate selection."""
    cutoffs = Table.broadcast_to(cutoffs, systems)
    if rh is None:
        rh = lh

    # Transform coordinates to fractional using per-particle system data
    lh_inv = systems[lh.data.system].unitcell.inverse_lattice_vectors
    lh = (
        bind(lh)
        .focus(lambda x: x.data.positions)
        .apply(lambda r: triangular_3x3_matmul(lh_inv, r))
    )
    rh_inv = systems[rh.data.system].unitcell.inverse_lattice_vectors
    rh = (
        bind(rh)
        .focus(lambda x: x.data.positions)
        .apply(lambda r: triangular_3x3_matmul(rh_inv, r))
    )

    candidates = candidate_selector(lh, rh, systems)

    return _filter_candidates(
        candidates,
        lh,
        rh,
        systems,
        cutoffs,
        rh_index_remap,
        max_num_edges=max_num_edges,
        max_image_candidates=max_image_candidates,
        consider_images=consider_images,
    )

neighborlist_changes(neighborlist, lh, rh, systems, cutoffs, compaction=0.5)

Compute added/removed edges from a particle change in a single call.

Appends proposed positions to the particle array and queries both old and new interactions at once, then splits the result by filtering edge indices into removed (before) and added (after) sets.

Parameters:

Name Type Description Default
neighborlist NearestNeighborList

Neighbor list implementation.

required
lh Table[ParticleId, NeighborListPoints]

Full original particle table.

required
rh WithIndices[ParticleId, Table[ParticleId, NeighborListPoints]]

Proposed changes — rh.indices maps entries to particle IDs in lh, rh.data holds the new particle data.

required
systems Table[SystemId, NeighborListSystems]

Per-system data (unit cells, etc.).

required
cutoffs Table[SystemId, Array]

Per-system cutoff distances.

required
compaction float

Fraction of total edges allocated per output (0–1). 0.5 means each of added/removed gets half the buffer. 1.0 means no compaction — full buffer with masking only.

0.5

Returns:

Type Description
NeighborListChangesResult

NeighborListChangesResult(added, removed).

Source code in src/kups/core/neighborlist.py
@partial(jit, static_argnames=("compaction",))
def neighborlist_changes(
    neighborlist: NearestNeighborList,
    lh: Table[ParticleId, NeighborListPoints],
    rh: WithIndices[ParticleId, Table[ParticleId, NeighborListPoints]],
    systems: Table[SystemId, NeighborListSystems],
    cutoffs: Table[SystemId, Array],
    compaction: float = 0.5,
) -> NeighborListChangesResult:
    """Compute added/removed edges from a particle change in a single call.

    Appends proposed positions to the particle array and queries both old
    and new interactions at once, then splits the result by filtering
    edge indices into ``removed`` (before) and ``added`` (after) sets.

    Args:
        neighborlist: Neighbor list implementation.
        lh: Full original particle table.
        rh: Proposed changes — ``rh.indices`` maps entries to particle IDs
            in ``lh``, ``rh.data`` holds the new particle data.
        systems: Per-system data (unit cells, etc.).
        cutoffs: Per-system cutoff distances.
        compaction: Fraction of total edges allocated per output (0–1).
            0.5 means each of added/removed gets half the buffer.
            1.0 means no compaction — full buffer with masking only.

    Returns:
        ``NeighborListChangesResult(added, removed)``.
    """
    N, k = lh.size, rh.data.size
    p_idx = rh.indices.indices_in(lh.keys)

    # Build a single query with new particles on the left-hand side
    # (original particles + new particles) and both old and new particles
    # on the right-hand side (old positions at changed indices + new positions).
    lh_combined = Table.union((lh, rh.data))
    rh_combined = Table.union((Table.arange(lh[rh.indices], label=ParticleId), rh.data))
    combined_remap = Index(
        lh_combined.keys, jnp.concatenate([p_idx, jnp.arange(k) + N])
    )

    # single neighborlist call
    all_edges = neighborlist(lh_combined, rh_combined, systems, cutoffs, combined_remap)

    # split into removed / added
    raw = all_edges.indices.indices  # (n_edges, 2)
    c0, c1 = raw[:, 0], raw[:, 1]
    # Removed mask checks for edges that exist in the original set (both indices < N).
    removed_mask = (c0 < N) & (c1 < N)

    # is_stale mask checks that both edges need to be in the original set
    # or one needs to be in the original set and the other needs to be in the new set.
    is_stale = isin(c0, p_idx, N + k) & (c0 < N) | isin(c1, p_idx, N + k) & (c1 < N)
    # Added mask checks for edges that involve at least one new particle.
    added_mask = (c0 < N + k) & (c1 < N + k) & ((c0 >= N) | (c1 >= N)) & ~is_stale

    # remap appended indices N+m -> p_idx[m]
    remapped = jnp.where(raw >= N, p_idx[raw - N], raw)

    # compact each output
    n_total = raw.shape[0]
    shifts = all_edges.shifts

    def _mask_only(mask: Array, indices: Array, shifts: Array) -> Edges[Literal[2]]:
        idx = where_broadcast_last(mask, indices, N)
        sh = where_broadcast_last(mask, shifts, 0)
        return Edges(Index(lh.keys, idx), sh)

    def _compact(mask: Array, indices: Array, label: str) -> Edges[Literal[2]]:
        count = mask.sum()
        runtime_assert(
            count <= capacity,
            f"neighborlist_changes: {label} edges ({{count}}) exceed "
            f"capacity ({{capacity}})",
            fmt_args={"count": count, "capacity": jnp.array(capacity)},
        )
        sel: Array = jnp.where(mask, size=capacity, fill_value=n_total - 1)[0]
        valid = mask.at[sel].get(mode="fill", fill_value=False)
        return _mask_only(valid, indices[sel], shifts[sel])

    if compaction >= 1.0:
        return NeighborListChangesResult(
            _mask_only(added_mask, remapped, shifts),
            _mask_only(removed_mask, remapped, shifts),
        )

    capacity = int(n_total * compaction)
    return NeighborListChangesResult(
        _compact(added_mask, remapped, "added"),
        _compact(removed_mask, remapped, "removed"),
    )