Runtime Assertions¶
JAX compiles functions into fixed-size computation graphs. When something goes wrong at runtime, like a neighbor list overflowing its pre-allocated buffer, JAX cannot resize arrays or throw exceptions from inside compiled code. The simulation either silently produces garbage or crashes with an opaque error. We need a way to detect problems inside JIT-compiled code and handle them outside, where we can actually do something about it.
from dataclasses import replace
import jax
import jax.numpy as jnp
from jax import Array
from kups.core.assertion import runtime_assert, with_runtime_assertions
from kups.core.capacity import LensCapacity
from kups.core.lens import lens
from kups.core.result import as_result_function
from kups.core.utils.jax import dataclass, field
The Approach¶
The idea is to declare side-channel operations using custom JAX primitives. Under normal execution, these primitives are no-ops: they pass their inputs through unchanged and produce no visible effect. The computation runs as if they were not there.
To actually collect the side-channel data, we wrap the function with a custom interpreter built on Slub (a library for writing JAX interpreters). The with_runtime_assertions decorator installs that interpreter: it intercepts the custom primitives during tracing, threads their payloads through the computation, and delivers them alongside the return value at the boundary.
Without the interpreter, the runtime_assert primitives are invisible. With it, they carry structured information out of compiled code.
def compute(x):
runtime_assert(x > 0, "x must be positive")
return x**2
# Without the interpreter: runtime_assert is a no-op
print("bare:", jax.jit(compute)(jnp.array(-1.0)))
# With the interpreter: assertions are collected alongside the return value
compute_traced = with_runtime_assertions(compute)
value, assertions = compute_traced(jnp.array(-1.0))
print("traced value:", value)
print("num assertions:", len(assertions))
print("predicate:", assertions[0].predicate)
bare: 1.0
traced value: 1.0
num assertions: 1
predicate: False
The bare call happily returns 1.0 (the square of -1) and nobody notices the invalid input. The wrapped call returns the same value but also carries the failed assertion out of the computation.
RuntimeAssertion¶
A RuntimeAssertion is the structured payload carried by these primitives. It holds a boolean predicate (pass/fail), a human-readable message with format placeholders for runtime values, and optionally a fix function that can repair the state when the check fails.
Unlike Python's assert, a RuntimeAssertion does not interrupt execution. The computation runs to completion, assertions are collected at the boundary, and only then are they inspected on the host. This separation is necessary because JAX traces the full computation graph ahead of time and cannot branch on runtime values during compilation.
def checked_sqrt(x):
runtime_assert(x >= 0, "Cannot take sqrt of {val}", fmt_args={"val": x})
return jnp.sqrt(jnp.abs(x))
checked_sqrt_traced = with_runtime_assertions(checked_sqrt)
# Passing
value, assertions = checked_sqrt_traced(jnp.array(4.0))
print("value:", value, "pass:", assertions[0].predicate)
# Failing — the computation still produces a value
value, assertions = checked_sqrt_traced(jnp.array(-4.0))
print("value:", value, "pass:", assertions[0].predicate)
# The assertion carries the error message with runtime values substituted
try:
assertions[0].check()
except AssertionError as e:
print("error:", str(e).split("\n")[0])
value: 2.0 pass: True
value: 2.0 pass: False
error: Cannot take sqrt of -4.0
Notice that the computation still produces a value (sqrt of abs(-4) = 2.0). The assertion does not stop it. This is by design: we want the full computation to finish so all assertions can be collected and inspected together.
Assertions Inside JIT¶
runtime_assert works inside jax.jit, jax.lax.scan, jax.lax.while_loop, and jax.lax.cond. The Slub interpreter handles each case: in scan and while_loop, failures are accumulated across iterations. In cond, both branches must produce structurally matching assertions so the interpreter can merge them.
@jax.jit
@with_runtime_assertions
def jitted_compute(x):
runtime_assert(x > 0, "x must be positive, got {val}", fmt_args={"val": x})
return x**2
value, assertions = jitted_compute(jnp.array(5.0))
print("pass:", assertions[0].predicate, "value:", value)
value, assertions = jitted_compute(jnp.array(-1.0))
print("pass:", assertions[0].predicate, "value:", value)
pass: True value: 25.0
pass: False value: 1.0
Result: Convenience Wrapper¶
Manually unpacking the return tuple of with_runtime_assertions and iterating over assertions gets tedious. Result wraps the value and assertions into a single object with helper methods. as_result_function is a decorator that applies with_runtime_assertions and packs the output into a Result.
The caller can check Result.all_assertions_pass (a single boolean, cheap to transfer from device to host), inspect Result.failed_assertions for details, or call Result.raise_assertion to throw the first failure.
@as_result_function
def checked_divide(x, y):
runtime_assert(y != 0, "Division by zero: y={val}", fmt_args={"val": y})
return x / y
result = checked_divide(jnp.array(10.0), jnp.array(2.0))
print("value:", result.value, "pass:", result.all_assertions_pass)
result = checked_divide(jnp.array(10.0), jnp.array(0.0))
print("value:", result.value, "pass:", result.all_assertions_pass)
print("failed:", len(result.failed_assertions))
value: 5.0 pass: True
value: inf pass: False
failed: 1
Fix Functions¶
Assertions become much more powerful when they carry repair instructions. A RuntimeAssertion can include a fix_fn that, given the current state, returns a corrected version. When the host inspects a failed assertion, it calls Result.fix_or_raise to apply all available fixes. Assertions without a fix function raise their configured exception.
This enables a retry loop: run the computation, check assertions, fix what failed, run again. The simulation makes progress even when buffers overflow, because the fix resizes them and the next attempt succeeds.
@dataclass
class State:
buffer_size: int = field(static=True)
data: Array
@as_result_function
def process(state: State):
needed = 10
runtime_assert(
jnp.array(state.buffer_size >= needed),
"Buffer too small: {size} < {needed}",
fmt_args={"size": jnp.array(state.buffer_size), "needed": jnp.array(needed)},
fix_fn=lambda s, new_size: replace(s, buffer_size=int(new_size)),
fix_args=jnp.array(needed),
)
return state.data.sum()
state = State(buffer_size=5, data=jnp.ones(5))
result = process(state)
print("pass:", result.all_assertions_pass)
# Fix the state and retry
fixed_state = result.fix_or_raise(state)
print("fixed buffer_size:", fixed_state.buffer_size)
result = process(fixed_state)
print("after fix:", result.all_assertions_pass)
pass: False
fixed buffer_size: 10
after fix: True
Capacity¶
Capacity is the main user of the fix-and-retry pattern. It manages dynamically-sized buffers like neighbor lists, where the required size is only known at runtime.
A Capacity object knows its current size and can generate an assertion that checks whether a required size fits. If it does not, the assertion carries a fix function that resizes the buffer (via a Lens into the state) to the next power of a growth base (default 2x). This exponential growth strategy amortizes the cost of resizing: the buffer doubles each time, so the number of retries is logarithmic in the final size. LensCapacity is the lens-backed implementation used below; FixedCapacity is a non-resizing variant for buffers of known, final size.
@dataclass
class SimState:
max_edges: int = field(static=True)
positions: Array
def simulate(state: SimState):
capacity = LensCapacity(
size=state.max_edges,
size_lens=lens(lambda s: s.max_edges, cls=SimState),
)
actual_edges = jnp.array(150)
capacity.generate_assertion(actual_edges)
return state.positions.sum()
state = SimState(max_edges=100, positions=jnp.ones((10, 3)))
result = as_result_function(simulate)(state)
print("pass:", result.all_assertions_pass)
# The fix resizes to the next power of 2
fixed = result.fix_or_raise(state)
print("resized max_edges:", fixed.max_edges)
pass: False
resized max_edges: 256
The capacity was 100, the simulation needed 150, so the fix resized to 256 (the next power of 2 above 150). The host-side loop updates the state and retries. The next JIT call sees the larger buffer and succeeds.
The Retry Loop¶
Putting it all together: the propagation loop runs a JIT-compiled propagator, checks assertions, applies fixes, and retries. It lives on the host side, outside JAX, because resizing arrays changes the static structure and requires recompilation. propagate_and_fix is the built-in host-side driver that implements this pattern.
During warmup, retries are expected as the simulation discovers the right buffer sizes. Once capacities stabilize, the assertions keep passing and the loop runs without interruption. The custom primitives remain in the compiled code but their predicates evaluate to True, so the host-side check is a single boolean transfer with negligible cost.
def propagate_and_fix(fn, state, max_retries=5):
"""Run fn, check assertions, fix and retry if needed."""
for attempt in range(max_retries):
result = fn(state)
if result.all_assertions_pass:
return result.value, state
print(f" retry {attempt + 1}: fixing failed assertions")
state = result.fix_or_raise(state)
raise RuntimeError("Too many retries")
wrapped = as_result_function(simulate)
state = SimState(max_edges=10, positions=jnp.ones((10, 3)))
value, final_state = propagate_and_fix(wrapped, state)
print("final max_edges:", final_state.max_edges)
print("value:", value)
retry 1: fixing failed assertions
final max_edges: 256
value: 30.0