kups.core.utils.jax
¶
JAX utility functions and decorators for functional programming.
This module provides type-safe wrappers around JAX transformations and utilities for working with PyTrees, including JIT compilation, vectorization, and custom dataclass registration.
NotJaxCompatibleError
¶
PyTreeDef
¶
Bases: Protocol
Typed protocol for JAX pytree structure definitions.
Source code in src/kups/core/utils/jax.py
SupportsTreeMatch
¶
Bases: Protocol
Protocol for pytree nodes that align themselves before mapping.
When tree_map encounters a SupportsTreeMatch node, it calls
__tree_match__ to reconcile self with the corresponding nodes
from the other input trees (e.g., merging label vocabularies in
Index). The returned tuple replaces the originals for the
remainder of the map.
Source code in src/kups/core/utils/jax.py
dataclass(cls=None, /, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=True, match_args=True, kw_only=False, slots=False, weakref_slot=False)
¶
Create a dataclass that works as a JAX PyTree.
Combines Python's @dataclass with JAX's PyTree registration, enabling
dataclasses to be used with JAX transformations like jit, grad, and vmap.
Dataclasses are frozen by default for immutability (unlike standard dataclasses).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cls
|
T | None
|
Class to convert into a JAX-compatible dataclass, or None when used as a decorator with arguments. |
None
|
init
|
bool
|
If True (default), generate an |
True
|
repr
|
bool
|
If True (default), generate a |
True
|
eq
|
bool
|
If True (default), generate |
True
|
order
|
bool
|
If True, generate |
False
|
unsafe_hash
|
bool
|
If True, generate a |
False
|
frozen
|
bool
|
If True (default, unlike standard dataclasses), fields cannot be assigned after instance creation. This is the default because JAX transformations work best with immutable data structures. |
True
|
match_args
|
bool
|
If True (default), generate |
True
|
kw_only
|
bool
|
If True, all fields become keyword-only in |
False
|
slots
|
bool
|
If True, generate |
False
|
weakref_slot
|
bool
|
If True and |
False
|
Returns:
| Type | Description |
|---|---|
T | Callable[[T], T]
|
A dataclass registered as a JAX PyTree, or a decorator if |
Example
Note
Unlike standard dataclasses.dataclass, this decorator defaults to
frozen=True.
Source code in src/kups/core/utils/jax.py
641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 | |
field(*, default=dataclasses.MISSING, default_factory=dataclasses.MISSING, init=True, repr=True, hash=None, compare=True, metadata=None, kw_only=dataclasses.MISSING, static=False)
¶
Create a dataclass field with JAX PyTree registration support.
This is an enhanced version of dataclasses.field that adds a static parameter
for controlling JAX PyTree registration. When static=True, the field is marked
as static metadata and excluded from JAX transformations like jit, grad, and vmap.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
default
|
T | _MISSING_TYPE
|
Default value for the field. Cannot be used with |
MISSING
|
default_factory
|
Callable[[], T] | _MISSING_TYPE
|
Factory function to generate default values. Cannot be used with |
MISSING
|
init
|
bool
|
If |
True
|
repr
|
bool
|
If |
True
|
hash
|
bool | None
|
If |
None
|
compare
|
bool
|
If |
True
|
metadata
|
Mapping[Any, Any] | None
|
Additional metadata dictionary for the field. |
None
|
kw_only
|
bool | _MISSING_TYPE
|
If |
MISSING
|
static
|
bool
|
If |
False
|
Returns:
| Type | Description |
|---|---|
T
|
A dataclass field configured with the specified parameters. |
Example
@dataclass
class Config:
learning_rate: Array # Dynamic field
model_name: str = field(default="transformer", static=True) # Static field
weights: Array = field(default_factory=lambda: jnp.zeros((10,))) # Dynamic field
config = Config()
# Only learning_rate and weights are traced by JAX transformations
# model_name remains constant as static metadata
Note
Static fields are useful for configuration parameters, model hyperparameters, or any values that should remain constant during JAX transformations.
Source code in src/kups/core/utils/jax.py
is_traced(x)
¶
isin(a, b, max_item)
¶
Fast membership test for integer arrays using index-based lookup.
Optimized alternative to jnp.isin for integer arrays with known maximum
value. Uses array indexing instead of comparisons for better performance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a
|
Array
|
Query array of integers to test for membership. |
required |
b
|
Array
|
Reference array of integers to test membership against. |
required |
max_item
|
int
|
Maximum possible value in both arrays (exclusive upper bound). |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Boolean array of same shape as |
Array
|
exists in |
Example
Source code in src/kups/core/utils/jax.py
jit(fn=None, *, in_shardings=None, out_shardings=None, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None, compiler_options=None)
¶
jit(
fn: C,
*,
in_shardings: Any = ...,
out_shardings: Any = ...,
static_argnums: int | Sequence[int] | None = ...,
static_argnames: str | Iterable[str] | None = ...,
donate_argnums: int | Sequence[int] | None = ...,
donate_argnames: str | Iterable[str] | None = ...,
keep_unused: bool = ...,
device: Any = ...,
backend: str | None = ...,
inline: bool = ...,
abstracted_axes: Any = ...,
compiler_options: dict[str, Any] | None = ...,
) -> C
jit(
fn: None = None,
*,
in_shardings: Any = ...,
out_shardings: Any = ...,
static_argnums: int | Sequence[int] | None = ...,
static_argnames: str | Iterable[str] | None = ...,
donate_argnums: int | Sequence[int] | None = ...,
donate_argnames: str | Iterable[str] | None = ...,
keep_unused: bool = ...,
device: Any = ...,
backend: str | None = ...,
inline: bool = ...,
abstracted_axes: Any = ...,
compiler_options: dict[str, Any] | None = ...,
) -> Callable[[C], C]
Type-preserving JIT compilation decorator for JAX functions.
Sets up a function for just-in-time compilation with XLA. Wraps jax.jit
while preserving function names and type signatures.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
C | None
|
Function to be jitted. Should be a pure function. The arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. |
None
|
in_shardings
|
Any
|
Optional sharding specification for inputs. If provided, the positional arguments must have compatible shardings. |
None
|
out_shardings
|
Any
|
Optional sharding specification for outputs. Has the same
effect as applying |
None
|
static_argnums
|
int | Sequence[int] | None
|
Int or collection of ints specifying which positional arguments to treat as static (trace- and compile-time constant). Static arguments should be hashable and immutable. |
None
|
static_argnames
|
str | Iterable[str] | None
|
String or collection of strings specifying which named arguments to treat as static (compile-time constant). |
None
|
donate_argnums
|
int | Sequence[int] | None
|
Collection of integers specifying which positional argument buffers can be overwritten by the computation and marked deleted in the caller. Useful for memory optimization. |
None
|
donate_argnames
|
str | Iterable[str] | None
|
String or collection of strings specifying which named arguments are donated to the computation. |
None
|
keep_unused
|
bool
|
If False (default), arguments that JAX determines to be unused may be dropped from compiled executables. If True, unused arguments will not be pruned. |
False
|
device
|
Any
|
Optional device to run the jitted function on. |
None
|
backend
|
str | None
|
Optional string representing the XLA backend: 'cpu', 'gpu', or 'tpu'. |
None
|
inline
|
bool
|
If True, inline this function into enclosing jaxprs. Default False. |
False
|
abstracted_axes
|
Any
|
Optional axis abstraction specification. |
None
|
compiler_options
|
dict[str, Any] | None
|
Optional dictionary of compiler options. |
None
|
Returns:
| Type | Description |
|---|---|
C | Callable[[C], C]
|
A wrapped version of the function, set up for just-in-time compilation |
C | Callable[[C], C]
|
with preserved type signature. |
Example
Source code in src/kups/core/utils/jax.py
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 | |
kahan_summation(*summands, compensate=None)
¶
Numerically stable summation using Kahan's compensated algorithm.
Reduces floating-point accumulation errors when summing many numbers by tracking and compensating for rounding errors at each step. Works with arbitrary PyTree structures.
The algorithm maintains an error compensation term that captures lost precision, significantly reducing numerical drift in iterative computations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*summands
|
T
|
One or more PyTrees to sum together. |
()
|
compensate
|
T | None
|
Optional error compensation term from previous summation. |
None
|
Returns:
| Type | Description |
|---|---|
T
|
Tuple of (sum, compensation) where compensation should be passed to |
T
|
subsequent calls for continued stability. |
Example
Reference
W. Kahan, "Further remarks on reducing truncation errors", 1965.
Source code in src/kups/core/utils/jax.py
key_chain(rng, shape=())
¶
Generate an infinite sequence of PRNG keys with deterministic iteration.
Creates a generator that produces an infinite stream of JAX PRNG keys by folding in incrementing counters. Useful for iterative algorithms that need reproducible randomness at each step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng
|
Array
|
Initial JAX PRNG key. |
required |
shape
|
tuple[int, ...]
|
Shape of key batches to generate. Default |
()
|
Yields:
| Type | Description |
|---|---|
Array
|
JAX PRNG keys with the specified shape, incremented deterministically. |
Example
Source code in src/kups/core/utils/jax.py
lens_field(*, default=dataclasses.MISSING, default_factory=dataclasses.MISSING, init=True, repr=True, hash=None, compare=True, metadata=None, kw_only=dataclasses.MISSING, static=False)
¶
Create a field for use with LensField annotations.
This is a type-safe wrapper around field() specifically for LensField[T] annotations. It returns the proper type for static type checkers while creating a regular dataclass field at runtime.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
default
|
T | _MISSING_TYPE
|
Default value for the field. |
MISSING
|
default_factory
|
Callable[[], T] | _MISSING_TYPE
|
Factory function to generate default values. |
MISSING
|
init
|
bool
|
Include field in init. |
True
|
repr
|
bool
|
Include field in repr. |
True
|
hash
|
bool
|
Include field in hash. |
None
|
compare
|
bool
|
Include field in comparison methods. |
True
|
metadata
|
dict
|
Additional metadata dictionary. |
None
|
kw_only
|
bool
|
Make this a keyword-only argument in init. |
MISSING
|
static
|
bool
|
Mark field as static (JAX tree registration). |
False
|
Returns:
| Type | Description |
|---|---|
LensField[T]
|
A dataclass field that type checkers treat as LensField[T]. |
Example
from kups.core.lens import LensField, HasLensFields from kups.core.utils.jax import dataclass, lens_field from jax import Array
@dataclass ... class Point(HasLensFields): ... x: LensField[float] = lens_field(default=0.0) ... y: LensField[Array] = lens_field(static=True)
Source code in src/kups/core/utils/jax.py
linearize(fn, *primals, has_aux=False)
¶
Linearize fn at primals, returning the output and a JVP function.
Type-preserving wrapper around jax.linearize.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Callable[[*T], R1] | Callable[[*T], tuple[R1, R2]]
|
Function to linearize. |
required |
*primals
|
*T
|
Points at which to linearize. |
()
|
has_aux
|
bool
|
If |
False
|
Returns:
| Type | Description |
|---|---|
tuple[R1, Callable[[*T], R1], R2] | tuple[R1, Callable[[*T], R1]]
|
|
Source code in src/kups/core/utils/jax.py
no_jax_tracing(fn)
¶
Decorator to mark functions that should not be used within JAX transformations.
Checks if any input pytree contains a JAX tracer. If so, raises NotJaxCompatibleError. Use this to prevent functions from being called inside jit, vmap, grad, etc.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
C
|
Function to protect from JAX tracing. |
required |
Returns:
| Type | Description |
|---|---|
C
|
Wrapped function that raises NotJaxCompatibleError if traced. |
Example
Source code in src/kups/core/utils/jax.py
non_differentiable(x)
¶
Identity function that raises on differentiation.
Use to mark values that must not be differentiated through. Any
attempt to compute a JVP will raise NotImplementedError.
Source code in src/kups/core/utils/jax.py
sequential_vmap_with_vjp(func)
¶
Create a sequentially vmapped function with custom VJP support.
Wraps a function with sequential vmap (processes batch elements one at a time) and defines custom forward/backward passes for automatic differentiation. This is useful when the underlying function doesn't support standard vmap batching rules.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
func
|
Callable[[*P], R]
|
Function to be sequentially vmapped |
required |
Returns:
| Type | Description |
|---|---|
Callable[[*P], R]
|
Vmapped function with proper VJP (vector-Jacobian product) support |
Source code in src/kups/core/utils/jax.py
shard_map(f, /, *, out_specs, in_specs=None, mesh=None, axis_names=frozenset(), check_vma=True)
¶
Map a function over shards of data for multi-device parallel computation.
Wraps jax.shard_map for SPMD (Single Program Multiple Data) parallel
execution across multiple devices. Each application of the function takes
as input a shard of the mapped-over arguments and produces a shard of the
output.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
f
|
C
|
Callable to be mapped. Each instance of |
required |
out_specs
|
Any
|
A pytree with |
required |
in_specs
|
Any
|
A pytree with |
None
|
mesh
|
Mesh | None
|
A |
None
|
axis_names
|
frozenset
|
Set of axis names from |
frozenset()
|
check_vma
|
bool
|
If True (default), enable additional validity checks and automatic
differentiation optimizations. The validity checks concern whether any
mesh axis names not mentioned in |
True
|
Returns:
| Type | Description |
|---|---|
C
|
A callable that applies the input function |
C
|
according to the |
Example
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
# Create a mesh of devices
devices = mesh_utils.create_device_mesh((2, 2))
mesh = Mesh(devices, axis_names=('x', 'y'))
@partial(shard_map, mesh=mesh, in_specs=P('x', None), out_specs=P('x', None))
def parallel_fn(x):
return x * 2
# x will be sharded along the first axis across 'x' devices
result = parallel_fn(x)
Note
Requires understanding of JAX's sharding model and mesh configuration. For an introduction to sharded data, refer to JAX's sharded computation documentation at https://docs.jax.dev/en/latest/notebooks/shard_map.html.
Source code in src/kups/core/utils/jax.py
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 | |
skip_post_init_if_disabled(post_init)
¶
Skip __post_init__ validation when inside a :func:no_post_init context.
JAX dataclass containers like Table and Buffered validate
invariants (unique keys, matching dimensions, …) in __post_init__.
During deserialization or lens-based structural updates the intermediate
objects may temporarily violate those invariants, so validation is
suppressed via the :func:no_post_init context manager. Decorate a
__post_init__ with this function to opt into that suppression.
Source code in src/kups/core/utils/jax.py
tree_concat(*trees)
¶
Concatenate pytrees along the leading axis.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*trees
|
T
|
Two or more pytrees with matching structure. |
()
|
Returns:
| Type | Description |
|---|---|
T
|
Pytree with each leaf concatenated along axis 0. |
Source code in src/kups/core/utils/jax.py
tree_map(fn, tree, *trees, is_leaf=None)
¶
Apply fn to every leaf of one or more pytrees, with label alignment.
Extends jax.tree.map with support for SupportsTreeMatch nodes. Before
fn is called, any node implementing __tree_match__ is aligned
across all input trees (e.g., Index objects merge their label
vocabularies so integer indices become comparable).
Nodes marked by is_leaf or SupportsTreeMatch are treated as leaves at
the top level. If a SupportsTreeMatch node is also a pytree (not flagged
by is_leaf), its children are recursed into after alignment.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Callable
|
Function applied to each aligned leaf. |
required |
tree
|
T
|
Primary pytree (determines output structure). |
required |
*trees
|
*S
|
Additional pytrees with matching structure. |
()
|
is_leaf
|
Callable[[Any], bool] | None
|
Optional predicate for extra leaf types. |
None
|
Returns:
| Type | Description |
|---|---|
T
|
Transformed pytree with the same structure as |
Source code in src/kups/core/utils/jax.py
tree_scatter_set(item, value, idxs, args)
¶
Set values at indices in a pytree, respecting HasScatterArgs.
Traverses the pytree and applies arr.at[idxs].set(val) to each array
leaf. Nodes implementing HasScatterArgs merge their scatter_args
(e.g. mode="drop") into the call before recursing into children.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
item
|
T
|
Pytree to update. |
required |
value
|
T
|
Pytree of replacement values (same structure as |
required |
idxs
|
Array
|
Integer index array for the scatter operation. |
required |
args
|
ScatterArgs
|
Scatter keyword args passed to |
required |
Returns:
| Type | Description |
|---|---|
T
|
Updated pytree with the same structure as |
Source code in src/kups/core/utils/jax.py
tree_stack(*trees)
¶
Stack pytrees into a new leading dimension.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*trees
|
T
|
Two or more pytrees with matching structure. |
()
|
Returns:
| Type | Description |
|---|---|
T
|
Pytree with each leaf stacked along a new leading axis. |
Source code in src/kups/core/utils/jax.py
tree_structure(x, is_leaf=None)
¶
Return the pytree structure of x.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
T
|
Pytree to inspect. |
required |
is_leaf
|
Callable[[Any], bool] | None
|
Optional predicate marking extra leaf types. |
None
|
Returns:
| Type | Description |
|---|---|
PyTreeDef[T]
|
A |
Source code in src/kups/core/utils/jax.py
tree_where_broadcast_last(accept, tree1, tree2)
¶
Element-wise jnp.where over two pytrees, broadcasting accept on trailing dims.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
accept
|
Array
|
Boolean condition array, broadcast to match each leaf's shape. |
required |
tree1
|
T
|
Pytree selected where |
required |
tree2
|
T
|
Pytree selected where |
required |
Returns:
| Type | Description |
|---|---|
T
|
Pytree with the same structure, each leaf chosen per |
Source code in src/kups/core/utils/jax.py
tree_zeros_like(tree)
¶
vectorize(pyfunc=None, *, excluded=frozenset(), signature=None)
¶
Define a vectorized function with broadcasting.
Wraps jax.numpy.vectorize for defining vectorized functions with broadcasting,
in the style of NumPy's generalized universal functions. It allows defining
functions that are automatically repeated across any leading dimensions, without
the implementation needing to handle higher dimensional inputs.
Unlike numpy.vectorize, this is syntactic sugar for an auto-batching
transformation (vmap) rather than a Python loop, making it considerably
more efficient.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pyfunc
|
C | None
|
Function to vectorize, or |
None
|
excluded
|
frozenset[int]
|
Optional set of integers representing positional arguments for
which the function will not be vectorized. These will be passed directly
to |
frozenset()
|
signature
|
str | None
|
Optional generalized universal function signature, e.g.,
|
None
|
Returns:
| Type | Description |
|---|---|
C | Callable[[C], C]
|
Vectorized version of the given function that broadcasts over batch dimensions. |