Skip to content

kups.core.neighborlist.postprocess

Postprocessors for the neighbor list pipeline.

Postprocessors run after compaction and can transform the final Edges using both the edges and the prepared pipeline context. They are the right place for graph-level output shaping that should not be coupled to row compaction itself.

MirrorPairEdges

Bases: Postprocessor[Literal[2]]

Append reversed pair edges for undirected graph outputs.

The default mirrors only self-graph update calls selected by ctx.for_indices. Full self-neighbor calls already emit both directions before compaction, while for_indices calls operate on affected ids in the already-updated lh table and are deduplicated by ForIndicesDedupMask. Their reverse edges are restored after compaction.

Attributes:

Name Type Description
only_when_for_indices bool

When True, no-op unless ctx.for_indices is active; ctx.rh is not involved. Set to False for pipelines whose selector emits only one direction even in full calls.

Source code in src/kups/core/neighborlist/postprocess.py
@dataclass
class MirrorPairEdges(Postprocessor[Literal[2]]):
    """Append reversed pair edges for undirected graph outputs.

    The default mirrors only self-graph update calls selected by
    ``ctx.for_indices``. Full self-neighbor calls already emit both directions
    before compaction, while ``for_indices`` calls operate on affected ids in
    the already-updated ``lh`` table and are deduplicated by
    ``ForIndicesDedupMask``. Their reverse edges are restored after compaction.

    Attributes:
        only_when_for_indices: When ``True``, no-op unless ``ctx.for_indices``
            is active; ``ctx.rh`` is not involved. Set to ``False`` for
            pipelines whose selector emits only one direction even in full
            calls.
    """

    only_when_for_indices: bool = field(default=True, static=True)

    def __call__(
        self, edges: Edges[Literal[2]], ctx: PipelineContext
    ) -> Edges[Literal[2]]:
        if self.only_when_for_indices and ctx.for_indices is None:
            return edges

        indices = edges.indices.indices
        mirrored_indices = jnp.concatenate([indices, indices[:, ::-1]], axis=0)
        mirrored_shifts = jnp.concatenate([edges.shifts, -edges.shifts], axis=0)
        return Edges(Index(edges.indices.keys, mirrored_indices), mirrored_shifts)