kups.relaxation.propagator
¶
Gradient-based relaxation using Optax optimizers.
This module provides a Propagator implementation for gradient-based optimization using Optax.
The RelaxationPropagator supports both standard optimizers (Adam, SGD) and line-search optimizers (L-BFGS, backtracking).
RelaxationPropagator
¶
Bases: Propagator[State]
Unified propagator for gradient-based optimization using Optax.
Uses a Potential to compute energy and gradients. Supports both standard optimizers (Adam, SGD) and line-search optimizers (L-BFGS, backtracking).
For line-search optimizers, the potential is evaluated at trial points during the line search. For standard optimizers, it's evaluated once per step.
After computing energy and gradients, the potential's patch is applied to the state. This allows potentials to update internal state (e.g., neighbor lists) at each relaxation step.
Class Type Parameters:
| Name | Bound or Constraints | Description | Default |
|---|---|---|---|
State
|
The simulation state type |
required | |
PyTree
|
The type of the property being optimized (must match Potential's gradient type) |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
potential |
Potential[State, PyTree, Any, Any]
|
Potential that computes energy and gradients of type PyTree |
property |
Lens[State, PyTree]
|
Lens to get/set the property being optimized |
opt_state |
Lens[State, OptState]
|
Lens to get/set the Optax optimizer state |
optimizer |
GradientTransformationExtraArgs
|
Optax gradient transformation |
Example
import optax
from kups.relaxation.propagator import RelaxationPropagator
from kups.core.potential import MappedPotential
# Standard optimizer (Adam)
propagator = RelaxationPropagator(
potential=my_potential,
property=positions_lens,
opt_state=lens(lambda s: s.opt_state),
optimizer=optax.adam(0.01),
)
# Line-search optimizer (L-BFGS)
propagator = RelaxationPropagator(
potential=my_potential,
property=positions_lens,
opt_state=lens(lambda s: s.opt_state),
optimizer=optax.lbfgs(),
)
# With gradient projection
mapped_potential = MappedPotential(
full_potential,
gradient_map=lambda g: g.positions,
hessian_map=lambda h: h,
)
propagator = RelaxationPropagator(
potential=mapped_potential,
property=positions_lens,
opt_state=lens(lambda s: s.opt_state),
optimizer=optax.lbfgs(),
)
state = propagator(key, state) # One optimization step
Source code in src/kups/relaxation/propagator.py
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | |