kups.core.lens
¶
Lens library for functional data manipulation in JAX.
This module provides a lens-based approach to accessing and modifying nested data structures in a functional manner. Lenses allow you to focus on specific parts of a data structure and perform operations without mutating the original structure.
The main abstractions are:
- View: A function that extracts a value from a data structure
- Update: A function that sets a value in a data structure
- Lens: A bidirectional interface for getting and setting values
- BoundLens: A lens that has been bound to a specific data structure
Lenses satisfy both the View and Update protocols through their __call__ method:
lens(state)acts as a View, returning the focused valuelens(state, value)acts as an Update, returning the modified structure
This allows lenses to be used directly wherever a View or Update function is expected.
Modifier = Callable[[R], R]
¶
Type alias for a function that transforms a value.
BaseLens
¶
Bases: Lens[S, R], ABC
Base class for lens implementations.
Source code in src/kups/core/lens.py
BoundLens
¶
Bases: Protocol[S, R]
Protocol for a lens that has been bound to a specific data structure.
A bound lens provides the same operations as a regular lens but without requiring the state parameter, since it's already bound to a specific instance.
Bound lenses satisfy View (via zero-argument call) and can update via single-argument call, providing a convenient interface for repeated operations on the same state.
Generic in S (bound data structure) and R (focused value).
Examples:
>>> bound = my_lens.bind(state)
>>> bound() # View: returns focused value
>>> bound(new_value) # Update: returns modified state
Source code in src/kups/core/lens.py
275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 | |
__call__(value=_NOT_SET)
¶
Get or set the focused value in the bound state.
When called with no arguments, returns the focused value. When called with one argument, sets the value and returns the modified state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
value
|
R2 | _NOT_SET_TYPE
|
If provided, the new value to set |
_NOT_SET
|
Returns:
| Type | Description |
|---|---|
S | R2
|
The focused value (no args) or modified state (one arg) |
Source code in src/kups/core/lens.py
apply(modifier)
¶
Apply a modifier function to the focused value in the bound data structure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
modifier
|
Modifier[R2]
|
A function that transforms the focused value |
required |
Returns:
| Type | Description |
|---|---|
S
|
A new data structure with the modified value set |
Source code in src/kups/core/lens.py
at(idxs, *, args=None)
¶
Create a bound lens that slices the focused pytree at the given indices.
Functionally equivalent to composing jax.tree.map and jax.array.at[idxs].get/set on the bound data structure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
idxs
|
Any
|
The indices to slice the pytree at. |
required |
args
|
ScatterArgs | None
|
Optional ScatterArgs controlling scatter behavior (mode, wrap_negative_indices, fill_value, indices_are_sorted, unique_indices). |
None
|
Returns:
| Type | Description |
|---|---|
BoundLens[S, R]
|
A new bound lens that slices the focused pytree at the given indices. |
Source code in src/kups/core/lens.py
focus(where)
¶
Focus this bound lens on a deeper part of the data structure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
where
|
Callable[[R2], B]
|
A function that extracts a value from the current focus |
required |
Returns:
| Type | Description |
|---|---|
BoundLens[S, B]
|
A new bound lens focused on the result of the where function |
Source code in src/kups/core/lens.py
get()
¶
Extract the focused value from the bound data structure.
Returns:
| Type | Description |
|---|---|
R
|
The focused value from the bound state |
merge(other)
¶
Merge this lens with another lens to access multiple values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
other
|
Lens[S2, R3]
|
Another lens to merge with |
required |
Returns:
| Type | Description |
|---|---|
BoundLens[S2, tuple[R2, R3]]
|
A new lens that accesses both focused values as a tuple |
Source code in src/kups/core/lens.py
nest(other)
¶
Nest another lens or view within this bound lens.
This provides an alternative to focus() that works with both lenses and views. When given a lens, it extracts the view from it; when given a view directly, it uses it as-is.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
other
|
Lens[R2, U]
|
Either a lens or view to nest within this lens |
required |
Returns:
| Type | Description |
|---|---|
BoundLens[S, U]
|
A new bound lens that composes this lens with the provided lens/view |
Source code in src/kups/core/lens.py
set(value)
¶
Set the focused value in the bound data structure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
value
|
R2
|
The new value to set |
required |
Returns:
| Type | Description |
|---|---|
S
|
A new data structure with the value set |
ConstLens
¶
Bases: BaseLens[S, R]
Lens that always returns a constant value; set is a no-op.
Source code in src/kups/core/lens.py
HasLensFields
¶
Base class for dataclasses that support lens-enabled field access.
Dataclasses inheriting from HasLensFields can use LensField[T] annotations to enable dual-mode field access: - Class access returns a Lens object for functional operations - Instance access returns the field value normally
This class uses the FieldMetaAccess metaclass to intercept class attribute access and provide lens objects when appropriate.
Example
from kups.core.utils.jax import dataclass from jax import Array
@dataclass ... class State(HasLensFields): ... position: LensField[Array] ... velocity: LensField[Array]
state = State(position=pos, velocity=vel) pos_lens = State.position # Returns Lens[State, Array] current_pos = state.position # Returns the Array value
Note
HasLensFields itself cannot be instantiated. It must be subclassed.
Source code in src/kups/core/lens.py
IndexLens
¶
Bases: BaseLens[S, R]
A lens that performs array indexing operations on the focused data.
This lens wraps another lens and applies JAX array indexing operations to slice, index, or select specific elements from arrays in the focused data.
Source code in src/kups/core/lens.py
at(idxs, **extra_kwargs)
¶
focus(where)
¶
get(state)
¶
Get values by applying array indexing to the focused data.
Source code in src/kups/core/lens.py
set(state, value)
¶
Set values by applying array indexing to the focused data.
LambdaLens
¶
Bases: BaseLens[S, R]
A lens that uses custom getter and setter functions.
This allows for more complex lens behavior that cannot be expressed with simple field access or traversal-based operations.
Source code in src/kups/core/lens.py
get(state)
¶
Lens
¶
Bases: Protocol[S, R]
Protocol for a lens that provides bidirectional access to data structures.
A lens combines a getter and setter, allowing functional access and modification of nested data structures. Lenses are composable and can be focused on specific parts of a data structure.
Lenses satisfy both the View and Update protocols through their __call__ method,
allowing them to be used directly wherever a view or update function is expected.
Generic in S (source data structure) and R (focused value).
Examples:
>>> state = MyState(value=10)
>>> my_lens = lens(lambda s: s.value)
>>> my_lens(state) # View: returns 10
>>> my_lens(state, 20) # Update: returns MyState(value=20)
Source code in src/kups/core/lens.py
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 | |
__call__(state, /, value=_NOT_SET)
¶
Get or set the focused value, satisfying View and Update protocols.
When called with one argument, acts as a View and returns the focused value. When called with two arguments, acts as an Update and returns the modified state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
S2
|
The data structure to operate on |
required |
value
|
R2 | _NOT_SET_TYPE
|
If provided, the new value to set |
_NOT_SET
|
Returns:
| Type | Description |
|---|---|
S2 | R2
|
The focused value (one arg) or modified state (two args) |
Source code in src/kups/core/lens.py
apply(state, /, modifier)
¶
Apply a modifier function to the focused value.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
S2
|
The data structure to modify |
required |
modifier
|
Modifier[R2]
|
A function that transforms the focused value |
required |
Returns:
| Type | Description |
|---|---|
S2
|
A new data structure with the modified value |
Source code in src/kups/core/lens.py
at(idxs, *, args=None)
¶
Create a lens that slices the focused pytree at the given indices.
Functionally equivalent to composing jax.tree.map and jax.array.at[idxs].get/set.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
idxs
|
Any
|
The indices to slice the pytree at. |
required |
args
|
ScatterArgs | None
|
Optional ScatterArgs controlling scatter behavior (mode, wrap_negative_indices, fill_value, indices_are_sorted, unique_indices). |
None
|
Returns:
| Type | Description |
|---|---|
Lens[S, R]
|
A new lens that slices the focused pytree at the given indices. |
Source code in src/kups/core/lens.py
bind(state)
¶
Bind this lens to a specific data structure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
S2
|
The data structure to bind to |
required |
Returns:
| Type | Description |
|---|---|
BoundLens[S2, R]
|
A bound lens that operates on the given state |
focus(where)
¶
Focus this lens on a deeper part of the data structure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
where
|
Callable[[R], B]
|
A function that extracts a value from the current focus |
required |
Returns:
| Type | Description |
|---|---|
Lens[S, B]
|
A new lens focused on the result of the where function |
Source code in src/kups/core/lens.py
get(state)
¶
Extract the focused value from the data structure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
S2
|
The data structure to extract from |
required |
Returns:
| Type | Description |
|---|---|
R
|
The focused value |
merge(other)
¶
nest(other)
¶
Nest another lens or view within this lens.
This provides an alternative to focus() that works with both lenses and views. When given a lens, it extracts the view from it; when given a view directly, it uses it as-is.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
other
|
Lens[R2, U]
|
Either a lens or view to nest within this lens |
required |
Returns:
| Type | Description |
|---|---|
Lens[S2, U]
|
A new lens that composes this lens with the provided lens/view |
Source code in src/kups/core/lens.py
set(state, /, value)
¶
Set the focused value in the data structure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
S2
|
The data structure to modify |
required |
value
|
R2
|
The new value to set |
required |
Returns:
| Type | Description |
|---|---|
S2
|
A new data structure with the value set |
Source code in src/kups/core/lens.py
LensField
¶
Bases: ABC
Type annotation for lens-enabled fields in dataclasses.
LensField provides a type-safe way to enable lens access on dataclass fields. When a dataclass inherits from HasLensFields, fields annotated with LensField[T] can be accessed both as regular attributes and as lenses.
Class Type Parameters:
| Name | Bound or Constraints | Description | Default |
|---|---|---|---|
B
|
The type of the field value |
required |
Behavior
- Class access (e.g.,
MyClass.field): Returns aLens[MyClass, B]that can be used for functional operations likeget(),set(),focus() - Instance access (e.g.,
obj.field): Returns the actual field value of typeB, behaving like a normal attribute
Usage
For dataclasses deriving from HasLensFields, annotate fields with LensField[T] to enable lens access. Regular dataclasses without HasLensFields inheritance do not support lens access through field annotations.
Examples:
>>> from kups.core.utils.jax import dataclass
>>> from kups.core.lens import LensField, HasLensFields
>>> import jax.numpy as jnp
>>> from jax import Array
>>>
>>> @dataclass
>>> class Point(HasLensFields):
... x: LensField[float]
... y: LensField[Array]
>>>
>>> # Instance access - normal field behavior
>>> point = Point(x=1.0, y=jnp.array([1.0, 2.0, 3.0]))
>>> point.x # Returns 1.0
>>> point.y # Returns Array([1., 2., 3.])
>>>
>>> # Class access - returns a lens
>>> x_lens = Point.x # Returns Lens[Point, float]
>>> x_lens.get(point) # Returns 1.0
>>> new_point = x_lens.set(point, 5.0) # Returns Point(x=5.0, y=...)
>>>
>>> # Compose with other lenses
>>> doubled_x_lens = Point.x.focus(lambda x: x * 2)
>>> doubled_x_lens.get(point) # Returns 2.0
>>>
>>> # Works in JAX transformations
>>> @jax.jit
>>> def increment_y(p: Point) -> Point:
... return Point.y.set(p, p.y + 1.0)
Notes
- Only works with dataclasses that inherit from HasLensFields
- The metaclass FieldMetaAccess intercepts class attribute access to return lenses
- Compatible with JAX transformations when used with jax-compatible
dataclasses (e.g., from
kups.core.utils.jax) - Use
lens_field()instead offield()for type-safe field definitions with default values or field options
Source code in src/kups/core/lens.py
1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 | |
MergedLens
¶
Bases: BaseLens[S, tuple[R, R2]]
A lens that merges two lenses to access multiple values.
This lens combines two lenses into a single lens that accesses both focused values as a tuple. It allows you to work with multiple parts of a data structure simultaneously.
Source code in src/kups/core/lens.py
NestedLens
¶
Bases: BaseLens[S, R2], Generic[S, R, R2]
A lens that composes two lenses to access deeply nested data.
This lens combines an outer lens (S -> A) with an inner lens (A -> B) to create a composite lens (S -> B). Operations are performed by first applying the outer lens, then the inner lens.
Source code in src/kups/core/lens.py
SimpleBoundLens
¶
Bases: BoundLens[S, R]
A lens that has been bound to a specific data structure instance.
This provides a convenient interface for repeatedly operating on the same data structure without having to pass it as a parameter each time.
Source code in src/kups/core/lens.py
SimpleLens
¶
Bases: BaseLens[S, R]
A simple lens implementation that uses traversal-based setting.
This is the most basic lens implementation that works with any pytree structure supported by JAX.
Source code in src/kups/core/lens.py
focus(where)
¶
get(state)
¶
set(state, value)
¶
Set the focused value using traversal lens.
Source code in src/kups/core/lens.py
TreePathView
¶
A view that follows a path of keys/attributes through a pytree.
Source code in src/kups/core/lens.py
Update
¶
Bases: Protocol
Protocol for an update function that sets a value in a data structure.
Class Type Parameters:
| Name | Bound or Constraints | Description | Default |
|---|---|---|---|
S
|
The type of the source data structure |
required | |
R
|
The type of the value being set |
required |
Source code in src/kups/core/lens.py
View
¶
Bases: Protocol
Protocol for a view function that extracts a value from a data structure.
A view is a read-only operation that focuses on a specific part of a data structure.
Class Type Parameters:
| Name | Bound or Constraints | Description | Default |
|---|---|---|---|
S
|
The type of the source data structure |
required | |
R
|
The type of the result value |
required |
Source code in src/kups/core/lens.py
lens_property
¶
Decorator to create a lens-enabled property.
This decorator allows properties to behave like LensField, returning a lens when accessed on the class and the property value when accessed on an instance.
Class Type Parameters:
| Name | Bound or Constraints | Description | Default |
|---|---|---|---|
B
|
The type of the property value |
required |
Behavior
- Class access (e.g.,
MyClass.prop): Returns aLens[MyClass, B]that can be used for functional operations likeget(),set(),focus() - Instance access (e.g.,
obj.prop): Returns the property value of typeB, behaving like a normal property
Examples:
>>> from kups.core.utils.jax import dataclass
>>> from kups.core.lens import lens_property, HasLensFields
>>>
>>> @dataclass
>>> class Temperature(HasLensFields):
... _kelvin: float
...
... @lens_property
... def kelvin(self) -> float:
... return self._kelvin
...
... @lens_property
... def celsius(self) -> float:
... return self._kelvin - 273.15
>>>
>>> temp = Temperature(_kelvin=300.0)
>>>
>>> # Instance access returns values
>>> temp.kelvin # 300.0
>>> temp.celsius # 26.85
>>>
>>> # Class access returns lenses
>>> kelvin_lens = Temperature.kelvin # Lens[Temperature, float]
>>> celsius_lens = Temperature.celsius # Lens[Temperature, float]
>>>
>>> # Use lenses for functional updates
>>> kelvin_lens.get(temp) # 300.0
>>> celsius_lens.get(temp) # 26.85
Notes
- Only works with classes that inherit from HasLensFields
- The decorated function should be a simple getter (no parameters other than self)
- Setting through the lens creates a new instance with the updated value
Source code in src/kups/core/lens.py
1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 | |
all_isinstance_lens(obj, cls)
¶
Create a lens that focuses on all elements in a pytree of a specific type.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obj
|
S
|
An instance of the state type to infer the pytree structure |
required |
cls
|
type[Target]
|
The type to filter elements by |
required |
Returns:
| Type | Description |
|---|---|
Lens[S, tuple[Target, ...]]
|
A lens that focuses on all elements of the specified type as a tuple. |
Source code in src/kups/core/lens.py
all_where_lens(obj, conditional, *, target_cls=None)
¶
Create a lens that focuses on all elements in a pytree satisfying a condition.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obj
|
S
|
An instance of the state type to infer the pytree structure |
required |
conditional
|
Callable[[Any], bool]
|
A predicate function that tests each element |
required |
target_cls
|
type[Target] | None
|
Optional type hint for the target type (currently unused) |
None
|
Returns:
| Type | Description |
|---|---|
Lens[S, tuple[Target, ...]]
|
A lens that focuses on all elements satisfying the condition as a tuple. |
Source code in src/kups/core/lens.py
bind(obj, where=None)
¶
Create a bound lens from a getter function and a data structure.
This is a convenience function that creates a lens and immediately binds it to a specific data structure instance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obj
|
S
|
The data structure to bind the lens to |
required |
where
|
Callable[[S], R] | None
|
A function that extracts a value from the data structure |
None
|
Returns:
| Type | Description |
|---|---|
BoundLens[S, R]
|
A BoundLens that operates on the given object |
Source code in src/kups/core/lens.py
const_lens(value)
¶
identity_lens(_cls)
¶
Create an identity lens for a type.
An identity lens is a lens that focuses on the entire data structure. It's primarily useful as a starting point for composition using .focus().
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
_cls
|
type[S]
|
The type to create an identity lens for (parameter name prefixed with underscore as the value itself is unused, only the type matters) |
required |
Returns:
| Type | Description |
|---|---|
Lens[S, S]
|
A SimpleLens that acts as an identity function on the data structure |
Examples:
>>> # Create an identity lens and compose it
>>> State = identity_lens(SimState)
>>> position_lens = State.focus(lambda s: s.position)
>>> velocity_lens = State.focus(lambda s: s.velocity)
Source code in src/kups/core/lens.py
lens(where, /, *, cls=None)
¶
Create a lens from a getter function.
This is the main factory function for creating lenses from getter functions that extract values from a data structure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
where
|
Callable[[S], R]
|
A function that extracts a value from the data structure |
required |
cls
|
type[S] | None
|
Optional class type hint (type inference only) |
None
|
Returns:
| Type | Description |
|---|---|
Lens[S, R]
|
A SimpleLens that can operate on the data structure |
Examples:
>>> # Direct lens creation
>>> position_lens = lens(lambda s: s.position)
>>> velocity_lens = lens(lambda s: s.velocity)
Source code in src/kups/core/lens.py
update(where, *, cls=None)
¶
Create an update function from a getter function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
where
|
Callable[[S], R]
|
A function that extracts a value from the data structure |
required |
cls
|
type[S] | None
|
Optional class type hint (type inference only) |
None
|
Returns:
| Type | Description |
|---|---|
Update[S, R]
|
A function that updates a value in the data structure |
Source code in src/kups/core/lens.py
view(where, /, cls=None)
¶
Create a view from a callable function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
where
|
Callable[[S], R]
|
A function that extracts a value from the data structure |
required |
Returns:
| Type | Description |
|---|---|
View[S, R]
|
A View instance that wraps the provided function |