Logging¶
Everything kUPS calls a step is pure: a function of (key, state) to state, traced and compiled once, reused many times, and allowed to donate its state buffer on return. Logging is the opposite. It writes to files, pushes pixels into a terminal, takes wall-clock samples, or opens a profiler trace, and none of those operations can live inside jax.jit without either breaking the tracing contract or forcing a JAX-to-host round-trip on every step. The pragmatic consequence is that observability stays in Python, around the step function, not inside it.
What kUPS provides is not a logging framework in the usual sense. It is a small Logger protocol and a handful of concrete loggers that satisfy it. Progress bars, HDF5 writers, and JAX profiler traces are all cases of the same shape. Users are free to add their own, to decide what to record, how often, and where to put it, without touching the library internals.
import tempfile
from collections import deque
from pathlib import Path
from typing import Any, Self
import jax
import jax.numpy as jnp
from jax import Array
from kups.core.lens import view
from kups.core.logging import (
CompositeLogger,
NullLogger,
ProfileLogger,
TqdmLogger,
)
from kups.core.storage import (
EveryNStep,
HDF5StorageReader,
HDF5StorageWriter,
Once,
WriterGroupConfig,
)
from kups.core.utils.jax import dataclass, key_chain
@dataclass
class State:
step: Array
energy: Array
@jax.jit
def advance(key: Array, state: State) -> State:
delta = jax.random.normal(key) - 0.1 * state.energy
return State(step=state.step + 1, energy=state.energy + 0.1 * delta)
def run(initial_state, logger, n_steps, seed=0):
chain = key_chain(jax.random.key(seed))
state = initial_state
with logger:
for i in range(n_steps):
state = advance(next(chain), state)
logger.log(state, i)
return state
scratch = tempfile.TemporaryDirectory()
scratch_dir = Path(scratch.name)
state0 = State(step=jnp.array(0), energy=jnp.array(1.0))
The Logger protocol¶
A Logger is a context manager with one extra method: log(state, step). The context manager duties carry the setup and teardown that most logging sinks need. A progress bar has to be drawn and closed; an HDF5 file has to be opened, its datasets sized against the full run length, and flushed at the end; a profiler trace has to be armed and released cleanly even when an exception aborts the run. Putting this behind __enter__ and __exit__ means the simulation loop never has to know the difference.
The log(state, step) call is deliberately synchronous. JAX is allowed to donate the state buffer immediately after the step returns, so any reading the logger wants to do has to happen before control returns to the loop. Any implementation that needs to defer work, for example to background-write a trajectory, copies out of the state and queues the copy; it never holds onto the state reference.
The smallest useful logger is a class with three methods.
class PrintLogger:
def __enter__(self) -> Self:
print("[logger] entering")
return self
def __exit__(self, *exc) -> None:
print("[logger] exiting")
def log(self, state: State, step: int) -> None:
print(f" step {step}: energy={float(state.energy):.4f}")
run(state0, PrintLogger(), n_steps=3)
[logger] entering
step 0: energy=0.9325
step 1: energy=0.7974
step 2: energy=0.7151
[logger] exiting
State(step=Array(3, dtype=int32, weak_type=True), energy=Array(0.71509, dtype=float32))
Filtering is the logger's job¶
The simulation loop does exactly one thing: it calls logger.log(state, step) after every step. It does not know that HDF5 writes only every N steps, that a progress bar updates every call, or that a profiler trace runs only over a specific range. Every Logger carries its own filtering logic. This keeps the loop uniform across runs with wildly different observability profiles, and adding a new subsampling rule is a local change inside one logger rather than a fork of the simulation driver.
LoggingFrequency is the convention the HDF5 writer uses internally for this. Once writes exactly at step zero and gives the dataset a scalar shape; EveryNStep writes on every Nth step and reserves space accordingly. Loggers that do not use the HDF5 machinery are free to encode their filter however they like, from a simple modulo to a time-based heartbeat.
class EveryNPrintLogger:
def __init__(self, n: int):
self.n = n
def __enter__(self) -> Self:
return self
def __exit__(self, *exc) -> None:
pass
def log(self, state: State, step: int) -> None:
if step % self.n == 0:
print(f" step {step}: energy={float(state.energy):.4f}")
run(state0, EveryNPrintLogger(n=3), n_steps=10)
step 0: energy=0.9325
step 3: energy=0.7419
step 6: energy=0.8789
step 9: energy=0.6537
State(step=Array(10, dtype=int32, weak_type=True), energy=Array(0.6537125, dtype=float32))
CompositeLogger and NullLogger¶
Most real runs want several things at once: a progress bar to know the run is alive, an HDF5 writer to record the trajectory, occasionally a profiler trace to diagnose a slow kernel. CompositeLogger takes a variadic list of loggers and wraps them in one. __enter__ enters them in order, log forwards to each in order, and __exit__ unwinds them in reverse so resources are released in the opposite order they were acquired.
NullLogger is the no-op counterpart. It is what the warmup driver installs when we want the simulation to run through a burn-in without touching the disk, without advancing a progress bar, and without arming the profiler. Swapping NullLogger for the real logger between warmup and production is typically a single line change and leaves the propagator untouched.
events: list = []
def make_append_logger(label: str):
class AppendLogger:
def __enter__(self) -> Self:
events.append((label, "enter"))
return self
def __exit__(self, *exc) -> None:
events.append((label, "exit"))
def log(self, state, step):
events.append((label, "log", step))
return AppendLogger()
a = make_append_logger("A")
b = make_append_logger("B")
run(state0, CompositeLogger(a, b), n_steps=2)
for event in events:
print(event)
n_steps = 5
run(state0, NullLogger(), n_steps=n_steps)
print(f"NullLogger ran {n_steps} steps with zero observable effect")
('A', 'enter')
('B', 'enter')
('A', 'log', 0)
('B', 'log', 0)
('A', 'log', 1)
('B', 'log', 1)
('B', 'exit')
('A', 'exit')
NullLogger ran 5 steps with zero observable effect
TqdmLogger: progress bars with postfixes¶
TqdmLogger is the simplest useful logger. It wraps tqdm, increments the progress bar once per log call, and optionally displays a postfix dictionary extracted from the state. The postfix hook is a View from state to a dict[str, Any], which is where per-system energy, acceptance rate, or any other running diagnostic is surfaced to the user. Picking what goes in the postfix is a lens expression, not a change to the propagator.
def postfix(state: State) -> dict[str, Any]:
return {"energy": f"{float(state.energy):.3f}", "step": int(state.step)}
final = run(state0, TqdmLogger(10, postfix=postfix), n_steps=10)
print("final energy:", float(final.energy))
0%| | 0/10 [00:00<?, ?it/s]
final energy: 0.6537125110626221
HDF5StorageWriter: structured trajectories with an async writer¶
Actually recording a simulation is what HDF5StorageWriter is for. It implements Logger but adds enough structure that the user has to declare a schema up front. A schema is a pytree of WriterGroupConfig objects, each pairing a view (state to storage) with a LoggingFrequency. This is what lets a single HDF5 file hold a once-only dump of initial data alongside per-step diagnostics, each in its own group with the right array shape.
On __enter__ the writer opens the file, sizes every dataset against the total step count (so the HDF5 layout is known at file creation time), and starts a background thread that drains a queue. On every log call the main thread extracts the view from the state synchronously, hands the small host-side copy to the background thread, and returns. The heavy write-to-disk work happens out of band and does not block the next step. On __exit__ the writer flushes, records the actual number of steps that ran (for runs that stopped early), and closes the file.
The cell below writes a two-group trajectory (one scalar logged once, the energy logged every step) to a temporary HDF5 file, then reads it back with HDF5StorageReader.
@dataclass
class Schema:
initial: WriterGroupConfig[State, dict[str, Array]]
frames: WriterGroupConfig[State, dict[str, Array]]
schema = Schema(
initial=WriterGroupConfig(
view=view(lambda s: {"start_energy": s.energy}),
logging_frequency=Once(),
),
frames=WriterGroupConfig(
view=view(lambda s: {"energy": s.energy}),
logging_frequency=EveryNStep(1),
),
)
n_steps = 50
out_file = scratch_dir / "traj.h5"
writer = HDF5StorageWriter(out_file, schema, state0, total_steps=n_steps)
final = run(state0, writer, n_steps=n_steps)
print("wrote file:", out_file.stat().st_size, "bytes")
with HDF5StorageReader(out_file) as reader:
frames = reader.focus_group(lambda s: s.frames)[:]
initial = reader.focus_group(lambda s: s.initial)[...]
print("frames shape: ", frames["energy"].shape)
print("initial start_energy:", float(initial["start_energy"]))
print("first logged energy: ", float(frames["energy"][0]))
print("last logged energy: ", float(frames["energy"][-1]))
print("live final energy: ", float(final.energy))
wrote file: 11535 bytes
frames shape: (50,)
initial start_energy: 0.9325219988822937
first logged energy: 0.9325219988822937
last logged energy: 0.9091824293136597
live final energy: 0.9091824293136597
ProfileLogger: targeted JAX profiler traces¶
Performance work needs traces, but capturing a full trace over a thousand-step run produces a file large enough to be unusable. ProfileLogger takes a start and end step and arms the JAX profiler only for that window. Because log(state, step) is called after step step has executed, the logger arms the trace one step early so the first captured step is genuinely the one asked for. If start_step is zero the trace is started in __enter__ instead. The output goes to TensorBoard or Perfetto.
The logger does nothing to the simulation itself. It is a decorator on wall time, not on state. This makes it safe to drop a profiler logger into a production run for a diagnostic session without changing anything else.
trace_dir = scratch_dir / "traces"
trace_dir.mkdir()
combined = CompositeLogger(
TqdmLogger(30),
ProfileLogger(trace_dir, start_step=10, end_step=20),
)
run(state0, combined, n_steps=30)
trace_files = [p for p in trace_dir.rglob("*") if p.is_file()]
print("trace artifacts produced:", len(trace_files))
0%| | 0/30 [00:00<?, ?it/s]
trace artifacts produced: 2
Writing your own logger¶
The protocol is small enough that any custom observability need is a handful of lines. A logger is a class with __enter__, __exit__, and log. Filtering goes inside log. Anything that needs state extraction happens before returning. Anything asynchronous is the logger's own responsibility. Because loggers compose through CompositeLogger, a one-off custom logger drops into existing runs without any change to the surrounding code.
Common shapes include a live energy-versus-step plot that updates every 50 steps, a Slack or webhook notifier that fires on convergence, a ring-buffer of recent states useful for post-mortem inspection after a crash, or a sanity checker that asserts invariants every few hundred steps and raises to stop the run on violation. The example below implements the ring-buffer case.
class RingBufferLogger:
"""Keep the last `k` states for post-mortem inspection."""
def __init__(self, k: int):
self.k = k
self.buf: deque = deque(maxlen=k)
def __enter__(self) -> Self:
return self
def __exit__(self, *exc) -> None:
pass
def log(self, state: State, step: int) -> None:
self.buf.append((step, float(state.energy)))
buffer = RingBufferLogger(k=5)
run(state0, buffer, n_steps=100)
print("last 5 steps recorded:")
for step, energy in buffer.buf:
print(f" step {step}: energy={energy:.4f}")
last 5 steps recorded:
step 95: energy=1.4706
step 96: energy=1.4429
step 97: energy=1.5107
step 98: energy=1.4969
step 99: energy=1.3860
Where to go next¶
The simulation drivers in kups.application.utils.propagate show how the pieces fit together: warmup with no logger, production with a composite logger, both wrapping the same propagator from the Propagators notebook. For application-level logging schemas that cover a full GCMC or MD run (which fields, at which frequencies, in which groups) see kups.application.mcmc.logging and its MD counterpart. For the HDF5 reader side and the reading-back-a-run workflow, see HDF5StorageReader in the storage module.