Conventions¶
kUPS is a toolkit. Users define their own simulation states with exactly the fields they need, and there is no universal base class that all simulations inherit from. Instead, kUPS relies on a set of conventions that keep code composable, type-safe, and compatible with JAX.
This notebook explains those conventions and why they exist.
from dataclasses import replace
from typing import Callable, Protocol, runtime_checkable
import jax.numpy as jnp
from jax import Array
from kups.core.lens import Lens, View, lens
from kups.core.utils.jax import dataclass
Protocols over Classes¶
The classic OOP approach to shared interfaces is inheritance: define an abstract base class, make everything inherit from it. This works poorly for simulation data. A Lennard-Jones simulation needs positions, masses, and LJ parameters. A molecular dynamics simulation needs those plus momenta. A GCMC simulation needs charges on top. With inheritance, you either build a deep class hierarchy where every combination is a subclass, or you create a God class with every possible field, most of which go unused in any given simulation.
kUPS uses structural typing via Python Protocol classes instead. A dataclass does not inherit from HasPositions. It just has a positions property, and the type checker recognizes the match. Data types can satisfy multiple protocols without multiple inheritance, and new protocols can be added without touching existing code. Each simulation defines only the fields it actually uses.
@runtime_checkable
class HasPositions(Protocol):
@property
def positions(self) -> Array: ...
@dataclass
class MyParticles:
positions: Array
masses: Array
particles = MyParticles(positions=jnp.zeros((3, 3)), masses=jnp.ones(3))
# No inheritance — structural match is enough
print(isinstance(particles, HasPositions))
True
Protocol Naming¶
With dozens of protocols in the codebase, a consistent naming scheme matters. The verb prefix tells you at a glance what kind of requirement the protocol encodes. Names should read as proper English: "this object has positions", "this object is a state".
Has*for a single specific field or property: HasPositions, HasUnitCell, HasSystemIndex. The most granular level, one protocol per field.Is*for a multi-field role or identity:IsState,IsMDState. The object "is" conceptually that thing. Used when several fields always appear together.Supports*for an operation capability:SupportsAdd,SupportsDType. Used for generic algorithms that need specific operators or methods.- No verb for callables and framework abstractions defined by behavior: Propagator, Lens. These are already nouns that describe what the object does.
@runtime_checkable
class HasMasses(Protocol):
@property
def masses(self) -> Array: ...
# Is* — multi-field role
@runtime_checkable
class IsParticleData(HasPositions, HasMasses, Protocol): ...
# Supports* — operation capability
@runtime_checkable
class SupportsScale(Protocol):
def scale(self, factor: float) -> "SupportsScale": ...
Composite Protocols¶
Real functions rarely need just one field. A center-of-mass calculation needs both positions and masses. An integrator needs positions, momenta, masses, and forces. Rather than creating a monolithic protocol for every combination, protocols compose freely. IsParticleData above simply inherits from HasPositions and HasMasses, with no diamond problems or method resolution order issues.
The result is that functions declare exactly what they need, no more and no less. A function that only reads positions accepts HasPositions, so it works with any dataclass that has that field, even one written for an entirely different simulation. This composability is what makes kUPS components reusable across simulation types.
@dataclass
class SimpleParticles:
positions: Array
masses: Array
@dataclass
class ChargedParticles:
positions: Array
masses: Array
charges: Array
def center_of_mass(particles: IsParticleData) -> Array:
"""Works with any object that has positions and masses."""
total_mass = jnp.sum(particles.masses)
return jnp.sum(particles.masses[:, None] * particles.positions, axis=0) / total_mass
simple = SimpleParticles(
positions=jnp.array([[0.0, 0.0, 0.0], [2.0, 0.0, 0.0]]),
masses=jnp.array([1.0, 1.0]),
)
charged = ChargedParticles(
positions=jnp.array([[0.0, 0.0, 0.0], [4.0, 0.0, 0.0]]),
masses=jnp.array([1.0, 3.0]),
charges=jnp.zeros(2),
)
# Same function, different types — protocols bridge the gap
print("simple:", center_of_mass(simple))
print("charged:", center_of_mass(charged))
simple: [1. 0. 0.]
charged: [3. 0. 0.]
Data and Logic Separation¶
JAX's jit, grad, and vmap operate on pure functions over pytrees. Simulation state must therefore be a pytree (a nested structure of arrays), and the operations on it must be standalone functions with no hidden side effects.
kUPS follows this by convention. Dataclasses hold state as fields. Standalone functions transform it. There are no computational methods on the dataclass itself. The one exception is @property for derived quantities that are trivially computed from stored fields, like forces from the position gradient, or velocities from momenta and masses. These are cheap to recompute on access, so storing them separately would only add redundancy.
Why store the position gradient instead of the forces directly? The gradient is what jax.grad produces. Forces are just its negation, a convenience wrapper. By storing the primitive and deriving the convenience, the pytree stays minimal and aligned with what JAX actually computes.
@dataclass
class MDParticles:
positions: Array
momenta: Array
masses: Array
position_gradient: Array # stored: comes from autodiff
@property
def forces(self) -> Array: # derived: negative gradient
return -self.position_gradient
@property
def velocities(self) -> Array: # derived: momenta / mass
return self.momenta / self.masses[:, None]
p = MDParticles(
positions=jnp.array([[1.0, 0.0, 0.0]]),
momenta=jnp.array([[2.0, 0.0, 0.0]]),
masses=jnp.array([1.0]),
position_gradient=jnp.array([[-0.5, 0.0, 0.0]]),
)
print("forces:", p.forces)
print("velocities:", p.velocities)
forces: [[ 0.5 -0. -0. ]]
velocities: [[2. 0. 0.]]
Logic that transforms state lives in standalone functions. The original state is never modified; a new one is returned. jax.jit requires this pure-functional contract, and it also makes code easier to reason about since you can inspect the old and new state side by side.
def apply_forces(p: MDParticles, dt: float) -> MDParticles:
new_momenta = p.momenta + p.forces * dt
return replace(p, momenta=new_momenta)
p2 = apply_forces(p, dt=0.1)
print("original momenta:", p.momenta)
print("updated momenta: ", p2.momenta)
original momenta: [[2. 0. 0.]]
updated momenta: [[2.05 0. 0. ]]
make_* Constructors¶
Because kUPS components are generic over state types (via protocols and lenses), constructing them requires wiring up Views and Lenses that point into the specific state layout. This wiring is the cost of generality, but it does not have to be paid every time.
kUPS provides two tiers of factory functions. make_*_potential is the full-featured constructor. It accepts explicit View and Lens arguments for every data source: positions, parameters, gradients. You can point each view at a different part of the state, share caches across potentials, or wire up non-standard layouts. It is the foundation that everything else builds on.
make_*_from_state is the shorthand wrapper. It takes a single Lens to a state object and extracts the individual views automatically, assuming the state follows the naming conventions. One line instead of five. Use it when the data layout matches. Drop down to make_*_potential when you need custom wiring.
# Toy example mirroring the real pattern.
# make_spring_potential: full control — explicit views for each data source.
def make_spring_potential[S](
positions_view: View[S, Array],
masses_view: View[S, Array],
k: float,
) -> Callable[[S], Array]:
def potential(state: S) -> Array:
pos = positions_view(state)
m = masses_view(state)
return 0.5 * k * jnp.sum(m[:, None] * pos**2)
return potential
# make_spring_from_state: shorthand — extracts views from a single lens.
def make_spring_from_state[S](
state_lens: Lens[S, IsParticleData],
k: float,
) -> Callable[[S], Array]:
return make_spring_potential(
positions_view=state_lens.focus(lambda s: s.positions),
masses_view=state_lens.focus(lambda s: s.masses),
k=k,
)
# Both produce the same result
@dataclass
class State:
particles: SimpleParticles
s = State(
particles=SimpleParticles(
positions=jnp.array([[1.0, 0.0, 0.0]]), masses=jnp.array([2.0])
)
)
# Full control
pot_a = make_spring_potential(
positions_view=lens(lambda s: s.particles.positions, cls=State),
masses_view=lens(lambda s: s.particles.masses, cls=State),
k=3.0,
)
# Shorthand
pot_b = make_spring_from_state(
state_lens=lens(lambda s: s.particles, cls=State),
k=3.0,
)
print("full control:", pot_a(s))
print("from_state: ", pot_b(s))
full control: 3.0
from_state: 3.0