Tables¶
Molecular simulations deal with data at multiple levels. A simulation contains systems, each system contains particles, and pairs of particles form edges. These entities are not independent — every particle belongs to a system, every edge connects two particles. Operations constantly cross these levels: computing a force on a particle requires knowing which system it belongs to, so that the correct unit cell can be applied.
kUPS organizes this hierarchical, relational data using two core primitives: Table and Index.
import jax.numpy as jnp
from jax import Array
from kups.core.data import Index, Table
from kups.core.typing import ParticleId, SystemId
from kups.core.utils.jax import dataclass
Table¶
A Table[Key, Data] is a keyed container. It pairs a tuple of unique keys (.keys) with a pytree of arrays (.data), where the leading dimension of every array aligns with the keys. Think of it as a lightweight database table with a primary key column and one or more value columns.
Under the hood, a Table is a JAX pytree. The keys are stored as a static field — a plain Python tuple that JAX traces through without turning into array operations. The data is the dynamic part: it holds the actual arrays that flow through jit, grad, and vmap. This split is what makes Table both relational and differentiable.
The constructor Table.arange builds a table with sequential integer keys from a pytree of arrays, inferring the row count from the leading dimension.
Table(keys=(0, 1, 2), data=Array([0.5, 1.2, 0.8], dtype=float32), _cls=<class 'kups.core.typing.SystemId'>)
Each key uniquely identifies a row. The data can be any JAX pytree — a single array, a dictionary, a dataclass — as long as all leaf arrays share the same leading dimension.
particles = Table.arange(
{"positions": jnp.zeros((4, 3)), "mass": jnp.ones(4)},
label=ParticleId,
)
particles
Table(keys=(0, 1, 2, 3), data={'mass': Array([1., 1., 1., 1.], dtype=float32), 'positions': Array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float32)}, _cls=<class 'kups.core.typing.ParticleId'>)
Index¶
An Index[Key] is a column of foreign-key references pointing into another table's key space. It stores a key vocabulary and an integer array of positions into that vocabulary. An Index typically lives inside a table's data — for instance, each particle carries an Index[SystemId] field that says which system it belongs to.
Like Table, an Index splits into static and dynamic parts. The key vocabulary is static — it defines the set of possible values and is fixed at compile time. The integer array of positions is dynamic — it can change across calls without recompilation. This encoding is similar to a categorical column in pandas: the categories are fixed, but which category each row points to is a runtime value.
The factory Index.new constructs an Index directly from a list of key values, inferring the key vocabulary from the unique values.
# Five particles: first three in system 0, last two in system 1
system_index = Index.new(
[SystemId(0), SystemId(0), SystemId(0), SystemId(1), SystemId(1)]
)
system_index
Index(['0' '0' '0' '1' '1'], cls=SystemId, keys=0-1, shape=(5,), max_count=None)
An Index supports out-of-bounds entries. The sentinel value len(keys) marks an entry as invalid — it points past the end of the vocabulary. The property Index.valid_mask returns a boolean array that is True for in-bounds entries and False for OOB ones. This is the mechanism that Buffered tables use to track occupation: unoccupied rows have their foreign-key index set to OOB, and everything downstream can filter on .valid_mask.
# Mark the last entry as out-of-bounds
system_index = Index(keys=(SystemId(0),), indices=jnp.array([0, 0, 0, 1, 1]))
# Each entry is either in-bounds (True) or OOB (False).
print("valid_mask:", system_index.valid_mask)
system_index
valid_mask: [ True True True False False]
Index(['0' '0' '0' 'OOB' 'OOB'], cls=SystemId, keys=(0,), shape=(5,), max_count=None)
Keys don't have to be integers — any sortable type works. String keys are useful when the identity carries meaning, such as species labels like "C_CO2" or "O_CO2" that encode both the element and the molecule it belongs to.
# Three particles in a CO2 molecule, indexed by species label
species = Index.new(["C_CO2", "O_CO2", "O_CO2"])
species
Index(['C_CO2' 'O_CO2' 'O_CO2'], cls=str, keys=('C_CO2', 'O_CO2'), shape=(3,), max_count=None)
Foreign-Key Lookup¶
The central operation in kUPS is table[index], which gathers rows from a table according to a foreign-key index. This is a relational JOIN: it broadcasts data from one level of the hierarchy to another.
Because both the table's keys and the index's key vocabulary are static, the mapping between the two key spaces is resolved at trace time. The result is a precomputed integer remapping array that translates the index's positions into positions in the table. At runtime, the lookup is a single jnp.take — no key comparisons, no searching.
# Two systems with different energies
systems = Table.arange(jnp.array([0.5, 1.2]), label=SystemId)
# Four particles: first two in system 0, last two in system 1
system_of_particle = Index.new([SystemId(0), SystemId(0), SystemId(1), SystemId(1)])
# Broadcast system energy to each particle
per_particle_energy = systems[system_of_particle]
print(per_particle_energy)
[0.5 0.5 1.2 1.2]
In practice, a table's data is often a dataclass whose fields include Index leaves pointing at other tables. This is how kUPS encodes relationships: particles store an Index[SystemId] referencing the systems table, and the foreign-key lookup systems[particles.data.system] broadcasts system-level data (like the unit cell) down to each particle.
The same pattern chains through every level of the hierarchy. System data broadcasts to particles via systems[particles.data.system], particle data broadcasts to edges via particles[edges.indices], and so on. Each level only stores a foreign-key index to the level above — there is no duplication of data, and the relationships are always explicit in the type signatures.
@dataclass
class SystemData:
unitcell: Array
@dataclass
class ParticleData:
positions: Array
system: Index[SystemId]
# Two systems with different unit cells
systems = Table.arange(
SystemData(unitcell=jnp.array([5.0 * jnp.eye(3), 10.0 * jnp.eye(3)])),
label=SystemId,
)
# Three particles: first two in system 0, last in system 1
particles = Table.arange(
ParticleData(
positions=jnp.array([[0.1, 0.0, 0.0], [0.2, 0.0, 0.0], [0.3, 0.0, 0.0]]),
system=Index.new([SystemId(0), SystemId(0), SystemId(1)]),
),
label=ParticleId,
)
# Broadcast unit cell to each particle via foreign-key lookup
per_particle_unitcell = systems[particles.data.system]
print(per_particle_unitcell.unitcell)
[[[ 5. 0. 0.]
[ 0. 5. 0.]
[ 0. 0. 5.]]
[[ 5. 0. 0.]
[ 0. 5. 0.]
[ 0. 0. 5.]]
[[10. 0. 0.]
[ 0. 10. 0.]
[ 0. 0. 10.]]]
Key Types¶
kUPS defines sentinel types like SystemId, ParticleId, and Label — lightweight int or str subclasses that serve as type-safe identifiers. A Table[SystemId, ...] and a Table[ParticleId, ...] are distinct types, even though both use integer keys under the hood. This prevents accidentally indexing a systems table with a particle index. Static type checkers like pyright will flag such mismatches as errors before the code ever runs.
systems = Table.arange(jnp.array([0.5, 1.2]), label=SystemId)
wrong_index = Index.new([ParticleId(0), ParticleId(5)])
# pyright would flag this: Index[ParticleId] cannot index Table[SystemId, ...]
try:
systems[wrong_index] # type: ignore[index]
except ValueError as e:
print(e)
Keys in self not found in target: [5]
Union: Batching Multiple Systems¶
In practice, we often want to simulate many independent systems at once as a single vectorized computation. Rather than using jax.vmap — which requires all systems to have the same number of particles — kUPS flattens everything into a single set of tables. Table.union concatenates multiple tables into one, automatically remapping integer keys so they remain globally unique.
The critical detail is what happens to foreign-key indices nested inside the data. When two particle tables are concatenated, each particle's Index[SystemId] must be shifted to point at the correct system in the merged systems table. Table.union handles this automatically: it detects Index leaves, identifies which key space they reference, and applies the corresponding offset. The result is a single flat representation where all cross-references remain valid.
# System A: one system with two particles
sys_a = Table.arange(SystemData(unitcell=5.0 * jnp.eye(3)[None]), label=SystemId)
parts_a = Table.arange(
ParticleData(
positions=jnp.array([[0.1, 0.0, 0.0], [0.2, 0.0, 0.0]]),
system=Index.new([SystemId(0), SystemId(0)]),
),
label=ParticleId,
)
# System B: one system with three particles
sys_b = Table.arange(SystemData(unitcell=10.0 * jnp.eye(3)[None]), label=SystemId)
parts_b = Table.arange(
ParticleData(
positions=jnp.array([[0.3, 0.0, 0.0], [0.4, 0.0, 0.0], [0.5, 0.0, 0.0]]),
system=Index.new([SystemId(0), SystemId(0), SystemId(0)]),
),
label=ParticleId,
)
# Union merges both levels, remapping keys and foreign-key indices
systems, particles = Table.union([sys_a, sys_b], [parts_a, parts_b])
print("system keys:", systems.keys)
print("particle keys:", particles.keys)
print("particle→system index:", particles.data.system.indices)
# The foreign-key lookup still works correctly after the merge
print("unit cell per particle:", systems[particles.data.system].unitcell)
system keys: (0, 1)
particle keys: (0, 1, 2, 3, 4)
particle→system index: [0 0 1 1 1]
unit cell per particle: [[[ 5. 0. 0.]
[ 0. 5. 0.]
[ 0. 0. 5.]]
[[ 5. 0. 0.]
[ 0. 5. 0.]
[ 0. 0. 5.]]
[[10. 0. 0.]
[ 0. 10. 0.]
[ 0. 0. 10.]]
[[10. 0. 0.]
[ 0. 10. 0.]
[ 0. 0. 10.]]
[[10. 0. 0.]
[ 0. 10. 0.]
[ 0. 0. 10.]]]
This flat representation is what flows through the simulation. Every kernel — neighbor lists, potentials, integrators — operates on the merged table without needing to know how many systems it contains or where one ends and the next begins. The foreign-key indices handle the bookkeeping.
JAX Compatibility¶
Both Table and Index store their keys as trace-time static fields. This means the key vocabulary is fixed at jax.jit compile time and never flows through the computation graph. Only the data arrays and the integer index arrays are dynamic.
This has an important consequence: changing the number of keys — adding a particle, removing a system — changes the static structure and forces JAX to recompile. Within a compiled function, tables are fixed-size containers. You can mutate the data (update positions, swap energies) but not the shape. This is a fundamental constraint of JAX's tracing model, and it is what the Buffered table (below) is designed to work around.
Because the key mapping is resolved statically, index maps between different key spaces can be precomputed — so the foreign-key lookup table[index] compiles down to a simple integer gather.
Buffered Tables¶
As noted above, jax.jit compiles a function for a fixed set of static fields. A Table with different keys triggers recompilation. This is a problem for simulations where the number of particles changes — such as Grand Canonical Monte Carlo, which inserts and deletes molecules at every step.
Buffered solves this by pre-allocating extra unoccupied rows. The total capacity (occupied + free) is fixed at compile time, so the keys never change and no recompilation is needed. An occupation mask derived from the data tracks which rows are active. Insertions fill free slots by writing data into them and marking them as occupied; deletions mark rows as unoccupied by setting their foreign-key index to an out-of-bounds sentinel. The compiled kernel stays the same throughout — only the mask and data values change.
The occupation mask is not stored as a separate field. Instead, Buffered derives it from a designated Index leaf in the data (by default, the .system field). Rows where this index points out of bounds are considered unoccupied. This keeps the representation minimal and avoids the need to synchronize a separate mask with the data. The Buffered.pad factory adds free slots to an existing table.
from kups.core.data import Buffered
# Start with 3 occupied particles, add 2 free buffer slots
particles = Table.arange(
ParticleData(
positions=jnp.array([[0.1, 0.0, 0.0], [0.2, 0.0, 0.0], [0.3, 0.0, 0.0]]),
system=Index.new([SystemId(0), SystemId(0), SystemId(1)]),
),
label=ParticleId,
)
buffered = Buffered.pad(particles, num_free=2)
print("total capacity:", len(buffered))
print("occupied:", buffered.num_occupied)
print("occupation mask:", buffered.occupation)
total capacity: 5
occupied: 3
occupation mask: [ True True True False False]
To insert a particle, Buffered.select_free finds an unoccupied slot and returns an Index pointing at it. Writing valid data (with an in-bounds system index) into that slot marks it as occupied.
# Insert a new particle into a free slot
free = buffered.select_free(1)
new_particle = ParticleData(
positions=jnp.array([[9.0, 9.0, 9.0]]),
system=Index.new([SystemId(0)]),
)
buffered = buffered.update(free, new_particle)
print("occupied:", buffered.num_occupied)
print("occupation mask:", buffered.occupation)
print("positions:\n", buffered.data.positions)
occupied: 4
occupation mask: [ True True True True False]
positions:
[[0.1 0. 0. ]
[0.2 0. 0. ]
[0.3 0. 0. ]
[9. 9. 9. ]
[0. 0. 0. ]]
Deletion is the reverse: writing an OOB system index into a row marks it as unoccupied. On the next construction of the Buffered, all data leaves in that row are sanitized — arrays are zeroed out and other Index leaves are set to OOB. The slot is now free for reuse.
# Delete particle 1 by setting its system index to OOB
to_delete = Index(buffered.keys, jnp.array([1]))
oob = len(buffered.data.system.keys)
deleted_data = ParticleData(
positions=buffered.data.positions[1:2],
system=Index(buffered.data.system.keys, jnp.array([oob])),
)
buffered = buffered.update(to_delete, deleted_data)
print("occupied:", buffered.num_occupied)
print("occupation mask:", buffered.occupation)
print("positions:\n", buffered.data.positions)
occupied: 3
occupation mask: [ True False True True False]
positions:
[[0.1 0. 0. ]
[0. 0. 0. ]
[0.3 0. 0. ]
[9. 9. 9. ]
[0. 0. 0. ]]