jax_dna.gradient_estimators.difftre

Optimization using the DiffTRe method.

DiffTRe: https://www.nature.com/articles/s41467-021-27241-4

Attributes

_loss_w_grads

Classes

DiffTRe

DiffTRe optimizer.

Functions

build_energy_function(...)

Builds the energy function for the given parameters.

_compute_states_energies(...)

Calls the simulation function to get the states and energies.

_compute_loss(→ tuple[float, int])

Module Contents

jax_dna.gradient_estimators.difftre.build_energy_function(opt_params: list[dict[str, float]], displacement_fn: collections.abc.Callable, energy_fns: tuple[jax_dna.energy.base.BaseEnergyFunction], energy_configs: tuple[jax_dna.energy.configuration.BaseConfiguration], rigid_body_transform_fn: collections.abc.Callable, seq: jax_dna.utils.types.Sequence, bonded_neighbors: jax.numpy.ndarray, unbonded_neighbors: jax.numpy.ndarray) jax_dna.energy.base.ComposedEnergyFunction[source]

Builds the energy function for the given parameters.

jax_dna.gradient_estimators.difftre._compute_states_energies(params: list[dict[str, float]], key: jax.random.PRNGKey, sim_init_fn: collections.abc.Callable, energy_configs: tuple[jax_dna.energy.configuration.BaseConfiguration], energy_fns: tuple[jax_dna.energy.base.BaseEnergyFunction], init_state: jax_md.rigid_body.RigidBody, n_steps: jax_dna.utils.types.Scalar, n_eq_steps: jax_dna.utils.types.Scalar, sample_every: jax_dna.utils.types.Scalar, energy_fn_builder: collections.abc.Callable[[list[dict[str, float]]], collections.abc.Callable]) tuple[jax_dna.simulators.io.SimulatorTrajectory, jax_md.rigid_body.RigidBody, jax.numpy.ndarray][source]

Calls the simulation function to get the states and energies.

Don’t try to JIT this function, not all sim functions are JIT-able. Instead JIT sim_init and run.

jax_dna.gradient_estimators.difftre._compute_loss(opt_params: list[dict[str, float]], energy_fn_builder: collections.abc.Callable[[dict[str, float]], collections.abc.Callable], beta: float, loss_fns: tuple[collections.abc.Callable], trajectory: jax_dna.simulators.io.SimulatorTrajectory, ref_states: jax_md.rigid_body.RigidBody, ref_energies: jax.numpy.ndarray, losses_reduce_fn: collections.abc.Callable = jnp.mean) tuple[float, int][source]
jax_dna.gradient_estimators.difftre._loss_w_grads
class jax_dna.gradient_estimators.difftre.DiffTRe[source]

DiffTRe optimizer.

energy_fn_builder: collections.abc.Callable[[list[dict[str, jax_dna.utils.types.ARR_OR_SCALAR]]], collections.abc.Callable]
beta: jax_dna.utils.types.Scalar
min_n_eff: jax_dna.utils.types.Scalar
loss_fns: tuple[collections.abc.Callable]
losses_reduce_fn: collections.abc.Callable
sim_init_fn: collections.abc.Callable
energy_configs: tuple[jax_dna.energy.configuration.BaseConfiguration]
energy_fns: tuple[jax_dna.energy.base.BaseEnergyFunction]
init_state: jax_md.rigid_body.RigidBody
n_steps: jax_dna.utils.types.Scalar
n_eq_steps: jax_dna.utils.types.Scalar
sample_every: jax_dna.utils.types.Scalar
_trajectory: jax_dna.simulators.io.SimulatorTrajectory | None = None
_ref_states: jax_md.rigid_body.RigidBody | None = None
_ref_energies: jax.numpy.ndarray | None = None
initialize(opt_params: list[dict[str, float]], key: jax.random.PRNGKey) DiffTRe[source]

Initialize the reference states and energies.

__call__(opt_params: list[dict[str, float]], key: jax.random.PRNGKey) tuple[DiffTRe, list[dict[str, float]], float, tuple[float, Ellipsis]][source]

Compute the loss and gradients for the given parameters.