Skip to content

kups.application.simulations.md_mlff

MLFF molecular dynamics simulation entry point.

Config

Bases: BaseModel

Top-level configuration for MLFF MD simulations.

Source code in src/kups/application/simulations/md_mlff.py
class Config(BaseModel):
    """Top-level configuration for MLFF MD simulations."""

    run: MdRunConfig
    md: MdParameters
    inp_files: tuple[str | Path, ...]
    model_path: str | Path

MlffMdState

Simulation state for MLFF MD.

Source code in src/kups/application/simulations/md_mlff.py
@dataclass
class MlffMdState:
    """Simulation state for MLFF MD."""

    particles: Table[ParticleId, MDParticles]
    systems: Table[SystemId, MDSystems]
    neighborlist_params: UniversalNeighborlistParameters
    step: Array
    jaxified_model: TojaxedMliap

    @property
    def neighborlist(self) -> NearestNeighborList:
        return DenseNearestNeighborList.from_state(self)

init_state(key, config)

Initialise an MLFF MD state from configuration.

Parameters:

Name Type Description Default
key Array

JAX PRNG key for momenta initialisation.

required
config Config

Simulation configuration.

required

Returns:

Type Description
MlffMdState

Fully constructed MLFF MD state.

Source code in src/kups/application/simulations/md_mlff.py
def init_state(key: Array, config: Config) -> MlffMdState:
    """Initialise an MLFF MD state from configuration.

    Args:
        key: JAX PRNG key for momenta initialisation.
        config: Simulation configuration.

    Returns:
        Fully constructed MLFF MD state.
    """
    model_path = get_model_path(config.model_path)
    jaxified_model = TojaxedMliap.from_zip_file(model_path)
    mb_key = key if config.md.initialize_momenta else None
    all_particles, all_systems = [], []
    for inp_file in config.inp_files:
        particles_i, systems_i = md_state_from_ase(inp_file, config.md, key=mb_key)
        all_particles.append(particles_i)
        all_systems.append(systems_i)
    particles, systems = Table.union(all_particles, all_systems)
    neighborlist_params = UniversalNeighborlistParameters.estimate(
        particles.data.system.counts, systems, jaxified_model.cutoff, base=1
    )
    return MlffMdState(
        particles=particles,
        systems=systems,
        neighborlist_params=neighborlist_params,
        jaxified_model=jaxified_model,
        step=jnp.array([0]),
    )

main()

CLI entry point for MLFF MD simulations.

Source code in src/kups/application/simulations/md_mlff.py
def main() -> None:
    """CLI entry point for MLFF MD simulations."""
    cli = NanoArgs(Config)
    config = cli.parse()
    rich.print(config)
    run(config)
    rich.print(analyze_md_file(config.run.out_file))

run(config)

Run an MLFF MD simulation from the given configuration.

Parameters:

Name Type Description Default
config Config

Simulation configuration.

required
Source code in src/kups/application/simulations/md_mlff.py
def run(config: Config) -> None:
    """Run an MLFF MD simulation from the given configuration.

    Args:
        config: Simulation configuration.
    """
    seed = config.run.seed or time.time_ns()
    chain = key_chain(jax.random.key(seed))
    state = init_state(next(chain), config)
    state_lens = identity_lens(MlffMdState)
    potential = make_tojaxed_from_state(
        state_lens, compute_position_and_unitcell_gradients=True
    )
    propagator = make_md_propagator(state_lens, config.md.integrator, potential)
    state = run_md(next(chain), propagator, state, config.run)