kups.relaxation.transforms.max_step_size
¶
Per-system max-step-size clipping transform.
Unlike :func:kups.relaxation.optax.max_step_size, this version takes an
index_prefix pytree (analogous to in_axes in :func:jax.vmap) whose
leaves are :class:Index objects mapping each parameter element to a system.
The maximum displacement is enforced per system, so batching independent
systems through one optimizer is bit-identical to running them one at a time.
MaxStepSize
¶
Bases: Optimizer[Params, MaxStepSizeState]
Clip updates so no element of any system moves more than max_step_size.
Per-element norms are computed along the last axis. For every system, the
maximum norm across all elements assigned to that system (across every leaf
of updates) is found, and updates for those elements are uniformly
scaled so the worst-case norm does not exceed max_step_size. Different
systems are scaled independently.
Attributes:
| Name | Type | Description |
|---|---|---|
max_step_size |
float
|
Maximum allowed per-element displacement norm. |
Source code in src/kups/relaxation/transforms/max_step_size.py
MaxStepSizeState
¶
Optimizer state holding the index_prefix captured at init time.
Attributes:
| Name | Type | Description |
|---|---|---|
index_prefix |
PyTree | None
|
Tree prefix of the parameter pytree whose leaves are
:class: |