"""Topological information for DNA/RNA."""
import dataclasses as dc
import itertools
import warnings
from collections.abc import Callable
from enum import IntEnum
from pathlib import Path
import jax.numpy as jnp
import numpy as np
import jax_dna.utils.constants as jd_const
import jax_dna.utils.types as typ
N_1ST_LINE_OXDNA_CLASSIC = 2
N_1ST_LINE_OXDNA_NEW = 3
VALID_NEIGHBOR_SECOND_DIM = 2
ERR_TOPOLOGY_INVALID_NUMBER_NUCLEOTIDES = "Invalid number of nucleotides"
ERR_TOPOLOGY_INVALID_STRAND_COUNTS = "Invalid strand counts"
ERR_TOPOLOGY_SEQ_NOT_MATCH_NUCLEOTIDES = "Sequence does not match number of nucleotides"
ERR_TOPOLOGY_STRAND_COUNTS_NOT_MATCH = "Strand counts do not match number of nucleotides"
ERR_TOPOLOGY_BONDED_NEIGHBORS_INVALID_SHAPE = "Invalid bonded neighbors shape"
ERR_TOPOLOGY_UNBONDED_NEIGHBORS_INVALID_SHAPE = "Invalid unbonded neighbors shape"
ERR_TOPOLOGY_INVALID_SEQUENCE_LENGTH = "Invalid sequence length"
ERR_TOPOLOGY_INVALID_SEQUENCE_NUCLEOTIDES = "Invalid sequence nucleotides"
ERR_TOPOLOGY_INVALID_DISCRETE_SEQUENCE_SHAPE = "Invalid discrete sequence shape"
ERR_TOPOLOGY_INVALID_UNPAIRED_PSEQ_SHAPE = "Invalid unpaired probabilistic sequence shape"
ERR_TOPOLOGY_MISMATCH_PSEQ_SHAPE_NUM_NUCLEOTIDES = "Pseq shape does not match number of nucleotides"
ERR_TOPOLOGY_INVALID_BP_PSEQ_SHAPE = "Invalid base-paired probabilistic sequence shape"
ERR_TOPOLOGY_INVALID_PROBABILITIES = "Probabilities must be > 0"
ERR_TOPOLOGY_PSEQ_NOT_NORMALIZED = "Probabilities must be normalized"
ERR_INVALID_OXDNA_FORMAT = (
"Invalid oxDNA topology format. See "
"https://lorenzo-rovigatti.github.io/oxDNA/configurations.html#topology-file for more information."
)
ERR_STRAND_COUNTS_CIRCULAR_MISMATCH = "Strand counts and cicularity do not match"
ERR_FILE_NOT_FOUND = "Topology file not found"
ERR_TOPOLOGY_INVALID_SEQUENCE_TYPE = "Invalid sequence type. Must be discrete or probabilistic"
WARN_CLASSIC_UNSPECIFIED_NT_TYPE = "Type of strand {strand_idx} not specified, and did not find T/U for autodetect"
WARN_NEW_UNSPECIFIED_NT_TYPE = "Type of strand {strand_idx} not specified"
[docs]
class NucleotideType(IntEnum):
"""Defines nucleotide types as integer enumerations."""
UNSPECIFIED = 0
DNA = 1
RNA = 2
[docs]
def check_valid_seq(seq: typ.Sequence, n_nucleotides: int) -> None:
"""Checks if a sequence is well-formed."""
if isinstance(seq, typ.Discrete_Sequence):
if len(set(np.array(seq)) - {0, 1, 2, 3}) > 0:
raise ValueError(ERR_TOPOLOGY_INVALID_SEQUENCE_NUCLEOTIDES)
if seq.shape != (n_nucleotides,):
raise ValueError(ERR_TOPOLOGY_INVALID_DISCRETE_SEQUENCE_SHAPE)
elif isinstance(seq, tuple) and len(seq) == jd_const.TWO_DIMENSIONS: # typ.Probabilistic_Sequence
up_pseq, bp_pseq = seq
if len(up_pseq.shape) != jd_const.TWO_DIMENSIONS or up_pseq.shape[1] != jd_const.N_NT:
raise ValueError(ERR_TOPOLOGY_INVALID_UNPAIRED_PSEQ_SHAPE)
if len(bp_pseq.shape) != jd_const.TWO_DIMENSIONS or bp_pseq.shape[1] != jd_const.N_BP_TYPES:
raise ValueError(ERR_TOPOLOGY_INVALID_BP_PSEQ_SHAPE)
n_unpaired = up_pseq.shape[0]
n_bp = bp_pseq.shape[0]
if n_unpaired + jd_const.N_NT_PER_BP * n_bp != n_nucleotides:
raise ValueError(ERR_TOPOLOGY_MISMATCH_PSEQ_SHAPE_NUM_NUCLEOTIDES)
if (up_pseq < 0).any() or (bp_pseq < 0).any():
raise ValueError(ERR_TOPOLOGY_INVALID_PROBABILITIES)
if (not np.allclose(np.sum(up_pseq, axis=1), 1)) or (not np.allclose(np.sum(bp_pseq, axis=1), 1)):
raise ValueError(ERR_TOPOLOGY_PSEQ_NOT_NORMALIZED)
else:
raise ValueError(ERR_TOPOLOGY_INVALID_SEQUENCE_TYPE)
[docs]
@dc.dataclass(frozen=True)
class Topology:
"""Topology information for a RNA/DNA strand."""
n_nucleotides: int
strand_counts: np.ndarray
bonded_neighbors: np.ndarray
unbonded_neighbors: np.ndarray
seq: typ.Sequence
is_end: np.ndarray
nt_type: np.ndarray
[docs]
def __post_init__(self) -> None:
"""Check that the topology is valid."""
if self.n_nucleotides < 1:
raise ValueError(ERR_TOPOLOGY_INVALID_NUMBER_NUCLEOTIDES)
if len(self.strand_counts) == 0 or sum(self.strand_counts) == 0:
raise ValueError(ERR_TOPOLOGY_INVALID_STRAND_COUNTS)
if self.n_nucleotides != sum(self.strand_counts):
raise ValueError(ERR_TOPOLOGY_STRAND_COUNTS_NOT_MATCH)
if (
len(self.bonded_neighbors.shape) != VALID_NEIGHBOR_SECOND_DIM
or self.bonded_neighbors.shape[1] != VALID_NEIGHBOR_SECOND_DIM
):
raise ValueError(ERR_TOPOLOGY_BONDED_NEIGHBORS_INVALID_SHAPE)
if (
len(self.unbonded_neighbors.shape) != VALID_NEIGHBOR_SECOND_DIM
or self.unbonded_neighbors.shape[1] != VALID_NEIGHBOR_SECOND_DIM
):
raise ValueError(ERR_TOPOLOGY_UNBONDED_NEIGHBORS_INVALID_SHAPE)
check_valid_seq(self.seq, self.n_nucleotides)
[docs]
def from_oxdna_file(path: typ.PathOrStr) -> Topology:
"""Read topology information from an oxDNA file."""
path = Path(path)
if not path.exists():
raise FileNotFoundError(ERR_FILE_NOT_FOUND)
with path.open() as f:
lines = f.readlines()
_, parse_f = _determine_oxdna_format(lines[0])
return parse_f(lines)
[docs]
def _get_bonded_neighbors(
strand_lengths: list[int],
is_circular: list[bool],
) -> list[tuple[int, int]]:
"""Convert 5' neighbors to bonded neighbors by index."""
if len(strand_lengths) != len(is_circular):
raise ValueError(ERR_STRAND_COUNTS_CIRCULAR_MISMATCH)
bonded_neighbors = []
init_idx = 0
for i, length in enumerate(strand_lengths):
pairs = list(itertools.pairwise(range(init_idx, init_idx + length)))
if is_circular[i]:
# the ordering here in intentional
pairs.append((init_idx, init_idx + length - 1))
bonded_neighbors.extend(pairs)
init_idx += length
return bonded_neighbors
[docs]
def _get_unbonded_neighbors(n_nucleotides: int, bonded_neighbors: list[tuple[int, int]]) -> list[tuple[int, int]]:
"""Get unbonded neighbors."""
all_possible_pairs = set(itertools.combinations(range(n_nucleotides), 2))
self_bonds = {(i, i) for i in range(n_nucleotides)}
return list(all_possible_pairs - set(bonded_neighbors) - self_bonds)
[docs]
def _from_file_oxdna_classic(lines: list[str]) -> Topology:
"""Read topology information from a file in the classix oxDNA format.
See https://lorenzo-rovigatti.github.io/oxDNA/configurations.html#topology-file
for more information.
Args:
lines (List[str]): lines from topology file
Returns:
Topology: Topology object
"""
n_nucleotides, n_strands = list(map(int, lines[0].strip().split()))
# after the first line the topology files are space delimited with the
# following columns:
# - strand id (1 indexed)
# - nucleotide base (A=0, C=1, G=2, T=3, U=3), use char for now
# - 3' neighbor (0-indexed), -1 if none, -1 indicates the stand isn't circular
# - 5' neighbor (0-indexed), -1 if none
#
# A more common convention is to store the nucleotides in the 5' -> 3' direction
# so we need to reverse the order, which seems to be as easy as reversing the
# order of the nucleotides per strand.
strand_ids, bases, _, neighbor_5p = list(zip(*[line.strip().split() for line in lines[1:]], strict=True))
strand_ids = list(map(int, strand_ids))
_, strand_counts = np.unique(strand_ids, return_counts=True)
neighbor_5p = list(map(int, neighbor_5p))
reversed_bases = []
is_circular = []
is_end = []
nt_type = []
for i in range(1, n_strands + 1):
strand_bases, strand_5p = zip(
*[
id_nucleotide[1:]
for id_nucleotide in zip(strand_ids, bases, neighbor_5p, strict=True)
if id_nucleotide[0] == i
],
strict=True,
)
strand_is_circular = strand_5p[-1] != -1
is_circular.append(strand_is_circular)
reversed_bases.extend(strand_bases[::-1])
strand_length = len(strand_bases)
strand_is_end = [0 for _ in range(strand_length)]
if not strand_is_circular:
strand_is_end[0] = 1
strand_is_end[-1] = 1
is_end.extend(strand_is_end)
# Classic format doesn't permit type specification
if "T" in strand_bases:
nt_type.extend([NucleotideType.DNA] * strand_length)
elif "U" in strand_bases:
nt_type.extend([NucleotideType.RNA] * strand_length)
else:
warnings.warn(WARN_CLASSIC_UNSPECIFIED_NT_TYPE.format(strand_idx=i), stacklevel=1)
nt_type.extend([NucleotideType.UNSPECIFIED] * strand_length)
sequence = "".join(reversed_bases)
bonded_neighbors = _get_bonded_neighbors(strand_counts, is_circular)
unbonded_neighbors = _get_unbonded_neighbors(n_nucleotides, bonded_neighbors)
return Topology(
n_nucleotides=n_nucleotides,
strand_counts=strand_counts,
bonded_neighbors=np.array(list(bonded_neighbors)),
unbonded_neighbors=np.array(list(unbonded_neighbors)),
seq=jnp.array([jd_const.NUCLEOTIDES_IDX[s] for s in sequence], dtype=jnp.int32),
is_end=jnp.array(is_end).astype(jnp.int32),
nt_type=jnp.array(nt_type).astype(jnp.int32),
)
[docs]
def _from_file_oxdna_new(lines: list[str]) -> Topology:
# the first line of the new oxDNA format is:
# n_nucleotides n_strands 5->3
# we don't need the 5->3, so we'll just ignore it
n_nucleotides, n_strands = list(map(int, lines[0].strip().split()[:-1]))
# the rest of the new oxDNA file format is laid out as follows:
# nucleotides k=v
# ...
# nucleotides k=v
# Where `nuclotides` is a string of ACTG and `k=v` is a set of key value pairs
# the lines are repeated n_stand times
sequence = []
strand_counts = []
is_circular = []
is_end = []
nt_type = []
for line in lines[1:]:
nucleotides = line.strip().split()[0]
sequence.append(nucleotides)
strand_counts.append(len(nucleotides))
strand_is_circular = "circular=true" in line
is_circular.append(strand_is_circular)
strand_length = len(nucleotides)
strand_is_end = [0 for _ in range(strand_length)]
if not strand_is_circular:
strand_is_end[0] = 1
strand_is_end[-1] = 1
is_end.extend(strand_is_end)
strand_is_dna = "type=DNA" in line
strand_is_rna = "type=RNA" in line
if strand_is_dna:
strand_nt_type = [NucleotideType.DNA] * strand_length
elif strand_is_rna:
strand_nt_type = [NucleotideType.RNA] * strand_length
else:
warnings.warn(WARN_NEW_UNSPECIFIED_NT_TYPE.format(strand_idx=line), stacklevel=1)
strand_nt_type = [NucleotideType.UNSPECIFIED] * strand_length
nt_type.extend(strand_nt_type)
sequence = "".join(sequence)
bonded_neighbors = _get_bonded_neighbors(strand_counts, is_circular)
unbonded_neighbors = _get_unbonded_neighbors(n_nucleotides, bonded_neighbors)
return Topology(
n_nucleotides=n_nucleotides,
strand_counts=np.array(strand_counts),
bonded_neighbors=np.array(bonded_neighbors),
unbonded_neighbors=np.array(unbonded_neighbors),
seq=jnp.array([jd_const.NUCLEOTIDES_IDX[s] for s in sequence], dtype=jnp.int32),
is_end=jnp.array(is_end).astype(jnp.int32),
nt_type=jnp.array(nt_type).astype(jnp.int32),
)