Skip to content

kups.application.simulations.relax_mlff

MLFF structure relaxation entry point.

init_state(config, opt_init)

Initialize relaxation state from config.

Parameters:

Name Type Description Default
config Config

Run configuration.

required
opt_init OptInit

Optimizer initializer.

required

Returns:

Type Description
RelaxMlffState

Initial relaxation state.

Source code in src/kups/application/simulations/relax_mlff.py
def init_state(config: Config, opt_init: OptInit) -> RelaxMlffState:
    """Initialize relaxation state from config.

    Args:
        config: Run configuration.
        opt_init: Optimizer initializer.

    Returns:
        Initial relaxation state.
    """
    model_path = get_model_path(config.model_path)
    jaxified_model = TojaxedMliap.from_zip_file(model_path)
    all_particles, all_systems = [], []
    for inp_file in config.inp_files:
        logging.info(f"Loading structure from {inp_file}")
        particles_i, systems_i = relax_state_from_ase(inp_file)
        all_particles.append(particles_i)
        all_systems.append(systems_i)
    particles, systems = Table.union(all_particles, all_systems)
    neighborlist_params = UniversalNeighborlistParameters.estimate(
        particles.data.system.counts, systems, jaxified_model.cutoff
    )
    opt_state = opt_init(
        (particles.data.positions, systems.data.unitcell.lattice_vectors)
    )
    return RelaxMlffState(
        particles=particles,
        systems=systems,
        neighborlist_params=neighborlist_params,
        opt_state=opt_state,
        step=jnp.array([0]),
        jaxified_model=jaxified_model,
    )

main()

CLI entry point for MLFF relaxation.

Source code in src/kups/application/simulations/relax_mlff.py
def main() -> None:
    """CLI entry point for MLFF relaxation."""
    cli = NanoArgs(Config)
    config = cli.parse()
    rich.print(config)
    run(config)
    rich.print(analyze_relax_file(config.run.out_file))

run(config)

Run structure relaxation.

Parameters:

Name Type Description Default
config Config

Full run configuration.

required
Source code in src/kups/application/simulations/relax_mlff.py
def run(config: Config) -> None:
    """Run structure relaxation.

    Args:
        config: Full run configuration.
    """
    key = jax.random.key(config.run.seed or time.time_ns())
    state_lens = identity_lens(RelaxMlffState)
    optimizer = make_optimizer(config.relax.optimizer)
    potential = make_tojaxed_from_state(
        state_lens, compute_position_and_unitcell_gradients=True
    )
    propagator, opt_init = make_relax_propagator(
        state_lens, potential, optimizer, config.relax.optimize_cell
    )
    state = init_state(config, opt_init)
    logging.info("Starting relaxation")
    run_relax(key, propagator, state, config.run)