jax_dna.simulators.jax_md

jax_md sampler implementation for jax_dna.

Submodules

Classes

JaxMDSimulator

A sampler based on running a jax_md simulation routine.

NeighborList

Neighbor list for managing unbonded neighbors.

NoNeighborList

A dummy neighbor list that does nothing.

SimulationState

This is a protocol to help with typing.

StaticSimulatorParams

Static parameters for the simulator.

Package Contents

class jax_dna.simulators.jax_md.JaxMDSimulator[source]

Bases: jax_dna.simulators.base.BaseSimulation

A sampler based on running a jax_md simulation routine.

energy_configs: list[jax_dna.energy.configuration.BaseConfiguration]
energy_fns: list[jax_dna.energy.base.BaseEnergyFunction]
simulator_params: jax_dna.simulators.jax_md.utils.StaticSimulatorParams
space: jax_md.space.Space
transform_fn: collections.abc.Callable
simulator_init: collections.abc.Callable[[collections.abc.Callable, collections.abc.Callable], jax_md.simulate.Simulator]
neighbors: jax_dna.simulators.jax_md.utils.NeighborHelper
topology: jax_dna.input.topology.Topology
__post_init__() None[source]

Builds the run function using the provided parameters.

class jax_dna.simulators.jax_md.NeighborList[source]

Bases: NeighborHelper

Neighbor list for managing unbonded neighbors.

displacement_fn: collections.abc.Callable
topology: jax_dna.input.topology.Topology
r_cutoff: float
dr_threshold: float
box_size: jax.numpy.ndarray
init_positions: jax_md.rigid_body.RigidBody
__post_init__() None[source]

Initialize the neighbor list.

property idx: jax.numpy.ndarray

Return the indices of the unbonded neighbors.

allocate(locs: jax_md.rigid_body.RigidBody) NeighborList[source]

Allocate memory for the neighbor list.

update(locs: jax_md.rigid_body.RigidBody) NeighborList[source]

Update the neighbor list.

class jax_dna.simulators.jax_md.NoNeighborList[source]

Bases: NeighborHelper

A dummy neighbor list that does nothing.

unbonded_nbrs: jax.numpy.ndarray
property idx: jax.numpy.ndarray

Return the indices of the unbonded neighbors.

allocate(locs: jax_md.rigid_body.RigidBody) NoNeighborList[source]

Allocate memory for the neighbor list.

update(locs: jax_md.rigid_body.RigidBody) NoNeighborList[source]

Update the neighbor list.

class jax_dna.simulators.jax_md.SimulationState[source]

Bases: Protocol

This is a protocol to help with typing.

Every state implements at least position and mass. More info about the specific states can be found here:

https://github.com/jax-md/jax-md/blob/main/jax_md/simulate.py

position: jax_md.rigid_body.RigidBody
mass: jax_md.rigid_body.RigidBody
class jax_dna.simulators.jax_md.StaticSimulatorParams[source]

Static parameters for the simulator.

seq: jax_dna.utils.types.Arr_Nucleotide
mass: jax_md.rigid_body.RigidBody
gamma: jax_md.rigid_body.RigidBody
bonded_neighbors: jax.numpy.ndarray
checkpoint_every: int
dt: float
kT: float
property sim_init_fn: collections.abc.Callable

Return the simulator init function.

property init_fn: dict[str, jax_md.rigid_body.RigidBody | jax.numpy.ndarray]

Return the kwargs for initial state of the simulator.

property step_fn: dict[str, jax_md.rigid_body.RigidBody | jax.numpy.ndarray]

Return the kwargs for the step_fn of the simulator.