kups.relaxation.optax
¶
Optax-based optimizers for structure relaxation.
Transform = str | dict[str, bool | int | float | str | list | None]
module-attribute
¶
A single transform spec: either a name string or a dict with "transform" key.
TransformationConfig = list[Transform]
module-attribute
¶
Ordered list of transform specs to chain into an optimizer.
ScaleByFireState
¶
Bases: NamedTuple
State for scale_by_fire transform.
Attributes:
| Name | Type | Description |
|---|---|---|
velocity |
Params
|
Velocity estimate (PyTree matching params). |
dt |
Array
|
Current adaptive timestep. |
alpha |
Array
|
Current velocity mixing parameter. |
n_pos |
Array
|
Count of consecutive positive power steps. |
Source code in src/kups/relaxation/optax/fire.py
get_transform(transform)
¶
Convert a transform config entry to an Optax GradientTransformation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
transform
|
Transform
|
Either a plain string name (e.g. |
required |
Returns:
| Type | Description |
|---|---|
GradientTransformation
|
The constructed GradientTransformation. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the transform name is not found in custom transforms or optax. |
Source code in src/kups/relaxation/optax/optimizer.py
get_transformations(transformations)
¶
Convert a list of transform configs to Optax GradientTransformations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
transformations
|
TransformationConfig
|
List of transform specifications. |
required |
Returns:
| Type | Description |
|---|---|
list[GradientTransformation]
|
List of GradientTransformations in the same order. |
Source code in src/kups/relaxation/optax/optimizer.py
make_optimizer(transformations)
¶
Create a chained optimizer from a list of transform configs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
transformations
|
TransformationConfig
|
List of transform specifications. |
required |
Returns:
| Type | Description |
|---|---|
GradientTransformationExtraArgs
|
Chained Optax GradientTransformation. |
Example
config = [ ... {"transform": "clip_by_global_norm", "max_norm": 1.0}, ... {"transform": "scale_by_fire", "dt_start": 0.1}, ... ] optimizer = make_optimizer(config)
Source code in src/kups/relaxation/optax/optimizer.py
scale_by_ase_lbfgs(memory_size=100, alpha=70.0)
¶
L-BFGS preconditioner with ASE-style initial inverse Hessian.
Equivalent to optax.scale_by_lbfgs except the initial Hessian
approximation is (1/alpha) * I (following the ASE convention)
rather than the curvature-based initialization used by default in Optax.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
memory_size
|
int
|
Number of past (param, gradient) differences to store. Must be >= 1. |
100
|
alpha
|
float
|
Initial inverse Hessian is |
70.0
|
Returns:
| Type | Description |
|---|---|
GradientTransformation
|
Optax GradientTransformation applying L-BFGS preconditioning. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Example
optimizer = optax.chain( ... scale_by_ase_lbfgs(memory_size=10, alpha=70.0), ... optax.scale(-1.0), ... )
Source code in src/kups/relaxation/optax/lbfgs.py
scale_by_fire(dt_start=0.1, dt_max=None, dt_min=None, max_step=0.2, f_inc=1.1, f_dec=0.5, alpha_start=0.1, f_alpha=0.99, n_min=5)
¶
FIRE (Fast Inertial Relaxation Engine) optimizer.
Composable Optax transform implementing the FIRE algorithm for structure relaxation. Can be chained with other transforms.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dt_start
|
float
|
Initial timestep. |
0.1
|
dt_max
|
float | None
|
Maximum timestep. Defaults to 10 * dt_start. |
None
|
dt_min
|
float | None
|
Minimum timestep. Defaults to dt_start * 1e-4. |
None
|
max_step
|
float | None
|
Maximum step size (clips position updates). Defaults to 0.2 Å. Set to None to disable clipping. |
0.2
|
f_inc
|
float
|
Factor to increase dt when making progress. |
1.1
|
f_dec
|
float
|
Factor to decrease dt on bad step. |
0.5
|
alpha_start
|
float
|
Initial velocity mixing parameter. |
0.1
|
f_alpha
|
float
|
Factor to decay alpha when making progress. |
0.99
|
n_min
|
int
|
Minimum positive power steps before increasing dt. |
5
|
Returns:
| Type | Description |
|---|---|
GradientTransformation
|
Optax GradientTransformation implementing FIRE. |
Reference
Bitzek et al., Phys. Rev. Lett. 97, 170201 (2006).
Source code in src/kups/relaxation/optax/fire.py
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | |