Basic Usage
Installation
To use jax_dna, first install it using pip:
pip install git+https://github.com/rkruegs123/jax-dna.git
Basic Usage
The two basic use cases for jax_dna are:
Running a simulation
Running an optimization
In this section, we’ll discuss how to run simple simulations and simple optimizations.
Running a single simulation
jax_dna currently supports two simulation engine:
jax_md and
oxDNA.
Regardless of which engine you choose, setting up the system to simulate is the
same, as jax_dna supports reading oxDNA input, topology, and trajectory
files. See jax_dna.input for more details on the input
format.
import jax_dna.input.topology as jdna_top
import jax_dna.input.trajectory as jdna_traj
topology = topology.from_oxdna_file("path/to/oxdna/topology.top")
initial_positions = jdna_traj.from_file("path/to/oxdna/trajectory.conf").states[0].to_rigid_body()
Using jax_md
To run a simulation using jax_md requires a working jax installation
which is installed alongside jax_dna via pip if it isn’t installed
already. For information on installing jax_md and jax, please refer to
the their respective documentation. For more details on the jax_md simulator
see jax_dna.simulators.jax_md.
Running a simulation using jax_md involves reading some input data as shown
above and then building the energy function:
import functools
import jax.numpy as jnp
import jax_md
import jax_dna.energy.dna1 as jdna_energy
experiment_config, energy_config = jdna_energy.default_configs()
dt = experiment_config["dt"]
kT = experiment_config["kT"]
diff_coef = experiment_config["diff_coef"]
rot_diff_coef = experiment_config["rot_diff_coef"]
# These are special values for the jax_md simulator
gamma = jax_md.rigid_body.RigidBody(
center=jnp.array([kT / diff_coef], dtype=jnp.float64),
orientation=jnp.array([kT / rot_diff_coef], dtype=jnp.float64),
)
mass = jax_md.rigid_body.RigidBody(
center=jnp.array([experiment_config["nucleotide_mass"]], dtype=jnp.float64),
orientation=jnp.array([experiment_config["moment_of_inertia"]], dtype=jnp.float64),
)
geometry = energy_config["geometry"]
transform_fn = functools.partial(
jdna_energy.Nucleotide.from_rigid_body,
com_to_backbone=geometry["com_to_backbone"],
com_to_hb=geometry["com_to_hb"],
com_to_stacking=geometry["com_to_stacking"],
)
# The jax_md simulator needs an energy function. We can use the default
# energy functions and configurations for dna1 simulations. For more
# information on energy functions and configurations, see the documentation.
energy_fn_configs = jdna_energy.default_energy_configs()
params = [{} for _ in range(len(energy_fn_configs))]
energy_fns = jdna_energy.default_energy_fns()
# Build the energy function
energy_function = jdna_jax_md.build_energy_function(topology, initial_positions)
The variable energy_function is a function that takes in a set of rigid
bodies and returns the total energy of the system. To run a simulation, we pass
that function to the jax_md simulator:
import jax_dna.simulators.jax_md as jdna_jaxmd
simulator = jdna_jaxmd.JaxMDSimulator(
energy_configs=energy_fn_configs,
energy_fns=energy_fns,
topology=topology,
simulator_params=jdna_jaxmd.StaticSimulatorParams(
seq=jnp.array(topology.seq),
mass=mass,
bonded_neighbors=topology.bonded_neighbors,
# this is gradient checkpointing which isn't used in this examples
checkpoint_every=100,
dt=dt,
kT=kT,
gamma=gamma,
),
space=jax_md.space.free(),
transform_fn=transform_fn,
simulator_init=jax_md.simulate.nvt_langevin,
neighbors=jdna_jaxmd.NoNeighborList(unbonded_nbrs=topology.unbonded_neighbors),
)
key = jax.random.PRNGKey(0)
sim_fn = jax.jit(lambda opts: simulator.run(opts, initial_positions, run_config["n_steps"], key))
trajectory = sim_fn(params)
A runnable version of this example can be found in the examples folder in the repository.
Using oxDNA
When running oxDNA simulations, jax_dna acts as a thin wrapper around the
oxDNA executable. To run a simulation, you need to have a working oxDNA
installation. For more information on installing oxDNA, please refer to the
oxDNA documentation. Additionally, the following environment variable must
point to the oxDNA executable: OXDNA_BIN_PATH
from pathlib import Path
import jax_dna.input.trajectory as jdna_traj
import jax_dna.input.topology as jdna_top
import jax_dna.simulators.oxdna as jdna_oxdna
import jax_dna.utils.types as jdna_types
input_dir = Path("path/to/oxdna-input/dir")
simulator = jdna_oxdna.oxDNASimulator(
input_dir=input_dir,
sim_type=jdna_types.oxDNASimulatorType.DNA1,
)
simulator.run()
trajectory = jdna_traj.from_file(
input_dir / "output.dat",
strand_lengths=jdna_top.from_oxdna_file(input_dir / "sys.top").strand_counts,
)
print("Length of trajectory: ", trajectory.state_rigid_body.center.shape[0])
A runnable version of this example can be found in the examples folder in the repository.
Running a simple optimization
The main advantage in using jax_dna is the ability to run optimizations. The
optimizations can be run directly through the simulation using jax_md or
using oxDNA and the DiffTRe algorithm.
As an example we will run a simple optimization, that will find the energy function parameters that produce a desired propeller twist.
This setup is the same for using either the jax_md or oxDNA simulators
but the implementation is slightly different.
Using jax_md
Below is an example of running an optimization using jax_md. The example
will optimize the energy function parameters to produce the target propeller
twist.
First setup the system:
Then setup the energy function, configs, and get the parameters that we want to optimize:
Next setup the simulator:
Now set up the loss that will be optimized and the function that we will just to calculate the loss and gradients:
Finally, run the optimization:
As the optimization runs you should see the propeller twist getting closer to the target propeller twist (with some noisiness).
Using oxDNA
Different from jax_md we cannot differentiate through the oxDNA simulation.
Instead we use the DiffTRe algorithm to optimize the energy. The
optimization for oxDNA / DiffTRe is more complicated than the jax_md
optimization. For these kinds of optimizations go to Advanced Usage.