kups.relaxation.transforms.clip_by_global_norm
¶
Per-system L2-norm clipping transform.
Per-system analogue of :func:optax.clip_by_global_norm. For every system,
the L2 norm of every update entry assigned to that system (across all
leaves of the parameter pytree) is computed; entries are then uniformly
rescaled so that per-system L2 norm does not exceed max_norm.
Different systems are clipped independently, so a batched run is
bit-identical to running each system one at a time.
For per-particle (rather than per-system) caps, see
:class:kups.relaxation.transforms.MaxStepSize — that constrains the
largest single particle's displacement, not the system's total.
ClipByGlobalNorm
¶
Bases: Optimizer[Params, ClipByGlobalNormState]
Clip the per-system L2 norm of updates to max_norm.
With index_prefix=None this reduces to the standard
:func:optax.clip_by_global_norm (a single tree-global L2 norm).
Attributes:
| Name | Type | Description |
|---|---|---|
max_norm |
float
|
Maximum allowed per-system L2 norm of the update. |
Source code in src/kups/relaxation/transforms/clip_by_global_norm.py
ClipByGlobalNormState
¶
State carrying the index_prefix captured at init time.
Attributes:
| Name | Type | Description |
|---|---|---|
index_prefix |
PyTree | None
|
Tree prefix of the parameter pytree whose leaves are
|