jax_dna.simulators.jax_md.jaxmd

A sampler based on running a jax_md simulation routine.

Attributes

REQUIRED_KEYS

ERR_MISSING_REQUIRED_KEYS

SIM_STATE

Classes

JaxMDSimulator

A sampler based on running a jax_md simulation routine.

Functions

build_run_fn(→ collections.abc.Callable[[dict[str, ...)

Builds the run function for the jax_md simulation.

Module Contents

jax_dna.simulators.jax_md.jaxmd.REQUIRED_KEYS
jax_dna.simulators.jax_md.jaxmd.ERR_MISSING_REQUIRED_KEYS = 'Missing required keys: {}'
jax_dna.simulators.jax_md.jaxmd.SIM_STATE
class jax_dna.simulators.jax_md.jaxmd.JaxMDSimulator[source]

Bases: jax_dna.simulators.base.BaseSimulation

A sampler based on running a jax_md simulation routine.

energy_configs: list[jax_dna.energy.configuration.BaseConfiguration]
energy_fns: list[jax_dna.energy.base.BaseEnergyFunction]
simulator_params: jax_dna.simulators.jax_md.utils.StaticSimulatorParams
space: jax_md.space.Space
transform_fn: collections.abc.Callable
simulator_init: collections.abc.Callable[[collections.abc.Callable, collections.abc.Callable], jax_md.simulate.Simulator]
neighbors: jax_dna.simulators.jax_md.utils.NeighborHelper
topology: jax_dna.input.topology.Topology
__post_init__() None[source]

Builds the run function using the provided parameters.

jax_dna.simulators.jax_md.jaxmd.build_run_fn(energy_configs: list[jax_dna.energy.configuration.BaseConfiguration], energy_fns: list[jax_dna.energy.base.BaseEnergyFunction], simulator_params: jax_dna.simulators.jax_md.utils.StaticSimulatorParams, space: jax_md.space.Space, transform_fn: collections.abc.Callable, simulator_init: collections.abc.Callable[[collections.abc.Callable, collections.abc.Callable], jax_md.simulate.Simulator], neighbors: jax_dna.simulators.jax_md.utils.NeighborHelper) collections.abc.Callable[[dict[str, float], jax_md.rigid_body.RigidBody, int, jax.random.PRNGKey], jax_dna.input.trajectory.Trajectory][source]

Builds the run function for the jax_md simulation.