Source code for jax_dna.input.trajectory

"""Trajectory information for RNA/DNA strands."""

import concurrent.futures as cf
import functools
import itertools
import multiprocessing as mp
from pathlib import Path

import chex
import jax.numpy as jnp
import jax_md
import numpy as np

import jax_dna.utils.math as jdm
import jax_dna.utils.types as typ

TRAJECTORY_TIMES_DIMS = 1
TRAJECTORY_ENERGIES_SHAPE = (None, 3)
NUCLEOTIDE_STATE_SHAPE = (None, 15)

ERR_TRAJECTORY_FILE_NOT_FOUND = "Trajectory file not found: {}"
ERR_TRAJECTORY_N_NUCLEOTIDE_STRAND_LEGNTHS = "n_nucleotides and sum(strand_lengths) do not match"
ERR_TRAJECTORY_TIMES_TYPE = "times must be a numpy array"
ERR_TRAJECTORY_ENERGIES_TYPE = "energies must be a numpy array"
ERR_TRAJECTORY_T_E_S_LENGTHS = "times, energies, and states do not have the same length"
ERR_TRAJECTORY_TIMES_DIMS = "times must be a 1D array"
ERR_TRAJECTORY_ENERGIES_SHAPE = "energies must be a 2D array with shape (n_states, 3)"

ERR_NUCLEOTIDE_STATE_TYPE = "Invalid type for nucleotide states:"
ERR_NUCLEOTIDE_STATE_SHAPE = "Invalid shape for nucleotide states:"

ERR_FIXED_BOX_SIZE = "Only trajecories in a fixed box size are supported"


[docs] @chex.dataclass(frozen=True) class Trajectory: """Trajectory information for a RNA/DNA strand.""" n_nucleotides: int strand_lengths: list[int] times: typ.Arr_States energies: typ.Arr_States_3 states: list["NucleotideState"]
[docs] def __post_init__(self) -> None: """Validate the input.""" if self.n_nucleotides != sum(self.strand_lengths): raise ValueError(ERR_TRAJECTORY_N_NUCLEOTIDE_STRAND_LEGNTHS) if not isinstance(self.times, np.ndarray): raise TypeError(ERR_TRAJECTORY_TIMES_TYPE) if not isinstance(self.energies, np.ndarray): raise TypeError(ERR_TRAJECTORY_ENERGIES_TYPE) if len(self.times) != len(self.energies) or len(self.times) != len(self.states): raise ValueError(ERR_TRAJECTORY_T_E_S_LENGTHS) if len(self.times.shape) != TRAJECTORY_TIMES_DIMS: raise ValueError(ERR_TRAJECTORY_TIMES_DIMS) if ( len(self.energies.shape) != len(TRAJECTORY_ENERGIES_SHAPE) or self.energies.shape[1] != TRAJECTORY_ENERGIES_SHAPE[1] ): raise ValueError(ERR_TRAJECTORY_ENERGIES_SHAPE)
@property def state_rigid_bodies(self) -> list[jax_md.rigid_body.RigidBody]: """Convert the states to a list of rigid bodies.""" return [state.to_rigid_body() for state in self.states] @property def state_rigid_body(self) -> jax_md.rigid_body.RigidBody: """Convert the states to a single rigid body.""" return jax_md.rigid_body.RigidBody( center=jnp.stack([state.com for state in self.states]), orientation=jax_md.rigid_body.Quaternion(jnp.stack([state.quaternions for state in self.states])), )
[docs] def slice(self, key: int | slice) -> "Trajectory": """Get a subset of the trajectory.""" return Trajectory( n_nucleotides=self.n_nucleotides, strand_lengths=self.strand_lengths, times=self.times[key], energies=self.energies[key], states=self.states[key], )
[docs] def __repr__(self) -> str: """Return a string representation of the trajectory.""" return "\n".join( [ "Trajectory:", f"n_nucleotides: {self.n_nucleotides}", f"strand_lengths: {self.strand_lengths}", f"# times: {len(self.times)}", f"# energies: {len(self.energies)}", f"# states: {len(self.states)}", ] )
[docs] @chex.dataclass(frozen=True) class NucleotideState: """State information for the nucleotides in a single state.""" array: typ.Arr_Nucleotide_15
[docs] def __post_init__(self) -> None: """Validate the input array.""" if not isinstance(self.array, np.ndarray): raise TypeError(ERR_NUCLEOTIDE_STATE_TYPE + str(type(self.array))) if len(self.array.shape) != len(NUCLEOTIDE_STATE_SHAPE) or self.array.shape[1] != NUCLEOTIDE_STATE_SHAPE[1]: raise ValueError(ERR_NUCLEOTIDE_STATE_SHAPE + str(self.array.shape))
@property def com(self) -> typ.Arr_Nucleotide_3: """Center of mass of the nucleotides.""" return self.array[:, :3] @property def back_base_vector(self) -> typ.Arr_Nucleotide_3: """Backbone base vector.""" return self.array[:, 3:6] @property def base_normal(self) -> typ.Arr_Nucleotide_3: """Base normal to the base plane.""" return self.array[:, 6:9] @property def velocity(self) -> typ.Arr_Nucleotide_3: """Velocity of the nucleotides.""" return self.array[:, 9:12] @property def angular_velocity(self) -> typ.Arr_Nucleotide_3: """Angular velocity of the nucleotides.""" return self.array[:, 12:15] @property def euler_angles(self) -> tuple[typ.Arr_Nucleotide, typ.Arr_Nucleotide, typ.Arr_Nucleotide]: """Convert principal axes to Tait-Bryan Euler angles.""" return jdm.principal_axes_to_euler_angles( self.back_base_vector, np.cross(self.base_normal, self.back_base_vector), self.base_normal, ) @property def quaternions(self) -> typ.Arr_Nucleotide_4: """Convert Euler angles to quaternions.""" return jdm.euler_angles_to_quaternion(*self.euler_angles)
[docs] def to_rigid_body(self) -> jax_md.rigid_body.RigidBody: """Convert the nucleotide state to jax-md rigid bodies.""" return jax_md.rigid_body.RigidBody( self.com, jax_md.rigid_body.Quaternion(vec=self.quaternions), )
[docs] def validate_box_size(state_box_sizes: list[typ.Vector3D]) -> None: """Validate the volume for a simulation is fixed.""" state_box_sizes = np.array(state_box_sizes) if not np.all(state_box_sizes == state_box_sizes[0]): raise ValueError(ERR_FIXED_BOX_SIZE)
[docs] def from_file( path: typ.PathOrStr, strand_lengths: list[int], *, is_oxdna: bool = True, n_processes: int = 1, ) -> Trajectory: """Parse a trajectory file. Trajectory files are in the following format: t = number b = number number number E = number number number com_x com_y com_z a1_x a1_y a1_z a3_x a3_y a3_z v_x v_y v_z L_x L_y L_z ...repeated n_nucleotides times in total com_x com_y com_z a1_x a1_y a1_z a3_x a3_y a3_z v_x v_y v_z L_x L_y L_z where the com_x, ..., L_z are all floating point numbers. This can be repeated a total of "timestep" number of times. In oxDNA the states are stored in 3'->5' order so we flip the order per strand and need the topology to get the boundaries of each strand. Args: path (PathOrStr): path to the trajectory file strand_lengths (list[int]): if this is an oxDNA trajectory, the lengths of each strand, so that they can be flipped to 5'->3' order is_oxdna (bool): whether the trajectory is in oxDNA format n_processes (int): number of processors to use for reading the file Returns: Trajectory: trajectory information """ path = Path(path) if not path.exists(): raise FileNotFoundError(ERR_TRAJECTORY_FILE_NOT_FOUND.format(path)) boundaries = np.linspace(0, path.stat().st_size, n_processes + 1, dtype=np.int64) n_runs = len(boundaries) - 1 with cf.ProcessPoolExecutor(n_processes, mp_context=mp.get_context("spawn")) as pool: vals = list( pool.map( _read_file_process_wrapper, zip( itertools.repeat(path, times=n_runs), boundaries[:-1], boundaries[1:], itertools.repeat(strand_lengths, times=n_runs), itertools.repeat(is_oxdna, times=n_runs), strict=True, ), ), ) # this is now an list of iterables where each iterable is a concatenated # list of the output of _read_file for each process concatenated_vals = list( map( itertools.chain.from_iterable, zip(*vals, strict=False), ) ) # convert the iterables to lists and unpack list ts, bs, es, states = list(map(list, concatenated_vals)) validate_box_size(bs) return Trajectory( n_nucleotides=sum(strand_lengths), strand_lengths=strand_lengths, times=np.array(ts, dtype=np.float64), energies=np.array(es, dtype=np.float64), states=[NucleotideState(array=s) for s in states], )
[docs] def _read_file_process_wrapper( args: tuple[Path, int, int, list[int], bool], ) -> tuple[ list[typ.Scalar], list[typ.Vector3D], list[typ.Vector3D], list[typ.Arr_Nucleotide_15], ]: """Wrapper for reading a trajectory file.""" file_path, start, end, strand_lengths, is_3p_5p = args return _read_file(file_path, start, end, strand_lengths, is_3p_5p=is_3p_5p)
[docs] def _read_file( file_path: Path, start: int, end: int, strand_lengths: list[int], *, is_3p_5p: bool, ) -> tuple[ list[typ.Scalar], list[typ.Vector3D], list[typ.Vector3D], list[typ.Arr_Nucleotide_15], ]: """Read a trajectory file object.""" # we don't know where we are in the file, but we can be only in one of two # situations: We are at the start of the state or we are in the midle of a # state. If we are in the middle of a state, we need to read until the next # state starts and then parse the states from there. Importantly, we need # to pass our 'end' if the end is in the middle of a state, because the # worker ahead of in the file will not read it. parse_str = functools.partial(np.fromstring, sep=" ", dtype=np.float64) state_length = sum(strand_lengths) strand_bounds = list(itertools.pairwise([0, *itertools.accumulate(strand_lengths)])) file_obj = file_path.open() file_obj.seek(start) line = file_obj.readline() while not line.startswith("t"): line = file_obj.readline() ts, bs, es, states = [], [], [], [] state = [] current = file_obj.tell() while current < end: if line[0] == "t": t = float(line.strip().split("=")[1]) ts.append(t) elif line[0] == "b": b = parse_str(line.strip().split("=")[1]) bs.append(b) elif line[0] == "E": e = parse_str(line.strip().split("=")[1]) es.append(e) else: state.append(parse_str(line.strip())) if len(state) == state_length: # if the trajectory is stored in 3'->5' order, we need to flip # the order of the nucleotides in each strand if is_3p_5p: state = list(itertools.chain.from_iterable([state[s:e][::-1] for s, e in strand_bounds])) state = np.array(state, dtype=np.float64) states.append(np.array(state, dtype=np.float64)) state = [] current = file_obj.tell() line = file_obj.readline() return ts, bs, es, states