Skip to content

kups.application.simulations.relax_torch

Config

Bases: BaseModel

Top-level configuration for torch-MLFF relaxation runs.

Source code in src/kups/application/simulations/relax_torch.py
class Config(BaseModel):
    """Top-level configuration for torch-MLFF relaxation runs."""

    run: RelaxRunConfig
    relax: RelaxParameters
    inp_files: tuple[str | Path, ...]
    model: ModelConfig

MACEModelConfig

Bases: BaseModel

Configuration for loading a PyTorch MACE checkpoint.

Source code in src/kups/application/simulations/relax_torch.py
class MACEModelConfig(BaseModel):
    """Configuration for loading a PyTorch MACE checkpoint."""

    backend: Literal["mace"] = "mace"
    model_path: str | Path
    device: str = "cuda"
    dtype: Literal["float32", "float64"] = "float32"

RelaxTorchState

Simulation state for torch-backed MLFF relaxation.

Source code in src/kups/application/simulations/relax_torch.py
@dataclass
class RelaxTorchState:
    """Simulation state for torch-backed MLFF relaxation."""

    particles: Table[ParticleId, RelaxParticles]
    systems: Table[SystemId, RelaxSystems]
    neighborlist_params: UniversalNeighborlistParameters
    opt_state: optax.OptState
    step: Array
    torch_mliap_model: TorchMliap

    @property
    def neighborlist(self) -> NearestNeighborList:
        return DenseNearestNeighborList.from_state(self)

UMAModelConfig

Bases: BaseModel

Configuration for loading a Meta FAIR Chemistry UMA checkpoint.

Source code in src/kups/application/simulations/relax_torch.py
class UMAModelConfig(BaseModel):
    """Configuration for loading a Meta FAIR Chemistry UMA checkpoint."""

    backend: Literal["uma"] = "uma"
    model_path: str | Path
    device: str = "cuda"
    task_name: Literal["omat", "omol", "oc20", "odac", "omc"] = "omat"
    inference_settings: Literal["default", "turbo"] = "default"

init_state(config, opt_init)

Initialise relaxation state from config.

Source code in src/kups/application/simulations/relax_torch.py
def init_state(config: Config, opt_init: OptInit) -> RelaxTorchState:
    """Initialise relaxation state from config."""
    torch_mliap_model = _load_torch_model(config.model)
    all_particles, all_systems = [], []
    for inp_file in config.inp_files:
        logging.info(f"Loading structure from {inp_file}")
        particles_i, systems_i = relax_state_from_ase(inp_file)
        all_particles.append(particles_i)
        all_systems.append(systems_i)
    particles, systems = Table.union(all_particles, all_systems)
    neighborlist_params = UniversalNeighborlistParameters.estimate(
        particles.data.system.counts, systems, torch_mliap_model.cutoff
    )
    opt_state = opt_init(particles, systems)
    return RelaxTorchState(
        particles=particles,
        systems=systems,
        neighborlist_params=neighborlist_params,
        opt_state=opt_state,
        step=jnp.array([0]),
        torch_mliap_model=torch_mliap_model,
    )

main()

CLI entry point for torch-MLFF relaxation.

Source code in src/kups/application/simulations/relax_torch.py
def main() -> None:
    """CLI entry point for torch-MLFF relaxation."""
    cli = NanoArgs(Config)
    config = cli.parse()
    rich.print(config)
    run(config)
    rich.print(analyze_relax_file(config.run.out_file))

run(config)

Run a torch-MLFF relaxation.

Source code in src/kups/application/simulations/relax_torch.py
def run(config: Config) -> None:
    """Run a torch-MLFF relaxation."""
    key = jax.random.key(config.run.seed or time.time_ns())
    state_lens = identity_lens(RelaxTorchState)
    optimizer = make_optimizer(config.relax.optimizer)
    potential = make_torch_mliap_from_state(
        state_lens, compute_position_and_cell_gradients=True
    )
    propagator, opt_init = make_relax_propagator(
        state_lens, potential, optimizer, config.relax.optimize_cell
    )
    state = init_state(config, opt_init)
    logging.info("Starting relaxation")
    run_relax(key, propagator, state, config.run)