jax_dna.optimization.objective
Objectives implemented as ray actors.
Attributes
Classes
Base class for objectives that calculate gradients. |
|
Objective that calculates the gradients of a simulation. |
|
Objective that calculates the gradients of an objective using DiffTRe. |
|
Objective that calculates the gradients of an objective using DiffTRe and ray. |
Functions
|
Compute the weights and normalized effective sample size of a trajectory. |
|
Compute the grads, loss, and auxiliary values. |
Module Contents
- jax_dna.optimization.objective.ERR_DIFFTRE_MISSING_KWARGS = 'Missing required kwargs: {missing_kwargs}.'
- jax_dna.optimization.objective.ERR_MISSING_ARG = 'Missing required argument: {missing_arg}.'
- jax_dna.optimization.objective.ERR_OBJECTIVE_NOT_READY = 'Not all required observables have been obtained.'
- jax_dna.optimization.objective.EnergyFn
- jax_dna.optimization.objective.empty_dict
- class jax_dna.optimization.objective.Objective(name: str, required_observables: list[str], needed_observables: list[str], logging_observables: list[str], grad_or_loss_fn: Callable[[tuple[str, Ellipsis]], tuple[jax_dna.utils.types.Grads, list[tuple[str, Any]]]], logger_config: dict[str, Any] = empty_dict)[source]
Base class for objectives that calculate gradients.
- _name
- _required_observables
- _needed_observables
- _grad_or_loss_fn
- _obtained_observables = []
- _logging_observables
- _logger
- required_observables() list[str][source]
Return the observables that are required to calculate the gradients.
- obtained_observables() list[tuple[str, jax_dna.utils.types.SimulatorActorOutput]][source]
Return the latest observed values for all observables.
- logging_observables() list[tuple[str, Any]][source]
Return the latest observed values for the logging observables.
- class jax_dna.optimization.objective.SimGradObjectiveActor(name: str, required_observables: list[str], needed_observables: list[str], logging_observables: list[str], grad_or_loss_fn: Callable[[tuple[str, Ellipsis]], tuple[jax_dna.utils.types.Grads, list[tuple[str, Any]]]], logger_config: dict[str, Any] = empty_dict)[source]
Bases:
ObjectiveObjective that calculates the gradients of a simulation.
- jax_dna.optimization.objective.compute_weights_and_neff(beta: float, new_energies: jax_dna.utils.types.Arr_N, ref_energies: jax_dna.utils.types.Arr_N) tuple[jax.numpy.ndarray, float][source]
Compute the weights and normalized effective sample size of a trajectory.
Calculation derived from the DiffTRe algorithm.
https://www.nature.com/articles/s41467-021-27241-4 See equations 4 and 5.
- Parameters:
beta – The inverse temperature.
new_energies – The new energies of the trajectory.
ref_energies – The reference energies of the trajectory.
- Returns:
The weights and the normalized effective sample size
- jax_dna.optimization.objective.compute_loss(opt_params: jax_dna.utils.types.Params, energy_fn_builder: callable, beta: float, loss_fn: collections.abc.Callable[[jax_md.rigid_body.RigidBody, jax_dna.utils.types.Arr_N, EnergyFn], tuple[jax.numpy.ndarray, tuple[str, Any]]], ref_states: jax_md.rigid_body.RigidBody, ref_energies: jax_dna.utils.types.Arr_N) tuple[float, tuple[float, jax.numpy.ndarray]][source]
Compute the grads, loss, and auxiliary values.
- Parameters:
opt_params – The optimization parameters.
energy_fn_builder – A function that builds the energy function.
beta – The inverse temperature.
loss_fn – The loss function.
ref_states – The reference states of the trajectory.
ref_energies – The reference energies of the trajectory.
- Returns:
The grads, the loss, a tuple containing the normalized effective sample size and the measured value of the trajectory, and the new energies.
- jax_dna.optimization.objective.compute_loss_and_grad
- class jax_dna.optimization.objective.DiffTReObjective(name: str, required_observables: list[str], needed_observables: list[str], logging_observables: list[str], grad_or_loss_fn: Callable[[tuple[jax_dna.utils.types.SimulatorActorOutput]], jax_dna.utils.types.Grads], energy_fn_builder: collections.abc.Callable[[jax_dna.utils.types.Params], collections.abc.Callable[[jax.numpy.ndarray], jax.numpy.ndarray]], opt_params: jax_dna.utils.types.Params, beta: float, n_equilibration_steps: int, min_n_eff_factor: float = 0.95, max_valid_opt_steps: int = math.inf, logging_config: dict[str, Any] = empty_dict)[source]
Bases:
ObjectiveObjective that calculates the gradients of an objective using DiffTRe.
- _energy_fn_builder
- _opt_params
- _beta
- _n_eq_steps
- _n_eff_factor = 0.95
- _max_valid_opt_steps = inf
- _opt_steps = 1
- _reference_states = None
- _reference_energies = None
- class jax_dna.optimization.objective.DiffTReObjectiveActor(name: str, required_observables: list[str], needed_observables: list[str], logging_observables: list[str], grad_or_loss_fn: Callable[[tuple[jax_dna.utils.types.SimulatorActorOutput]], jax_dna.utils.types.Grads], energy_fn_builder: collections.abc.Callable[[jax_dna.utils.types.Params], collections.abc.Callable[[jax.numpy.ndarray], jax.numpy.ndarray]], opt_params: jax_dna.utils.types.Params, beta: float, n_equilibration_steps: int, min_n_eff_factor: float = 0.95, max_valid_opt_steps: int = math.inf, logging_config: dict[str, Any] = empty_dict)[source]
Bases:
DiffTReObjectiveObjective that calculates the gradients of an objective using DiffTRe and ray.