Basic Usage

Installation

To use jax_dna, first install it using pip:

pip install git+https://github.com/rkruegs123/jax-dna.git

Basic Usage

The two basic use cases for jax_dna are:

  1. Running a simulation

  2. Running an optimization

In this section, we’ll discuss how to run simple simulations and simple optimizations.

Running a single simulation

jax_dna currently supports two simulation engine: jax_md and oxDNA.

Regardless of which engine you choose, setting up the system to simulate is the same, as jax_dna supports reading oxDNA input, topology, and trajectory files. See jax_dna.input for more details on the input format.

import jax_dna.input.topology as jdna_top
import jax_dna.input.trajectory as jdna_traj

topology = topology.from_oxdna_file("path/to/oxdna/topology.top")
initial_positions = jdna_traj.from_file("path/to/oxdna/trajectory.conf").states[0].to_rigid_body()

Using jax_md

To run a simulation using jax_md requires a working jax installation which is installed alongside jax_dna via pip if it isn’t installed already. For information on installing jax_md and jax, please refer to the their respective documentation. For more details on the jax_md simulator see jax_dna.simulators.jax_md.

Running a simulation using jax_md involves reading some input data as shown above and then building the energy function:

import functools

import jax.numpy as jnp
import jax_md

import jax_dna.energy.dna1 as jdna_energy

experiment_config, energy_config = jdna_energy.default_configs()

dt = experiment_config["dt"]
kT = experiment_config["kT"]
diff_coef = experiment_config["diff_coef"]
rot_diff_coef = experiment_config["rot_diff_coef"]

# These are special values for the jax_md simulator
gamma = jax_md.rigid_body.RigidBody(
    center=jnp.array([kT / diff_coef], dtype=jnp.float64),
    orientation=jnp.array([kT / rot_diff_coef], dtype=jnp.float64),
)
mass = jax_md.rigid_body.RigidBody(
    center=jnp.array([experiment_config["nucleotide_mass"]], dtype=jnp.float64),
    orientation=jnp.array([experiment_config["moment_of_inertia"]], dtype=jnp.float64),
)

geometry = energy_config["geometry"]
transform_fn = functools.partial(
    jdna_energy.Nucleotide.from_rigid_body,
    com_to_backbone=geometry["com_to_backbone"],
    com_to_hb=geometry["com_to_hb"],
    com_to_stacking=geometry["com_to_stacking"],
)

# The jax_md simulator needs an energy function. We can use the default
# energy functions and configurations for dna1 simulations. For more
# information on energy functions and configurations, see the documentation.
energy_fn_configs = jdna_energy.default_energy_configs()
params = [{} for _ in range(len(energy_fn_configs))]
energy_fns = jdna_energy.default_energy_fns()

# Build the energy function
energy_function = jdna_jax_md.build_energy_function(topology, initial_positions)

The variable energy_function is a function that takes in a set of rigid bodies and returns the total energy of the system. To run a simulation, we pass that function to the jax_md simulator:

import jax_dna.simulators.jax_md as jdna_jaxmd

simulator = jdna_jaxmd.JaxMDSimulator(
    energy_configs=energy_fn_configs,
    energy_fns=energy_fns,
    topology=topology,
    simulator_params=jdna_jaxmd.StaticSimulatorParams(
        seq=jnp.array(topology.seq),
        mass=mass,
        bonded_neighbors=topology.bonded_neighbors,
        # this is gradient checkpointing which isn't used in this examples
        checkpoint_every=100,
        dt=dt,
        kT=kT,
        gamma=gamma,
    ),
    space=jax_md.space.free(),
    transform_fn=transform_fn,
    simulator_init=jax_md.simulate.nvt_langevin,
    neighbors=jdna_jaxmd.NoNeighborList(unbonded_nbrs=topology.unbonded_neighbors),
)

key = jax.random.PRNGKey(0)
sim_fn = jax.jit(lambda opts: simulator.run(opts, initial_positions, run_config["n_steps"], key))
trajectory = sim_fn(params)

A runnable version of this example can be found in the examples folder in the repository.

Using oxDNA

When running oxDNA simulations, jax_dna acts as a thin wrapper around the oxDNA executable. To run a simulation, you need to have a working oxDNA installation. For more information on installing oxDNA, please refer to the oxDNA documentation. Additionally, the following environment variable must point to the oxDNA executable: OXDNA_BIN_PATH

from pathlib import Path

import jax_dna.input.trajectory as jdna_traj
import jax_dna.input.topology as jdna_top
import jax_dna.simulators.oxdna as jdna_oxdna
import jax_dna.utils.types as jdna_types

input_dir = Path("path/to/oxdna-input/dir")

simulator = jdna_oxdna.oxDNASimulator(
    input_dir=input_dir,
    sim_type=jdna_types.oxDNASimulatorType.DNA1,
)

simulator.run()

trajectory = jdna_traj.from_file(
    input_dir / "output.dat",
    strand_lengths=jdna_top.from_oxdna_file(input_dir / "sys.top").strand_counts,
)

print("Length of trajectory: ", trajectory.state_rigid_body.center.shape[0])

A runnable version of this example can be found in the examples folder in the repository.

Running a simple optimization

The main advantage in using jax_dna is the ability to run optimizations. The optimizations can be run directly through the simulation using jax_md or using oxDNA and the DiffTRe algorithm.

As an example we will run a simple optimization, that will find the energy function parameters that produce a desired propeller twist.

This setup is the same for using either the jax_md or oxDNA simulators but the implementation is slightly different.

Using jax_md

Below is an example of running an optimization using jax_md. The example will optimize the energy function parameters to produce the target propeller twist.

First setup the system:

Then setup the energy function, configs, and get the parameters that we want to optimize:

Next setup the simulator:

Now set up the loss that will be optimized and the function that we will just to calculate the loss and gradients:

Finally, run the optimization:

As the optimization runs you should see the propeller twist getting closer to the target propeller twist (with some noisiness).

Using oxDNA

Different from jax_md we cannot differentiate through the oxDNA simulation. Instead we use the DiffTRe algorithm to optimize the energy. The optimization for oxDNA / DiffTRe is more complicated than the jax_md optimization. For these kinds of optimizations go to Advanced Usage.