Skip to content

kups.core.utils.position

Utilities for computing particle positions and center of mass.

Functions in this module honor the cell's per-axis periodic mask via cell.wrap — periodic axes fold into the primary cell, non-periodic axes pass through unchanged. The same code path covers fully-periodic crystals, vacuum clusters, and slab / wire geometries.

center_of_mass(particles, cells)

Compute center of mass for indexed particle groups.

Calculates the center of mass for each group of particles defined by group index, honoring the cell's per-axis periodic mask. The computation ensures that wrapped particles are unwrapped relative to a reference particle in each group before averaging.

Warning

On periodic axes, assumes each molecular structure is smaller than half the cell size. Structures spanning more than half the box may yield incorrect results.

Parameters:

Name Type Description Default
particles Table[ParticleId, P]

Indexed particles, optionally supporting HasWeights for mass weighting.

required
cells Cell

Cell(s) carrying lattice geometry and per-axis periodicity. Must have one cell per group.

required

Returns:

Type Description
Array

Shape (num_groups, 3) containing center-of-mass positions for each group.

Example
# Compute COM for each molecule
com = center_of_mass(molecules, cell)
Source code in src/kups/core/utils/position.py
@jit
def center_of_mass[P: HasPositionsAndGroupIndex](
    particles: Table[ParticleId, P], cells: Cell
) -> Array:
    """Compute center of mass for indexed particle groups.

    Calculates the center of mass for each group of particles defined by
    group index, honoring the cell's per-axis ``periodic`` mask. The
    computation ensures that wrapped particles are unwrapped relative to a
    reference particle in each group before averaging.

    Warning:
        On periodic axes, assumes each molecular structure is smaller than
        half the cell size. Structures spanning more than half the box may
        yield incorrect results.

    Args:
        particles: Indexed particles, optionally supporting `HasWeights` for mass weighting.
        cells: Cell(s) carrying lattice geometry and per-axis periodicity.
            Must have one cell per group.

    Returns:
        Shape `(num_groups, 3)` containing center-of-mass positions for each group.

    Example:
        ```python
        # Compute COM for each molecule
        com = center_of_mass(molecules, cell)
        ```
    """
    group_ids = particles.data.group.indices
    num_groups = particles.data.group.num_labels
    # TODO: This function assumes that the structure is less than half the size of the cell!
    assert cells.vectors.shape[0] == num_groups, (
        f"Cells must match the number of groups. Got {cells.vectors.shape[0]} for {num_groups} groups."
    )
    batched_cells = cells[group_ids]
    # Index any particle in each group
    ref_idx = (
        jnp.zeros((num_groups,), dtype=int)
        .at[group_ids]
        .set(jnp.arange(len(group_ids)), mode="drop")
    )
    positions = particles.data.positions
    if isinstance(particles.data, HasWeights):
        weight = particles.data.weights[:, None]
    else:
        weight = jnp.ones_like(positions[:, 0])[:, None]
    offsets = positions[ref_idx]
    rel_positions = positions - offsets[group_ids]
    rel_positions = batched_cells.wrap(rel_positions)
    center_of_masses = jax.ops.segment_sum(
        rel_positions * weight,
        group_ids,
        num_groups,
    )
    total_mass = jax.ops.segment_sum(
        weight,
        group_ids,
        num_groups,
    )
    center_of_masses /= total_mass
    center_of_masses += offsets
    center_of_masses = cells.wrap(center_of_masses)
    return center_of_masses

to_absolute_positions(particles, cells, center_of_masses)

Calculate absolute positions from relative positions and group COMs.

Inverse operation of to_relative_positions. Converts positions defined relative to group centers of mass back to absolute coordinates, honoring the cell's per-axis periodic mask via wrap (identity on non-periodic axes).

Parameters:

Name Type Description Default
particles Table[ParticleId, P]

Indexed particles with relative positions and group index.

required
cells Cell

Cell(s) carrying lattice geometry and per-axis periodicity.

required
center_of_masses Array

Centers of mass for each group, shape (num_groups, 3).

required

Returns:

Type Description
Array

Shape (N, 3) containing absolute particle positions.

Example
abs_pos = to_absolute_positions(rel_molecules, cell, com)
Source code in src/kups/core/utils/position.py
@jit
def to_absolute_positions[P: HasPositionsAndGroupIndex](
    particles: Table[ParticleId, P],
    cells: Cell,
    center_of_masses: Array,
) -> Array:
    """Calculate absolute positions from relative positions and group COMs.

    Inverse operation of `to_relative_positions`. Converts positions defined
    relative to group centers of mass back to absolute coordinates, honoring
    the cell's per-axis ``periodic`` mask via ``wrap`` (identity on
    non-periodic axes).

    Args:
        particles: Indexed particles with relative positions and group index.
        cells: Cell(s) carrying lattice geometry and per-axis periodicity.
        center_of_masses: Centers of mass for each group, shape `(num_groups, 3)`.

    Returns:
        Shape `(N, 3)` containing absolute particle positions.

    Example:
        ```python
        abs_pos = to_absolute_positions(rel_molecules, cell, com)
        ```
    """
    group_ids = particles.data.group.indices
    positions = particles.data.positions
    abs_positions = positions + center_of_masses.at[group_ids].get(
        mode="fill", fill_value=0
    )
    abs_positions = cells[group_ids].wrap(abs_positions)
    return abs_positions

to_relative_positions(particles, cells, center_of_masses=None)

Calculate particle positions relative to their group's center of mass.

Transforms absolute particle positions to positions relative to each group's center of mass, honoring the cell's per-axis periodic mask (periodic axes wrap; non-periodic axes pass through unchanged).

Parameters:

Name Type Description Default
particles Table[ParticleId, P]

Indexed particles with position and group index data. Supports HasWeights if center of mass needs to be computed.

required
cells Cell

Cell(s) carrying lattice geometry and per-axis periodicity.

required
center_of_masses Array | None

Optional precomputed centers of mass. If None, will be computed automatically.

None

Returns:

Type Description
Array

Shape (N, 3) containing positions relative to group COMs.

Example
rel_pos = to_relative_positions(molecules, cell)
Source code in src/kups/core/utils/position.py
@jit
def to_relative_positions[P: HasPositionsAndGroupIndex](
    particles: Table[ParticleId, P],
    cells: Cell,
    center_of_masses: Array | None = None,
) -> Array:
    """Calculate particle positions relative to their group's center of mass.

    Transforms absolute particle positions to positions relative to each
    group's center of mass, honoring the cell's per-axis ``periodic`` mask
    (periodic axes wrap; non-periodic axes pass through unchanged).

    Args:
        particles: Indexed particles with position and group index data. Supports
            `HasWeights` if center of mass needs to be computed.
        cells: Cell(s) carrying lattice geometry and per-axis periodicity.
        center_of_masses: Optional precomputed centers of mass. If `None`,
            will be computed automatically.

    Returns:
        Shape `(N, 3)` containing positions relative to group COMs.

    Example:
        ```python
        rel_pos = to_relative_positions(molecules, cell)
        ```
    """
    group_ids = particles.data.group.indices
    if center_of_masses is None:
        center_of_masses = center_of_mass(particles, cells)
    positions = particles.data.positions
    rel_positions = positions - center_of_masses.at[group_ids].get(
        mode="fill", fill_value=0
    )
    rel_positions = cells[group_ids].wrap(rel_positions)
    return rel_positions