"""Persistence length observable."""
import dataclasses as dc
import functools
from collections.abc import Callable
import chex
import jax
import jax.numpy as jnp
from jax import vmap
from jax_md import space
import jax_dna.energy.dna1 as jd_energy
import jax_dna.input.toml as jd_toml
import jax_dna.input.trajectory as jd_traj
import jax_dna.observables.base as jd_obs
import jax_dna.simulators.io as jd_sio
import jax_dna.utils.types as jd_types
TARGETS = {
"oxDNA": 47.5, # nm
}
[docs]
def persistence_length_fit(correlations: jnp.ndarray, l0_av: float) -> tuple[float, float]:
"""Computes the Lp given correlations in alignment decay and average distance between base pairs.
Lp obeys the following equality: `<l_n * l_0> = exp(-n<l_0> / Lp)`, where `<l_n * l_0>` represents the
average correlation between adjacent base pairs (`l_0`) and base pairs separated by a distance of
`n` base pairs (`l_n`). This relationship is linear in log space, `log(<l_n * l_0>) = -n<l_0> / Lp`.
So, given the average correlations across distances and the average distance between adjacent base pairs,
we compute Lp via a linear fit.
Args:
correlations (jnp.ndarray): a (max_dist,) array containing the average correlation between
base pairs separated by distances up to `max_dist`
l0_av (jnp.ndarray): the average distance between adjacent base pairs
"""
# Format the correlations for a linear fit
y = jnp.log(correlations)
x = jnp.arange(correlations.shape[0])
x = jnp.stack([jnp.ones_like(x), x], axis=1)
# Fit a line
fit_ = jnp.linalg.lstsq(x, y)
# Extract slope and offset, and compute Lp
offset = fit_[0][0]
slope = fit_[0][1] # slope = -l0_av / Lp
Lp = -l0_av / slope # noqa: N806 -- This is a special variable name
return Lp, offset
[docs]
def compute_l_vector(base_sites: jnp.ndarray, quartet: jnp.ndarray) -> tuple[jnp.ndarray, float]:
"""Computes the distance between two adjacent base pairs."""
# Extract the two base pairs defined by a quartet
bp1, bp2 = quartet
(a1, b1), (a2, b2) = bp1, bp2 # a1 and b1, and a2 and b2 are base paired
# Compute midpoints for each base pair
mp1 = (base_sites[b1] + base_sites[a1]) / 2.0
mp2 = (base_sites[b2] + base_sites[a2]) / 2.0
# Compute vector between midpoints
midpoint_diff = mp2 - mp1
l0 = jnp.linalg.norm(midpoint_diff)
midpoint_diff /= l0
# Return vector and its norm
return midpoint_diff, l0
get_all_l_vectors = vmap(compute_l_vector, in_axes=(None, 0))
[docs]
def vector_autocorrelate(vecs: jnp.ndarray) -> jnp.ndarray:
"""Computes the average correlations in alignment decay between a list of vector.
Given an ordered list of n vectors (representing vectors between adjacent base pairs),
computes the average correlation between all pairs of vectors separated by a distance `d`
for all distances `d < n`. Note that multiple pairs of vectors are included for all
values < n-1.
Args:
vecs (jnp.ndarray): a (n, 3) array of vectors corresponding to displacements between midpoints of adjacent
base pairs.
"""
max_dist = vecs.shape[0]
def window_correlations(i: int) -> jnp.ndarray:
li = vecs[i]
def i_correlation_fn(j: int) -> jnp.ndarray:
return jnp.where(j >= i, jnp.dot(li, vecs[j]), 0.0)
i_correlations = vmap(i_correlation_fn)(jnp.arange(max_dist))
return jnp.roll(i_correlations, -i)
all_correlations = vmap(window_correlations)(jnp.arange(max_dist))
all_correlations = jnp.sum(all_correlations, axis=0)
all_correlations /= jnp.arange(max_dist, 0, -1)
return all_correlations
if __name__ == "__main__":
import matplotlib.pyplot as plt
import jax_dna.input.topology as jd_top
test_geometry = jd_toml.parse_toml("jax_dna/input/dna1/default_energy.toml")["geometry"]
tranform_fn = functools.partial(
jd_energy.Nucleotide.from_rigid_body,
com_to_backbone=test_geometry["com_to_backbone"],
com_to_hb=test_geometry["com_to_hb"],
com_to_stacking=test_geometry["com_to_stacking"],
)
top = jd_top.from_oxdna_file("data/templates/persistence-length/sys.top")
test_traj = jd_traj.from_file(
path="data/templates/persistence-length/init.conf",
strand_lengths=top.strand_counts,
)
sim_traj = jd_sio.SimulatorTrajectory(
seq=jnp.array(top.seq_idx),
strand_lengths=top.strand_counts,
rigid_body=test_traj.state_rigid_body,
)
quartets = jd_obs.get_duplex_quartets(202)
displacement_fn, _ = space.free()
lp_metadata = LpMetadata(rigid_body_transform_fn=tranform_fn, quartets=quartets, displacement_fn=displacement_fn)
output_all_corrs, output_all_l0_vals = lp_metadata(sim_traj)
mean_all_corrs = jnp.mean(output_all_corrs, axis=0)
mean_l0_val = jnp.mean(output_all_l0_vals, axis=0)
truncation = 40
fit_lp, fit_offset = persistence_length_fit(mean_all_corrs[:truncation], mean_l0_val)
def log_corr_fn(n: jnp.ndarray) -> jnp.ndarray: # noqa: D103 -- This is for testing
return -n * mean_l0_val / fit_lp + fit_offset
plt.plot(jnp.log(mean_all_corrs[:truncation]))
plt.plot(log_corr_fn(jnp.arange(mean_all_corrs[:truncation].shape[0])), linestyle="--")
plt.xlabel("Distance")
plt.ylabel("Log-Correlation")
plt.show()