jax_dna.gradient_estimators.difftre
Optimization using the DiffTRe method.
DiffTRe: https://www.nature.com/articles/s41467-021-27241-4
Attributes
Classes
DiffTRe optimizer. |
Functions
Builds the energy function for the given parameters. |
|
Calls the simulation function to get the states and energies. |
|
|
Module Contents
- jax_dna.gradient_estimators.difftre.build_energy_function(opt_params: list[dict[str, float]], displacement_fn: collections.abc.Callable, energy_fns: tuple[jax_dna.energy.base.BaseEnergyFunction], energy_configs: tuple[jax_dna.energy.configuration.BaseConfiguration], rigid_body_transform_fn: collections.abc.Callable, seq: jax_dna.utils.types.Sequence, bonded_neighbors: jax.numpy.ndarray, unbonded_neighbors: jax.numpy.ndarray) jax_dna.energy.base.ComposedEnergyFunction[source]
Builds the energy function for the given parameters.
- jax_dna.gradient_estimators.difftre._compute_states_energies(params: list[dict[str, float]], key: jax.random.PRNGKey, sim_init_fn: collections.abc.Callable, energy_configs: tuple[jax_dna.energy.configuration.BaseConfiguration], energy_fns: tuple[jax_dna.energy.base.BaseEnergyFunction], init_state: jax_md.rigid_body.RigidBody, n_steps: jax_dna.utils.types.Scalar, n_eq_steps: jax_dna.utils.types.Scalar, sample_every: jax_dna.utils.types.Scalar, energy_fn_builder: collections.abc.Callable[[list[dict[str, float]]], collections.abc.Callable]) tuple[jax_dna.simulators.io.SimulatorTrajectory, jax_md.rigid_body.RigidBody, jax.numpy.ndarray][source]
Calls the simulation function to get the states and energies.
Don’t try to JIT this function, not all sim functions are JIT-able. Instead JIT sim_init and run.
- jax_dna.gradient_estimators.difftre._compute_loss(opt_params: list[dict[str, float]], energy_fn_builder: collections.abc.Callable[[dict[str, float]], collections.abc.Callable], beta: float, loss_fns: tuple[collections.abc.Callable], trajectory: jax_dna.simulators.io.SimulatorTrajectory, ref_states: jax_md.rigid_body.RigidBody, ref_energies: jax.numpy.ndarray, losses_reduce_fn: collections.abc.Callable = jnp.mean) tuple[float, int][source]
- jax_dna.gradient_estimators.difftre._loss_w_grads
- class jax_dna.gradient_estimators.difftre.DiffTRe[source]
DiffTRe optimizer.
- energy_fn_builder: collections.abc.Callable[[list[dict[str, jax_dna.utils.types.ARR_OR_SCALAR]]], collections.abc.Callable]
- beta: jax_dna.utils.types.Scalar
- min_n_eff: jax_dna.utils.types.Scalar
- loss_fns: tuple[collections.abc.Callable]
- losses_reduce_fn: collections.abc.Callable
- sim_init_fn: collections.abc.Callable
- energy_configs: tuple[jax_dna.energy.configuration.BaseConfiguration]
- energy_fns: tuple[jax_dna.energy.base.BaseEnergyFunction]
- init_state: jax_md.rigid_body.RigidBody
- n_steps: jax_dna.utils.types.Scalar
- n_eq_steps: jax_dna.utils.types.Scalar
- sample_every: jax_dna.utils.types.Scalar
- _trajectory: jax_dna.simulators.io.SimulatorTrajectory | None = None