jax_dna.energy.base

Base classes for energy functions.

Attributes

ERR_PARAM_NOT_FOUND

ERR_CALL_NOT_IMPLEMENTED

ERR_COMPOSED_ENERGY_FN_LEN_MISMATCH

ERR_COMPOSED_ENERGY_FN_TYPE_ENERGY_FNS

Classes

BaseEnergyFunction

Base class for energy functions.

ComposedEnergyFunction

Represents a linear combination of energy functions.

BaseNucleotide

Base nucleotide class.

Module Contents

jax_dna.energy.base.ERR_PARAM_NOT_FOUND = "Parameter '{key}' not found in {class_name}"
jax_dna.energy.base.ERR_CALL_NOT_IMPLEMENTED = 'Subclasses must implement this method'
jax_dna.energy.base.ERR_COMPOSED_ENERGY_FN_LEN_MISMATCH = 'Weights must have the same length as energy functions'
jax_dna.energy.base.ERR_COMPOSED_ENERGY_FN_TYPE_ENERGY_FNS = 'energy_fns must be a list of energy functions'
class jax_dna.energy.base.BaseEnergyFunction[source]

Base class for energy functions.

This class should not be used directly. Subclasses should implement the __call__ method.

Parameters:

displacement_fn (Callable) – an instance of a displacement function from jax_md.space

displacement_fn: collections.abc.Callable
property displacement_mapped: collections.abc.Callable

Returns the displacement function mapped to the space.

__add__(other: BaseEnergyFunction) ComposedEnergyFunction[source]

Add two energy functions together to create a ComposedEnergyFunction.

__mul__(other: float) ComposedEnergyFunction[source]

Multiply an energy function by a scalar to create a ComposedEnergyFunction.

abstract __call__(body: jax_md.rigid_body.RigidBody, seq: jax_dna.utils.types.Sequence, bonded_neighbors: jax_dna.utils.types.Arr_Bonded_Neighbors_2, unbonded_neighbors: jax_dna.utils.types.Arr_Unbonded_Neighbors_2) float[source]

Calculate the energy of the system.

class jax_dna.energy.base.ComposedEnergyFunction[source]

Represents a linear combination of energy functions.

Parameters:
  • energy_fns (list[BaseEnergyFunction]) – a list of energy functions

  • weights (jnp.ndarray) – optional, the weights of the energy functions

  • rigid_body_transform_fn (Callable) – a function to transform the rigid body

  • nucleotide (to into something that can be used by the energy functions like a DNA1)

energy_fns: tuple[BaseEnergyFunction]
weights: jax.numpy.ndarray | None = None
rigid_body_transform_fn: collections.abc.Callable[[jax_md.rigid_body.RigidBody], Any] | None = None
__post_init__() None[source]

Check that the input is valid.

compute_terms(body: jax_md.rigid_body.RigidBody, seq: jax_dna.utils.types.Sequence, bonded_neighbors: jax_dna.utils.types.Arr_Bonded_Neighbors_2, unbonded_neighbors: jax_dna.utils.types.Arr_Unbonded_Neighbors_2) jax.numpy.ndarray[source]

Compute each of the energy terms in the energy function.

__call__(body: jax_md.rigid_body.RigidBody, seq: jax_dna.utils.types.Sequence, bonded_neighbors: jax_dna.utils.types.Arr_Bonded_Neighbors_2, unbonded_neighbors: jax_dna.utils.types.Arr_Unbonded_Neighbors_2) float[source]

Calculates the energy of the system using the all of the function in energy_fns.

Parameters:
  • body (jax_md.rigid_body.RigidBody) – The rigid body(ies) of the system

  • seq (typ.Sequence) – the sequence of the system

  • bonded_neighbors (typ.Arr_Bonded_Neighbors_2) – the bonded neighbors

  • unbonded_neighbors (typ.Arr_Unbonded_Neighbors_2) – the unbonded neighbors

Returns:

the energy of the system

Return type:

float

add_energy_fn(energy_fn: BaseEnergyFunction, weight: float = 1.0) ComposedEnergyFunction[source]

Add an energy function to the list of energy functions.

Parameters:
  • energy_fn (BaseEnergyFunction) – the energy function to add

  • weight (float) – the weight of the energy function

Returns:

a new ComposedEnergyFunction with the added energy function

Return type:

ComposedEnergyFunction

add_composable_energy_fn(energy_fn: ComposedEnergyFunction) ComposedEnergyFunction[source]

Add a ComposedEnergyFunction to the list of energy functions.

Parameters:

energy_fn (ComposedEnergyFunction) – the ComposedEnergyFunction to add

Returns:

a new ComposedEnergyFunction with the added energy function

Return type:

ComposedEnergyFunction

__add__(other: BaseEnergyFunction | ComposedEnergyFunction) ComposedEnergyFunction[source]

Create a new ComposedEnergyFunction by adding another energy function.

This is a convenience method for the add_energy_fn and add_composable_energy_fn methods.

__radd__(other: BaseEnergyFunction | ComposedEnergyFunction) ComposedEnergyFunction[source]

Create a new ComposedEnergyFunction by adding another energy function.

This is a convenience method for the add_energy_fn and add_composable_energy_fn methods.

class jax_dna.energy.base.BaseNucleotide[source]

Bases: jax_md.rigid_body.RigidBody, abc.ABC

Base nucleotide class.

center: jax_dna.utils.types.Arr_Nucleotide_3
orientation: jax_dna.utils.types.Arr_Nucleotide_3 | jax_md.rigid_body.Quaternion
stack_sites: jax_dna.utils.types.Arr_Nucleotide_3
back_sites: jax_dna.utils.types.Arr_Nucleotide_3
base_sites: jax_dna.utils.types.Arr_Nucleotide_3
back_base_vectors: jax_dna.utils.types.Arr_Nucleotide_3
base_normals: jax_dna.utils.types.Arr_Nucleotide_3
cross_prods: jax_dna.utils.types.Arr_Nucleotide_3
static from_rigid_body(rigid_body: jax_md.rigid_body.RigidBody, **kwargs) BaseNucleotide[source]
Abstractmethod:

Create an instance of the subclass from a RigidBody..