Source code for jax_dna.simulators.oxdna.utils

"""Utilities for the oxDNA simulator."""

import datetime
import functools
import operator
from pathlib import Path

import jax
import jax.numpy as jnp
import sympy

import jax_dna.utils.types as jd_types
from jax_dna.utils.types import oxDNAModelHType

ERR_CANNOT_PROCESS_SRC_H = "Cannot process src/model.h file. Failed parsing: {}"
ERR_INVALID_HEADER_TYPE = "Invalid header value variable {} with value {}"

SYMPY_EVAL_N: int = 32

DEFAULT_OXDNA_VARIABLE_MAPPER = {
    # fene
    "eps_backbone": "FENE_EPS",
    "delta_backbone": "FENE_DELTA",
    "r0_backbone": "FENE_R0_OXDNA",
    # excluded_volume
    "eps_exc": "EXCL_EPS",
    "sigma_backbone": "EXCL_S1",
    "sigma_base": "EXCL_S2",
    "sigma_back_base": "EXCL_S3",
    "sigma_base_back": "EXCL_S4",
    "dr_star_backbone": "EXCL_R1",
    "dr_star_base": "EXCL_R2",
    "dr_star_back_base": "EXCL_R3",
    "dr_star_base_back": "EXCL_R4",
    "b_backbone": "EXCL_B1",
    "b_base": "EXCL_B2",
    "b_back_base": "EXCL_B3",
    "b_base_back": "EXCL_B4",
    "dr_c_backbone": "EXCL_RC1",
    "dr_c_base": "EXCL_RC2",
    "dr_c_back_base": "EXCL_RC3",
    "dr_c_base_back": "EXCL_RC4",
    # stacking
    # func f1(dr_stack)
    "eps_stack_base": "STCK_BASE_EPS_OXDNA",
    "eps_stack_kt_coeff": "STCK_FACT_EPS_OXDNA",
    "a_stack": "STCK_A",
    "dr0_stack": "STCK_R0",
    "dr_c_stack": "STCK_RC",
    "dr_low_stack": "STCK_RLOW",
    "dr_high_stack": "STCK_RHIGH",
    "b_low_stack": "STCK_BLOW",
    "b_high_stack": "STCK_BHIGH",
    "dr_c_low_stack": "STCK_RCLOW",
    "dr_c_high_stack": "STCK_RCHIGH",
    # func f4(theta_4)
    "a_stack_4": "STCK_THETA4_A",
    "theta0_stack_4": "STCK_THETA4_T0",
    "delta_theta_star_stack_4": "STCK_THETA4_TS",
    "b_stack_4": "STCK_THETA4_B",
    "delta_theta_stack_4_c": "STCK_THETA4_TC",
    # func f4(theta_5p)
    "a_stack_5": "STCK_THETA5_A",
    "theta0_stack_5": "STCK_THETA5_T0",
    "delta_theta_star_stack_5": "STCK_THETA5_TS",
    "b_stack_5": "STCK_THETA5_B",
    "delta_theta_stack_5_c": "STCK_THETA5_TC",
    # func f4(theta_6p)
    "a_stack_6": "STCK_THETA6_A",
    "theta0_stack_6": "STCK_THETA6_T0",
    "delta_theta_star_stack_6": "STCK_THETA6_TS",
    "b_stack_6": "STCK_THETA6_B",
    "delta_theta_stack_6_c": "STCK_THETA6_TC",
    # func f5(-cos(phi1))
    "a_stack_1": "STCK_PHI1_A",
    "neg_cos_phi1_star_stack": "STCK_PHI1_XS",
    "b_neg_cos_phi1_stack": "STCK_PHI1_B",
    "neg_cos_phi1_c_stack": "STCK_PHI1_XC",
    # func f5(-cos(phi2))
    "a_stack_2": "STCK_PHI2_A",
    "neg_cos_phi2_star_stack": "STCK_PHI2_XS",
    "b_neg_cos_phi2_stack": "STCK_PHI2_B",
    "neg_cos_phi2_c_stack": "STCK_PHI2_XC",
    # hydrogen_bonding
    # func f1(dr_hb)
    "eps_hb": "HYDR_EPS_OXDNA",
    "a_hb": "HYDR_A",
    "dr0_hb": "HYDR_R0",
    "dr_c_hb": "HYDR_RC",
    "dr_low_hb": "HYDR_RLOW",
    "dr_high_hb": "HYDR_RHIGH",
    "b_low_hb": "HYDR_BLOW",
    "dr_c_low_hb": "HYDR_RCLOW",
    "b_high_hb": "HYDR_BHIGH",
    "dr_c_high_hb": "HYDR_RCHIGH",
    # func f4(theta_1)
    "a_hb_1": "HYDR_THETA1_A",
    "theta0_hb_1": "HYDR_THETA1_T0",
    "delta_theta_star_hb_1": "HYDR_THETA1_TS",
    "b_hb_1": "HYDR_THETA1_B",
    "delta_theta_hb_1_c": "HYDR_THETA1_TC",
    # func f4(theta_2)
    "a_hb_2": "HYDR_THETA2_A",
    "theta0_hb_2": "HYDR_THETA2_T0",
    "delta_theta_star_hb_2": "HYDR_THETA2_TS",
    "b_hb_2": "HYDR_THETA2_B",
    "delta_theta_hb_2_c": "HYDR_THETA2_TC",
    # func f4(theta_3)
    "a_hb_3": "HYDR_THETA3_A",
    "theta0_hb_3": "HYDR_THETA3_T0",
    "delta_theta_star_hb_3": "HYDR_THETA3_TS",
    "b_hb_3": "HYDR_THETA3_B",
    "delta_theta_hb_3_c": "HYDR_THETA3_TC",
    # func f4(theta_4)
    "a_hb_4": "HYDR_THETA4_A",
    "theta0_hb_4": "HYDR_THETA4_T0",
    "delta_theta_star_hb_4": "HYDR_THETA4_TS",
    "b_hb_4": "HYDR_THETA4_B",
    "delta_theta_hb_4_c": "HYDR_THETA4_TC",
    # func f4(theta_7)
    "a_hb_7": "HYDR_THETA7_A",
    "theta0_hb_7": "HYDR_THETA7_T0",
    "delta_theta_star_hb_7": "HYDR_THETA7_TS",
    "b_hb_7": "HYDR_THETA7_B",
    "delta_theta_hb_7_c": "HYDR_THETA7_TC",
    # func f4(theta_8)
    "a_hb_8": "HYDR_THETA8_A",
    "theta0_hb_8": "HYDR_THETA8_T0",
    "delta_theta_star_hb_8": "HYDR_THETA8_TS",
    "b_hb_8": "HYDR_THETA8_B",
    "delta_theta_hb_8_c": "HYDR_THETA8_TC",
    # cross_stacking
    # func f2(dr_cross)
    "k_cross": "CRST_K",
    "r0_cross": "CRST_R0",
    "dr_c_cross": "CRST_RC",
    "dr_low_cross": "CRST_RLOW",
    "dr_high_cross": "CRST_RHIGH",
    "b_low_cross": "CRST_BLOW",
    "dr_c_low_cross": "CRST_RCLOW",
    "b_high_cross": "CRST_BHIGH",
    "dr_c_high_cross": "CRST_RCHIGH",
    # func f4(theta_1)
    "a_cross_1": "CRST_THETA1_A",
    "theta0_cross_1": "CRST_THETA1_T0",
    "delta_theta_star_cross_1": "CRST_THETA1_TS",
    "b_cross_1": "CRST_THETA1_B",
    "delta_theta_cross_1_c": "CRST_THETA1_TC",
    # func f4(theta_2)
    "a_cross_2": "CRST_THETA2_A",
    "theta0_cross_2": "CRST_THETA2_T0",
    "delta_theta_star_cross_2": "CRST_THETA2_TS",
    "b_cross_2": "CRST_THETA2_B",
    "delta_theta_cross_2_c": "CRST_THETA2_TC",
    # func f4(theta_3)
    "a_cross_3": "CRST_THETA3_A",
    "theta0_cross_3": "CRST_THETA3_T0",
    "delta_theta_star_cross_3": "CRST_THETA3_TS",
    "b_cross_3": "CRST_THETA3_B",
    "delta_theta_cross_3_c": "CRST_THETA3_TC",
    # func f4(theta_4) + f4(pi - theta_4)
    "a_cross_4": "CRST_THETA4_A",
    "theta0_cross_4": "CRST_THETA4_T0",
    "delta_theta_star_cross_4": "CRST_THETA4_TS",
    "b_cross_4": "CRST_THETA4_B",
    "delta_theta_cross_4_c": "CRST_THETA4_TC",
    # func f4(theta_7) + f4(pi - theta_7)
    "a_cross_7": "CRST_THETA7_A",
    "theta0_cross_7": "CRST_THETA7_T0",
    "delta_theta_star_cross_7": "CRST_THETA7_TS",
    "b_cross_7": "CRST_THETA7_B",
    "delta_theta_cross_7_c": "CRST_THETA7_TC",
    # func f4(theta_8) + f4(pi - theta_8)
    "a_cross_8": "CRST_THETA8_A",
    "theta0_cross_8": "CRST_THETA8_T0",
    "delta_theta_star_cross_8": "CRST_THETA8_TS",
    "b_cross_8": "CRST_THETA8_B",
    "delta_theta_cross_8_c": "CRST_THETA8_TC",
    # coaxial_stacking
    # func f2(dr_coax)
    "k_coax": "CXST_K_OXDNA",
    "dr0_coax": "CXST_R0",
    "dr_c_coax": "CXST_RC",
    "dr_low_coax": "CXST_RLOW",
    "dr_high_coax": "CXST_RHIGH",
    "b_low_coax": "CXST_BLOW",
    "dr_c_low_coax": "CXST_RCLOW",
    "b_high_coax": "CXST_BHIGH",
    "dr_c_high_coax": "CXST_RCHIGH",
    # func f4(theta_1) + f4(2*pi - theta_1)
    "a_coax_1": "CXST_THETA1_A",
    "theta0_coax_1": "CXST_THETA1_T0_OXDNA",
    "delta_theta_star_coax_1": "CXST_THETA1_TS",
    "b_coax_1": "CXST_THETA1_B",
    "delta_theta_coax_1_c": "CXST_THETA1_TC",
    # func f4(theta_4)
    "a_coax_4": "CXST_THETA4_A",
    "theta0_coax_4": "CXST_THETA4_T0",
    "delta_theta_star_coax_4": "CXST_THETA4_TS",
    "b_coax_4": "CXST_THETA4_B",
    "delta_theta_coax_4_c": "CXST_THETA4_TC",
    # func f4(theta_5) + f4(pi - theta_5)
    "a_coax_5": "CXST_THETA5_A",
    "theta0_coax_5": "CXST_THETA5_T0",
    "delta_theta_star_coax_5": "CXST_THETA5_TS",
    "b_coax_5": "CXST_THETA5_B",
    "delta_theta_coax_5_c": "CXST_THETA5_TC",
    # func f4(theta_6) + f4(pi - theta_6)
    "a_coax_6": "CXST_THETA6_A",
    "theta0_coax_6": "CXST_THETA6_T0",
    "delta_theta_star_coax_6": "CXST_THETA6_TS",
    "b_coax_6": "CXST_THETA6_B",
    "delta_theta_coax_6_c": "CXST_THETA6_TC",
    # func f5(cos(phi3))
    "a_coax_3p": "CXST_PHI3_A",
    "cos_phi3_star_coax": "CXST_PHI3_XS",
    "b_cos_phi3_coax": "CXST_PHI3_B",
    "cos_phi3_c_coax": "CXST_PHI3_XC",
    # func f5(cos(phi4))
    "a_coax_4p": "CXST_PHI4_A",
    "cos_phi4_star_coax": "CXST_PHI4_XS",
    "b_cos_phi4_coax": "CXST_PHI4_B",
    "cos_phi4_c_coax": "CXST_PHI4_XC",
}

MIN_VALID_HEADER_TOKEN_COUNT = 3


[docs] def _parse_value_in(value: str) -> int | float | str: for t in (int, float): try: if t is float: tmp_value = value.replace("f", "").lower() parsed = float(sympy.parse_expr(tmp_value).evalf(n=SYMPY_EVAL_N)) else: parsed = t(value) except (AttributeError, ValueError, SyntaxError, TypeError): continue else: return parsed return value
[docs] def _parse_value_out(value: int | float | str) -> str: # noqa: PYI041 -- this is documentation specific if isinstance(value, int) or (isinstance(value, jax.Array) and (jnp.issubdtype(value.dtype, jnp.integer))): parsed_value = str(value) elif isinstance(value, float) or (isinstance(value, jax.Array) and (jnp.issubdtype(value.dtype, jnp.floating))): parsed_value = f"{value}f" elif isinstance(value, str): parsed_value = value else: raise TypeError(ERR_INVALID_HEADER_TYPE.format(type(value), value)) return parsed_value
[docs] def read_src_h(src_h: Path) -> dict[str, int | float | str]: """Parse the src/model.h file to get the parameters.""" params = {} with src_h.open("r") as f: for line in f: # this is a variable if line.startswith("#define") and "MODEL_H_" not in line: # We need to parse lines of the following varieties: # #define POS_BACK -0.4f # #define HYDR_F1 0 # #define HYDR_THETA8_T0 (PI*0.5f) # #define HYDR_T3_MESH_POINTS HYDR_T2_MESH_POINTS # #define CXST_T5_MESH_POINTS 6 // perfetto parts = line.split() if ( len(parts) >= MIN_VALID_HEADER_TOKEN_COUNT and (var_value := _parse_value_in(" ".join(parts[2:]).split("//")[0].strip())) is not None ): params[parts[1]] = var_value else: raise ValueError(ERR_CANNOT_PROCESS_SRC_H.format(line)) return params
[docs] def write_src_h(src_h: Path, params: dict[str, tuple[oxDNAModelHType, int | float | str]]) -> None: """Write the src/model.h file with the new parameters.""" with src_h.open("w") as f: f.write( "\n".join( [ "/**", " * @file model.h", f" * @date {datetime.datetime.now(tz=datetime.UTC).strftime('%b %d, %Y')}", " * @author fromano -- modified by jax_dna", " */", "", "#ifndef MODEL_H_", "#define MODEL_H_\n", ] ) ) for key, value in params.items(): try: parsed_value = _parse_value_out(value) except ValueError as e: raise ValueError(ERR_INVALID_HEADER_TYPE.format(key, value)) from e f.write(f"#define {key} {parsed_value}\n") if key == "FENE_DELTA": f.write(f"#define FENE_DELTA2 {value**2}f\n") f.write("#endif /* MODEL_H_ */\n")
[docs] def update_params(src_h: Path, new_params: list[jd_types.Params]) -> None: """Update the src/model.h file with the new parameters.""" params = read_src_h(src_h) flattened_params = functools.reduce(operator.or_, new_params, {}) for np in filter(lambda k: k in DEFAULT_OXDNA_VARIABLE_MAPPER, flattened_params): mapped_name = DEFAULT_OXDNA_VARIABLE_MAPPER[np] if mapped_name in params: params[mapped_name] = flattened_params[np] else: raise ValueError(f"Parameter {np} not found in src/model.h") write_src_h(src_h, params)