Source code for jax_dna.energy.base

"""Base classes for energy functions."""

from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any, Union

import chex
import jax.numpy as jnp
import jax_md

import jax_dna.utils.types as typ

ERR_PARAM_NOT_FOUND = "Parameter '{key}' not found in {class_name}"
ERR_CALL_NOT_IMPLEMENTED = "Subclasses must implement this method"
ERR_COMPOSED_ENERGY_FN_LEN_MISMATCH = "Weights must have the same length as energy functions"
ERR_COMPOSED_ENERGY_FN_TYPE_ENERGY_FNS = "energy_fns must be a list of energy functions"


[docs] @chex.dataclass(frozen=True) class BaseEnergyFunction: """Base class for energy functions. This class should not be used directly. Subclasses should implement the __call__ method. Parameters: displacement_fn (Callable): an instance of a displacement function from jax_md.space """ displacement_fn: Callable @property def displacement_mapped(self) -> Callable: """Returns the displacement function mapped to the space.""" return jax_md.space.map_bond(self.displacement_fn)
[docs] def __add__(self, other: "BaseEnergyFunction") -> "ComposedEnergyFunction": """Add two energy functions together to create a ComposedEnergyFunction.""" if not isinstance(other, BaseEnergyFunction): return NotImplemented return ComposedEnergyFunction(energy_fns=[self, other])
[docs] def __mul__(self, other: float) -> "ComposedEnergyFunction": """Multiply an energy function by a scalar to create a ComposedEnergyFunction.""" if not isinstance(other, float | int): return NotImplemented return ComposedEnergyFunction( energy_fns=[self], weights=jnp.array([other], dtype=float), )
[docs] def __call__( self, body: jax_md.rigid_body.RigidBody, seq: typ.Sequence, bonded_neighbors: typ.Arr_Bonded_Neighbors_2, unbonded_neighbors: typ.Arr_Unbonded_Neighbors_2, ) -> float: """Calculate the energy of the system.""" raise NotImplementedError(ERR_CALL_NOT_IMPLEMENTED)
[docs] @chex.dataclass(frozen=True) class ComposedEnergyFunction: """Represents a linear combination of energy functions. Parameters: energy_fns (list[BaseEnergyFunction]): a list of energy functions weights (jnp.ndarray): optional, the weights of the energy functions rigid_body_transform_fn (Callable): a function to transform the rigid body to into something that can be used by the energy functions like a DNA1 nucleotide """ energy_fns: tuple[BaseEnergyFunction] weights: jnp.ndarray | None = None rigid_body_transform_fn: Callable[[jax_md.rigid_body.RigidBody], Any] | None = None
[docs] def __post_init__(self) -> None: """Check that the input is valid.""" if not isinstance(self.energy_fns, list) or not all( isinstance(fn, BaseEnergyFunction) for fn in self.energy_fns ): raise TypeError(ERR_COMPOSED_ENERGY_FN_TYPE_ENERGY_FNS) if self.weights is not None and len(self.weights) != len(self.energy_fns): raise ValueError(ERR_COMPOSED_ENERGY_FN_LEN_MISMATCH)
[docs] def compute_terms( self, body: jax_md.rigid_body.RigidBody, seq: typ.Sequence, bonded_neighbors: typ.Arr_Bonded_Neighbors_2, unbonded_neighbors: typ.Arr_Unbonded_Neighbors_2, ) -> jnp.ndarray: """Compute each of the energy terms in the energy function.""" if self.rigid_body_transform_fn: body = self.rigid_body_transform_fn(body) return jnp.array([fn(body, seq, bonded_neighbors, unbonded_neighbors) for fn in self.energy_fns])
[docs] def __call__( self, body: jax_md.rigid_body.RigidBody, seq: typ.Sequence, bonded_neighbors: typ.Arr_Bonded_Neighbors_2, unbonded_neighbors: typ.Arr_Unbonded_Neighbors_2, ) -> float: """Calculates the energy of the system using the all of the function in `energy_fns`. Args: body (jax_md.rigid_body.RigidBody): The rigid body(ies) of the system seq (typ.Sequence): the sequence of the system bonded_neighbors (typ.Arr_Bonded_Neighbors_2): the bonded neighbors unbonded_neighbors (typ.Arr_Unbonded_Neighbors_2): the unbonded neighbors Returns: float: the energy of the system """ energy_vals = self.compute_terms(body, seq, bonded_neighbors, unbonded_neighbors) return jnp.sum(energy_vals) if self.weights is None else jnp.dot(self.weights, energy_vals)
[docs] def add_energy_fn(self, energy_fn: BaseEnergyFunction, weight: float = 1.0) -> "ComposedEnergyFunction": """Add an energy function to the list of energy functions. Args: energy_fn (BaseEnergyFunction): the energy function to add weight (float): the weight of the energy function Returns: ComposedEnergyFunction: a new ComposedEnergyFunction with the added energy function """ if self.weights is None: weights = None if weight == 1.0 else jnp.array([1.0] * len(self.energy_fns) + [weight]) else: weights = jnp.concatenate([self.weights, jnp.array([weight])]) return ComposedEnergyFunction( energy_fns=[*self.energy_fns, energy_fn], weights=weights, )
[docs] def add_composable_energy_fn(self, energy_fn: "ComposedEnergyFunction") -> "ComposedEnergyFunction": """Add a ComposedEnergyFunction to the list of energy functions. Args: energy_fn (ComposedEnergyFunction): the ComposedEnergyFunction to add Returns: ComposedEnergyFunction: a new ComposedEnergyFunction with the added energy function """ other_weights = energy_fn.weights w_none = self.weights is None ow_none = other_weights is None if w_none and ow_none: weights = None elif not w_none and not ow_none: weights = jnp.concatenate([self.weights, other_weights]) else: this_weights = self.weights if not w_none else jnp.ones(len(energy_fn.energy_fns)) other_weights = other_weights if not ow_none else jnp.ones(len(self.energy_fns)) weights = jnp.concatenate([this_weights, other_weights]) return ComposedEnergyFunction( energy_fns=self.energy_fns + energy_fn.energy_fns, weights=weights, )
[docs] def __add__(self, other: Union[BaseEnergyFunction, "ComposedEnergyFunction"]) -> "ComposedEnergyFunction": """Create a new ComposedEnergyFunction by adding another energy function. This is a convenience method for the add_energy_fn and add_composable_energy_fn methods. """ if isinstance(other, BaseEnergyFunction): energy_fn = self.add_energy_fn elif isinstance(other, ComposedEnergyFunction): energy_fn = self.add_composable_energy_fn else: return NotImplemented return energy_fn(other)
[docs] def __radd__(self, other: Union[BaseEnergyFunction, "ComposedEnergyFunction"]) -> "ComposedEnergyFunction": """Create a new ComposedEnergyFunction by adding another energy function. This is a convenience method for the add_energy_fn and add_composable_energy_fn methods. """ return self.__add__(other)
[docs] @chex.dataclass(frozen=True) class BaseNucleotide(jax_md.rigid_body.RigidBody, ABC): """Base nucleotide class.""" center: typ.Arr_Nucleotide_3 orientation: typ.Arr_Nucleotide_3 | jax_md.rigid_body.Quaternion stack_sites: typ.Arr_Nucleotide_3 back_sites: typ.Arr_Nucleotide_3 base_sites: typ.Arr_Nucleotide_3 back_base_vectors: typ.Arr_Nucleotide_3 base_normals: typ.Arr_Nucleotide_3 cross_prods: typ.Arr_Nucleotide_3
[docs] @staticmethod @abstractmethod def from_rigid_body(rigid_body: jax_md.rigid_body.RigidBody, **kwargs) -> "BaseNucleotide": """Create an instance of the subclass from a RigidBody.."""