Source code for jax_dna.energy.dna1

"""oxDNA1 energy implementation in jax_dna."""

from pathlib import Path
from types import MappingProxyType

import jax
import jax.numpy as jnp

from jax_dna.energy.base import BaseEnergyFunction
from jax_dna.energy.configuration import BaseConfiguration
from jax_dna.energy.dna1.bonded_excluded_volume import BondedExcludedVolume, BondedExcludedVolumeConfiguration
from jax_dna.energy.dna1.coaxial_stacking import CoaxialStacking, CoaxialStackingConfiguration
from jax_dna.energy.dna1.cross_stacking import CrossStacking, CrossStackingConfiguration
from jax_dna.energy.dna1.expected_hydrogen_bonding import ExpectedHydrogenBonding, ExpectedHydrogenBondingConfiguration
from jax_dna.energy.dna1.expected_stacking import ExpectedStacking, ExpectedStackingConfiguration
from jax_dna.energy.dna1.fene import Fene, FeneConfiguration
from jax_dna.energy.dna1.hydrogen_bonding import HydrogenBonding, HydrogenBondingConfiguration
from jax_dna.energy.dna1.nucleotide import Nucleotide
from jax_dna.energy.dna1.stacking import Stacking, StackingConfiguration
from jax_dna.energy.dna1.unbonded_excluded_volume import UnbondedExcludedVolume, UnbondedExcludedVolumeConfiguration
from jax_dna.input import toml
from jax_dna.utils.types import PyTree


[docs] def default_configs() -> tuple[PyTree, PyTree]: """Return the default simulation and energy configuration files for dna1 simulations.""" config_dir = ( # jax_dna/energy/dna1/__init__.py Path(__file__) .resolve() # jax_dna/ .parent.parent.parent # jax_dna/input/dna1 .joinpath("input") .joinpath("dna1") ) def cast_f(x: float | list[float]) -> jnp.ndarray: return jnp.array(x, dtype=jnp.float64) return ( jax.tree.map(cast_f, toml.parse_toml(config_dir.joinpath("default_simulation.toml"))), jax.tree.map(cast_f, toml.parse_toml(config_dir.joinpath("default_energy.toml"))), )
def default_energy_configs( overrides: dict = MappingProxyType({}), opts: dict = MappingProxyType({}) ) -> list[BaseConfiguration]: """Return the default configurations for the energy functions.""" default_sim_config, default_config = default_configs() def get_param(x: str) -> dict: return default_config[x] | overrides.get(x, {}) def get_opts(x: str, defaults: tuple[str] = BaseConfiguration.OPT_ALL) -> tuple[str]: return opts.get(x, defaults) default_stacking_opts = tuple(set(default_config["stacking"].keys()) - {"kT", "ss_stack_weights"}) return [ FeneConfiguration.from_dict(get_param("fene"), get_opts("fene")), BondedExcludedVolumeConfiguration.from_dict( get_param("bonded_excluded_volume"), get_opts("bonded_excluded_volume") ), StackingConfiguration.from_dict( get_param("stacking") | {"kt": overrides.get("kT", default_sim_config["kT"])}, get_opts("stacking", default_stacking_opts), ), UnbondedExcludedVolumeConfiguration.from_dict(get_param("unbonded_excluded_volume")), HydrogenBondingConfiguration.from_dict(get_param("hydrogen_bonding"), get_opts("hydrogen_bonding")), CrossStackingConfiguration.from_dict(get_param("cross_stacking"), get_opts("cross_stacking")), CoaxialStackingConfiguration.from_dict(get_param("coaxial_stacking"), get_opts("coaxial_stacking")), ]
[docs] def default_energy_fns() -> list[BaseEnergyFunction]: """Return the default energy functions.""" return [ Fene, BondedExcludedVolume, Stacking, UnbondedExcludedVolume, HydrogenBonding, CrossStacking, CoaxialStacking, ]
__all__ = [ "Nucleotide", "Fene", "FeneConfiguration", "HydrogenBonding", "HydrogenBondingConfiguration", "Stacking", "StackingConfiguration", "CoaxialStacking", "CoaxialStackingConfiguration", "CrossStacking", "CrossStackingConfiguration", "BondedExcludedVolume", "BondedExcludedVolumeConfiguration", "UnbondedExcludedVolume", "UnbondedExcludedVolumeConfiguration", "ExpectedHydrogenBondingConfiguration", "ExpectedHydrogenBonding", "ExpectedStackingConfiguration", "ExpectedStacking", "default_configs", "default_energy_fns", ]