Source code for jax_dna.simulators.oxdna.oxdna

"""OXDNA sampler module.

Run an jax_dna simulation using an oxDNA sampler.
"""

import logging
import os
import subprocess
import time
import typing
import warnings
from collections.abc import Callable
from pathlib import Path

import chex
import numpy as np
import ray

import jax_dna.energy.configuration as jd_energy
import jax_dna.input.oxdna_input as jd_oxdna
import jax_dna.input.topology as jd_top
import jax_dna.input.trajectory as jd_traj
import jax_dna.simulators.base as jd_base
import jax_dna.simulators.io as jd_sio
import jax_dna.simulators.oxdna.utils as oxdna_utils
import jax_dna.utils.types as jd_types

REQUIRED_KEYS = {
    "oxdna_bin",
    "input_directory",
}

ERR_OXDNA_NOT_FOUND = "OXDNA binary not found at: {}"
ERR_MISSING_REQUIRED_KEYS = "Missing required keys: {}"
ERR_INPUT_FILE_NOT_FOUND = "Input file not found: {}"
ERR_OXDNA_FAILED = "OXDNA simulation failed"
OXDNA_TRAJECTORY_FILE_KEY = "trajectory_file"
OXDNA_TOPOLOGY_FILE_KEY = "topology"

BIN_PATH_ENV_VAR = "OXDNA_BIN_PATH"
ERR_BIN_PATH_NOT_SET = "OXDNA_BIN_PATH environment variable not set"
BUILD_PATH_ENV_VAR = "OXDNA_BUILD_PATH"
ERR_BUILD_PATH_NOT_SET = "OXDNA_BUILD_PATH environment variable not set"
ERR_BUILD_SETUP_FAILED = "OXDNA build setup failed wiht return code: {}"
WARN_CANT_GUESS_BIN_LOC = (
    "Could not guess the location of the {} binary, be sure {} is set to its location for oxDNA recompilation."
)
ERR_ORIG_MODEL_H_NOT_FOUND = "Original model.h file not found, looked at {}"

MAKE_BIN_ENV_VAR = "MAKE_BIN_PATH"
CMAKE_BIN_ENV_VAR = "CMAKE_BIN_PATH"

CMAKE_MAKE_BIN_LOC_GUESSES = [
    "/bin/{}",
    "/usr/bin/{}",
    "/snap/bin/{}",
    r"C:\Program Files (x86)\GnuWin32\bin\{}.exe",
]

logger = logging.getLogger(__name__)


# We do not force the user the set this because they may not be recompiling oxDNA
[docs] def _guess_binary_location(bin_name: str, env_var: str) -> Path | None: """Guess the location of a binary.""" guessed_path = None for guess in CMAKE_MAKE_BIN_LOC_GUESSES: pth = Path(guess.format(bin_name)) if pth.exists(): guessed_path = pth break if guessed_path is None: warnings.warn(WARN_CANT_GUESS_BIN_LOC.format(bin_name, env_var), stacklevel=2) logger.debug(WARN_CANT_GUESS_BIN_LOC.format(bin_name, env_var)) return os.environ.get(env_var, None) or guessed_path
[docs] def _default_build_ready() -> bool: return True
[docs] def _default_set_build_ready(_: bool) -> None: # noqa: FBT001 pass
[docs] class oxDNABinarySemaphore: # noqa: N801 oxDNA is a special word """A semaphore for the oxDNA binary.""" def __init__(self) -> None: """Initialize the semaphore, defaults to False.""" self._ready = False
[docs] def check(self) -> bool: """Check if the semaphore is ready.""" return self._ready
[docs] def set(self, ready: bool) -> None: # noqa: FBT001 -- The way this gets used is easier this way """Set the value of the semaphore.""" self._ready = ready
[docs] @ray.remote class oxDNABinarySemaphoreActor(oxDNABinarySemaphore): # noqa: N801 oxDNA is a special word """A ray actor wrapped oxDNA binary semaphore."""
[docs] @chex.dataclass class oxDNASimulator(jd_base.BaseSimulation): # noqa: N801 oxDNA is a special word """A sampler base on running an oxDNA simulation.""" input_dir: str sim_type: jd_types.oxDNASimulatorType energy_configs: list[jd_energy.BaseConfiguration] | None = None n_build_threads: int = 4 logger_config: dict[str, typing.Any] | None = None disable_build: bool = False check_build_ready: Callable[[None], bool] = _default_build_ready set_build_ready: Callable[[bool], None] = _default_set_build_ready build_wait_interval: int = 15
[docs] def __post_init__(self, *args, **kwds) -> None: """Check the validity of the configuration.""" self._initialize_logger()
[docs] def _initialize_logger(self) -> None: config = self.logger_config if self.logger_config is not None else {} level = config.get("level", logging.INFO) logger = logging.getLogger(__name__) logger.setLevel(level) if config.get("filename", None): handler = logging.FileHandler(config["filename"]) handler.setLevel(level) else: handler = logging.StreamHandler() handler.setLevel(level) handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s:%(name)s:%(message)s")) logger.addHandler(handler) self._logger = logger
[docs] def run( self, opt_params: list[jd_types.Params] | None = None, seed: np.ndarray | None = None, **kwargs, # noqa: ARG002 we want to satisfy the interface ) -> jd_traj.Trajectory: """Run an oxDNA simulation.""" # The user may want to override the current parameters in the oxdna binary # If so, we need to update the parameters in the src/model.h file if not self._logger.handlers: self._initialize_logger() # It's possible that there are multiple oxDNA simulators sharing the same # binary per step. We need to ensure that all of the other simulators # that aren't responsible for building the binary wait for the # recompilation to finish before running. while not self.check_build_ready() and self.disable_build: self._logger.debug("Waiting for build to be ready") time.sleep(self.build_wait_interval) # if we are the building simulator, we need to update the src/model.h file if opt_params is not None and not self.disable_build: self._update_params(new_params=opt_params) # after building the binary, put the original model file back. self._restore_params() # let the other simulators know that the binary is ready self._logger.debug("Setting build ready") self.set_build_ready(True) init_dir = Path(self.input_dir) input_file = init_dir / "input" self._logger.info("oxDNA input file: %s", input_file) if not input_file.exists(): raise FileNotFoundError(ERR_INPUT_FILE_NOT_FOUND.format(input_file)) # overwrite the seed input_config = jd_oxdna.read(input_file) input_config["seed"] = seed or np.random.default_rng().integers(0, 2**32) jd_oxdna.write(input_config, input_file) if BIN_PATH_ENV_VAR not in os.environ: raise ValueError(ERR_BIN_PATH_NOT_SET) oxdna_config = jd_oxdna.read(init_dir / "input") output_file = init_dir / oxdna_config["trajectory_file"] std_out_file = init_dir / "oxdna.out.log" std_err_file = init_dir / "oxdna.err.log" self._logger.info("Starting oxDNA simulation") self._logger.debug( "oxDNA std_out->%s, std_err->%s", std_out_file, std_err_file, ) with std_out_file.open("w") as f_std, std_err_file.open("w") as f_err: subprocess.run( [ # noqa: S603 os.environ[BIN_PATH_ENV_VAR], "input", ], stdout=f_std, stderr=f_err, check=True, cwd=init_dir, ) self._logger.info("oxDNA simulation complete") # read the output trajectory file topology = jd_top.from_oxdna_file(init_dir / oxdna_config["topology"]) # return the trajectory trajectory = jd_traj.from_file(output_file, topology.strand_counts, is_oxdna=True) self._logger.debug( "oxDNA trajectory com size: %s", str(trajectory.state_rigid_body.center.shape), ) return jd_sio.SimulatorTrajectory( rigid_body=trajectory.state_rigid_body, )
[docs] def _update_params(self, *, new_params: list[dict]) -> None: """Update the simulation. This function will recompile the oxDNA binary with the new parameters. """ if BUILD_PATH_ENV_VAR not in os.environ: raise ValueError(ERR_BUILD_PATH_NOT_SET) _cmake_bin = _guess_binary_location("cmake", CMAKE_BIN_ENV_VAR) _make_bin = _guess_binary_location("make", MAKE_BIN_ENV_VAR) logger.debug("cmake_bin: %s", _cmake_bin) logger.debug("make_bin: %s", _make_bin) self._logger.info("Updating oxDNA parameters") build_dir = Path(os.environ[BUILD_PATH_ENV_VAR]) self._logger.debug("build_dir: %s", build_dir) std_out = build_dir / "jax_dna.cmake.std.log" std_err = build_dir / "jax_dna.cmake.err.log" self._logger.debug( "running cmake: std_out->%s, std_err->%s", std_out, std_err, ) with std_out.open("w") as f_std, std_err.open("w") as f_err: if _cmake_bin is None: raise FileNotFoundError(ERR_OXDNA_NOT_FOUND.format("cmake")) completed_proc = subprocess.run( [_cmake_bin, ".."], shell=False, # noqa: S603 false positive cwd=build_dir, stdout=f_std, stderr=f_err, check=True, ) self._logger.debug("cmake completed") if completed_proc.returncode != 0: raise ValueError(ERR_BUILD_SETUP_FAILED.format(completed_proc.returncode)) updated_params = [(ec | np).init_params() for ec, np in zip(self.energy_configs, new_params, strict=True)] # check for existing src/model.h file save a copy if we haven't already old_model_h = build_dir.parent.joinpath("src/model.h.old") model_h = build_dir.parent.joinpath("src/model.h") orig_text = model_h.read_text() if not old_model_h.exists(): # copy the original, so we can restore it later old_model_h.write_text(orig_text) # update the values in the src/model.h new_params = [up.to_dictionary(include_dependent=True, exclude_non_optimizable=True) for up in updated_params] oxdna_utils.update_params(model_h, new_params) # rebuild the binary std_out = build_dir / "jax_dna.make.std.log" std_err = build_dir / "jax_dna.make.err.log" self._logger.debug( "running make with %d processes: std_out->%s, std_err->%s", self.n_build_threads, std_out, std_err, ) with std_out.open("w") as f_std, std_err.open("w") as f_err: completed_proc = subprocess.run( [_make_bin, f"-j{self.n_build_threads}"], shell=False, # noqa: S603 false positive cwd=build_dir, check=True, stdout=f_std, stderr=f_err, ) if completed_proc.returncode != 0: # restore the original src/model.h model_h.write_text(orig_text) raise ValueError(ERR_BUILD_SETUP_FAILED.format(completed_proc.returncode)) self._logger.info("oxDNA binary rebuilt")
[docs] def _restore_params(self) -> None: """Restore the original parameters.""" logger.debug("Restoring oxDNA parameters to original values") build_dir = Path(os.environ[BUILD_PATH_ENV_VAR]) old_model_h = build_dir.parent.joinpath("src/model.h.old") model_h = build_dir.parent.joinpath("src/model.h") if old_model_h.exists(): # restore the original src/model.h old_model_h.replace(model_h) else: raise FileNotFoundError(ERR_ORIG_MODEL_H_NOT_FOUND.format(old_model_h))