jax_dna.gradient_estimators.difftre =================================== .. py:module:: jax_dna.gradient_estimators.difftre .. autoapi-nested-parse:: Optimization using the DiffTRe method. DiffTRe: https://www.nature.com/articles/s41467-021-27241-4 Attributes ---------- .. autoapisummary:: jax_dna.gradient_estimators.difftre._loss_w_grads Classes ------- .. autoapisummary:: jax_dna.gradient_estimators.difftre.DiffTRe Functions --------- .. autoapisummary:: jax_dna.gradient_estimators.difftre.build_energy_function jax_dna.gradient_estimators.difftre._compute_states_energies jax_dna.gradient_estimators.difftre._compute_loss Module Contents --------------- .. py:function:: build_energy_function(opt_params: list[dict[str, float]], displacement_fn: collections.abc.Callable, energy_fns: tuple[jax_dna.energy.base.BaseEnergyFunction], energy_configs: tuple[jax_dna.energy.configuration.BaseConfiguration], rigid_body_transform_fn: collections.abc.Callable, seq: jax_dna.utils.types.Sequence, bonded_neighbors: jax.numpy.ndarray, unbonded_neighbors: jax.numpy.ndarray) -> jax_dna.energy.base.ComposedEnergyFunction Builds the energy function for the given parameters. .. py:function:: _compute_states_energies(params: list[dict[str, float]], key: jax.random.PRNGKey, sim_init_fn: collections.abc.Callable, energy_configs: tuple[jax_dna.energy.configuration.BaseConfiguration], energy_fns: tuple[jax_dna.energy.base.BaseEnergyFunction], init_state: jax_md.rigid_body.RigidBody, n_steps: jax_dna.utils.types.Scalar, n_eq_steps: jax_dna.utils.types.Scalar, sample_every: jax_dna.utils.types.Scalar, energy_fn_builder: collections.abc.Callable[[list[dict[str, float]]], collections.abc.Callable]) -> tuple[jax_dna.simulators.io.SimulatorTrajectory, jax_md.rigid_body.RigidBody, jax.numpy.ndarray] Calls the simulation function to get the states and energies. Don't try to JIT this function, not all sim functions are JIT-able. Instead JIT sim_init and run. .. py:function:: _compute_loss(opt_params: list[dict[str, float]], energy_fn_builder: collections.abc.Callable[[dict[str, float]], collections.abc.Callable], beta: float, loss_fns: tuple[collections.abc.Callable], trajectory: jax_dna.simulators.io.SimulatorTrajectory, ref_states: jax_md.rigid_body.RigidBody, ref_energies: jax.numpy.ndarray, losses_reduce_fn: collections.abc.Callable = jnp.mean) -> tuple[float, int] .. py:data:: _loss_w_grads .. py:class:: DiffTRe DiffTRe optimizer. .. py:attribute:: energy_fn_builder :type: collections.abc.Callable[[list[dict[str, jax_dna.utils.types.ARR_OR_SCALAR]]], collections.abc.Callable] .. py:attribute:: beta :type: jax_dna.utils.types.Scalar .. py:attribute:: min_n_eff :type: jax_dna.utils.types.Scalar .. py:attribute:: loss_fns :type: tuple[collections.abc.Callable] .. py:attribute:: losses_reduce_fn :type: collections.abc.Callable .. py:attribute:: sim_init_fn :type: collections.abc.Callable .. py:attribute:: energy_configs :type: tuple[jax_dna.energy.configuration.BaseConfiguration] .. py:attribute:: energy_fns :type: tuple[jax_dna.energy.base.BaseEnergyFunction] .. py:attribute:: init_state :type: jax_md.rigid_body.RigidBody .. py:attribute:: n_steps :type: jax_dna.utils.types.Scalar .. py:attribute:: n_eq_steps :type: jax_dna.utils.types.Scalar .. py:attribute:: sample_every :type: jax_dna.utils.types.Scalar .. py:attribute:: _trajectory :type: jax_dna.simulators.io.SimulatorTrajectory | None :value: None .. py:attribute:: _ref_states :type: jax_md.rigid_body.RigidBody | None :value: None .. py:attribute:: _ref_energies :type: jax.numpy.ndarray | None :value: None .. py:method:: initialize(opt_params: list[dict[str, float]], key: jax.random.PRNGKey) -> DiffTRe Initialize the reference states and energies. .. py:method:: __call__(opt_params: list[dict[str, float]], key: jax.random.PRNGKey) -> tuple[DiffTRe, list[dict[str, float]], float, tuple[float, Ellipsis]] Compute the loss and gradients for the given parameters.