Skip to content

Lenses

JAX data structures are immutable. Updating a deeply nested field means reconstructing every intermediate layer — a process that is verbose, error-prone, and tedious to maintain. Lenses solve this by letting you describe what to access and automatically deriving how to update it.

from dataclasses import replace

import jax.numpy as jnp
from jax import Array

from kups.core.lens import LambdaLens, Lens, bind, lens, view
from kups.core.utils.jax import dataclass

The Problem

Consider a simulation state with three levels of nesting. To update the positions at the bottom, you must reconstruct every layer by hand.

@dataclass
class UnitCell:
    lattice: Array


@dataclass
class Atoms:
    positions: Array
    cell: UnitCell


@dataclass
class State:
    atoms: Atoms
    energy: Array


state = State(
    atoms=Atoms(
        positions=jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
        cell=UnitCell(lattice=5.0 * jnp.eye(3)),
    ),
    energy=jnp.array(0.5),
)

# Manual nested update — reconstruct every layer
new_positions = jnp.zeros((2, 3))
new_state = replace(state, atoms=replace(state.atoms, positions=new_positions))
print(new_state.atoms.positions)
[[0. 0. 0.]
 [0. 0. 0.]]

View and Update

Simulation code works with arbitrary pytrees — we don't know in advance that .atoms gives us an Atoms object or that .positions holds coordinates. We need a way to abstract over data access.

A View is a function that extracts a value from a structure. Easy to write — just a lambda. An Update is a function that returns a new structure with a value replaced. Much harder — you must reconstruct every intermediate layer. Often you only need one or the other, but frequently you need both: read a field, compute something, write it back. When you do, keeping them in sync is tedious.

# View: trivial
def get_positions(state):
    return state.atoms.positions


# Update: must reconstruct every layer
def set_positions(state, value):
    return replace(state, atoms=replace(state.atoms, positions=value))


print(get_positions(state))
print(set_positions(state, jnp.ones((2, 3))).atoms.positions)
[[1. 2. 3.]
 [4. 5. 6.]]
[[1. 1. 1.]
 [1. 1. 1.]]

Lens: Unifying View and Update

A Lens combines both into a single object. You only provide the getter — the easy part. The lens derives the setter automatically. This eliminates the duplication: one definition, both directions.

Because a Lens satisfies both the View and Update protocols, it can be passed anywhere either is expected.

pos_lens = lens(lambda s: s.atoms.positions, cls=State)

# Get
print(pos_lens.get(state))

# Set — no manual replace chain needed
new_state = pos_lens.set(state, jnp.zeros((2, 3)))
print(new_state.atoms.positions)
[[1. 2. 3.]
 [4. 5. 6.]]
[[0. 0. 0.]
 [0. 0. 0.]]

The lens() factory accepts an optional cls= argument to help static type checkers. A bare lambda like lambda s: s.atoms.positions gives pyright no information about what s is, so the resulting lens is untyped. Passing cls=State tells the type checker what the lens operates on:

pos_lens = lens(lambda s: s.atoms.positions, cls=State)  # Lens[State, Array]

In practice, you rarely need this. When a lens is passed into a typed context — for example, as a function argument annotated Lens[State, Array] — the type is inferred automatically from the call site. The cls= argument exists for the cases where you define a lens at module level without such context.

Read-only views: view()

Sometimes a component only needs to read a field, not write it. The view() factory returns a plain View, which is callable on the state but has no setter:

  • lens(where, cls=...) returns a Lens[S, R]. Satisfies both View and Update; the setter is inferred from the getter. Read via lens.get(state); write via lens.set(state, value).
  • view(where, cls=...) returns a View[S, R]. Read-only; no setter is generated. Read by calling it directly: v(state).

Use view() when a function advertises (via its type signature) that it will only read. Use lens() when it may also need to write back.

pos_view = view(lambda s: s.atoms.positions, cls=State)

# A View is callable directly on the state (no `.get`):
print(pos_view(state))

# It has no `.set`, so misuse is a type error at the call site,
# not a silent failure at runtime.
[[1. 2. 3.]
 [4. 5. 6.]]

How Set is Inferred

Writing a getter for nested data is trivial — lambda s: s.atoms.positions reads like English. Writing the corresponding setter is mechanical but verbose. Since the setter is fully determined by the getter's path, the Lens can infer it: it traces the getter, recording each attribute access, then replays the path in reverse to reconstruct the structure.

This works as long as every leaf in the returned value is a path into the original data. The structure itself can be computed — for example, a tuple collecting fields from different places — as long as each leaf traces back to a concrete location.

# Getter returns a computed structure (tuple), but each leaf is a path
both_lens = lens(lambda s: (s.atoms.positions, s.energy), cls=State)

print(both_lens.get(state))

# Set updates both fields at once
new_state = both_lens.set(state, (jnp.ones((2, 3)), jnp.array(9.9)))
print(new_state.atoms.positions)
print(new_state.energy)
(Array([[1., 2., 3.],
       [4., 5., 6.]], dtype=float32), Array(0.5, dtype=float32, weak_type=True))
[[1. 1. 1.]
 [1. 1. 1.]]
9.9

Computed leaves — arithmetic, aggregation — break set. There is no path to write back to.

bad_lens = lens(lambda s: s.energy * 2, cls=State)

# Get works fine
print(bad_lens.get(state))

# Set fails — the result is computed, not a path
try:
    bad_lens.set(state, jnp.array(1.0))
except ValueError as e:
    print(e)
1.0
Cannot set value through this lens: Focus function returned a computed value instead of a path into the data. Use attribute access (x.field) or indexing (x[i]) to reference data. Got: ArrayImpl
Hint: The focus function must return references to parts of the data, not computed values or literals.

LambdaLens: Custom Setters

Sometimes the setter is not a simple reversal of the getter — setting one field should also update a derived field, or the mapping between external and internal representation is non-trivial. LambdaLens lets you provide explicit get and set functions for these cases. Use it sparingly — automatic inference covers the vast majority of cases.

# Setting energy also zeroes out positions (contrived but illustrative)
custom: Lens[State, Array] = LambdaLens(
    lambda s: s.energy,
    lambda state, value: replace(
        state,
        energy=value,
        atoms=replace(state.atoms, positions=state.atoms.positions * 0),
    ),
)

new_state = custom.set(state, jnp.array(99.0))
print("energy:", new_state.energy)
print("positions zeroed:", new_state.atoms.positions)
energy: 99.0
positions zeroed: [[0. 0. 0.]
 [0. 0. 0.]]

Focus: Composing Deeper

Lens.focus narrows an existing lens to a nested field, creating a composed lens. This is useful when you build lenses incrementally — a module defines a lens to its slice of the state, and the caller focuses it further. Each .focus() composes get and set automatically.

atoms_lens = lens(lambda s: s.atoms, cls=State)
pos_via_focus = atoms_lens.focus(lambda a: a.positions)

print(pos_via_focus.get(state))

new_state = pos_via_focus.set(state, jnp.ones((2, 3)) * 7)
print(new_state.atoms.positions)
[[1. 2. 3.]
 [4. 5. 6.]]
[[7. 7. 7.]
 [7. 7. 7.]]

A single deep lambda achieves the same result, but chained .focus() calls are more composable when building lenses dynamically — each piece can come from a different module or configuration.

Array Indexing with .at()

Simulations frequently need to update a subset of particles or entries, not the whole array. Lens.at focuses on specific array elements within the pytree, mirroring JAX's array.at[idx].set() but applying across the whole pytree at once.

pos_lens = lens(lambda s: s.atoms.positions, cls=State)

# Read a subset
print(pos_lens.at(jnp.array([0])).get(state))

# Update only the first particle
new_state = pos_lens.at(jnp.array([0])).set(state, jnp.array([[9.0, 8.0, 7.0]]))
print(new_state.atoms.positions)
[[1. 2. 3.]]


[[9. 8. 7.]
 [4. 5. 6.]]

This is how Table.at works under the hood for foreign-key updates — the Lens handles the gather and scatter across the entire data pytree.

Bind: Repeated Access

Passing the same state to every get/set call is repetitive when doing multiple operations in sequence. bind(state) creates a BoundLens — operations no longer need the state as argument.

bound = bind(state)

print(bound.focus(lambda s: s.energy).get())
print(bound.focus(lambda s: s.atoms.positions).get())

new_state = bound.focus(lambda s: s.atoms.positions).set(jnp.zeros((2, 3)))
print(new_state.atoms.positions)
0.5
[[1. 2. 3.]
 [4. 5. 6.]]
[[0. 0. 0.]
 [0. 0. 0.]]

Higher-Order Pattern: Lenses as Configuration

Simulation components — integrators, Monte Carlo moves — need to read and write specific fields of the state. But the state type varies between simulations. Rather than hardcoding field access, components accept Lenses as parameters. The caller provides lenses that tell the component where to find its data. One implementation works with any state type.

@dataclass
class StateA:
    positions: Array
    energy: Array


@dataclass
class StateB:
    coords: Array
    potential: Array


def scale_positions[S](state: S, pos_lens: Lens[S, Array], factor: float) -> S:
    """Generic function — works with any state type via the lens."""
    return pos_lens.set(state, pos_lens.get(state) * factor)


a = StateA(positions=jnp.array([1.0, 2.0]), energy=jnp.array(0.5))
b = StateB(coords=jnp.array([3.0, 4.0]), potential=jnp.array(1.0))

print(scale_positions(a, lens(lambda s: s.positions, cls=StateA), 2.0).positions)
print(scale_positions(b, lens(lambda s: s.coords, cls=StateB), 2.0).coords)
[2. 4.]
[6. 8.]