jax_dna.optimization.objective

Objectives implemented as ray actors.

Attributes

ERR_DIFFTRE_MISSING_KWARGS

ERR_MISSING_ARG

ERR_OBJECTIVE_NOT_READY

EnergyFn

empty_dict

compute_loss_and_grad

Classes

Objective

Base class for objectives that calculate gradients.

SimGradObjectiveActor

Objective that calculates the gradients of a simulation.

DiffTReObjective

Objective that calculates the gradients of an objective using DiffTRe.

DiffTReObjectiveActor

Objective that calculates the gradients of an objective using DiffTRe and ray.

Functions

compute_weights_and_neff(→ tuple[jax.numpy.ndarray, float])

Compute the weights and normalized effective sample size of a trajectory.

compute_loss(→ tuple[float, tuple[float, ...)

Compute the grads, loss, and auxiliary values.

Module Contents

jax_dna.optimization.objective.ERR_DIFFTRE_MISSING_KWARGS = 'Missing required kwargs: {missing_kwargs}.'
jax_dna.optimization.objective.ERR_MISSING_ARG = 'Missing required argument: {missing_arg}.'
jax_dna.optimization.objective.ERR_OBJECTIVE_NOT_READY = 'Not all required observables have been obtained.'
jax_dna.optimization.objective.EnergyFn
jax_dna.optimization.objective.empty_dict
class jax_dna.optimization.objective.Objective(name: str, required_observables: list[str], needed_observables: list[str], logging_observables: list[str], grad_or_loss_fn: Callable[[tuple[str, Ellipsis]], tuple[jax_dna.utils.types.Grads, list[tuple[str, Any]]]], logger_config: dict[str, Any] = empty_dict)[source]

Base class for objectives that calculate gradients.

_name
_required_observables
_needed_observables
_grad_or_loss_fn
_obtained_observables = []
_logging_observables
_logger
name() str[source]

Return the name of the objective.

required_observables() list[str][source]

Return the observables that are required to calculate the gradients.

needed_observables() list[str][source]

Return the observables that are still needed.

obtained_observables() list[tuple[str, jax_dna.utils.types.SimulatorActorOutput]][source]

Return the latest observed values for all observables.

logging_observables() list[tuple[str, Any]][source]

Return the latest observed values for the logging observables.

is_ready() bool[source]

Check if the objective is ready to calculate its gradients.

update(sim_results: list[tuple[list[str], list[str]]]) None[source]

Update the observables with the latest simulation results.

calculate() list[jax_dna.utils.types.Grads][source]

Calculate the gradients of the objective.

post_step(opt_params: dict) None[source]

Reset the needed observables for the next step.

class jax_dna.optimization.objective.SimGradObjectiveActor(name: str, required_observables: list[str], needed_observables: list[str], logging_observables: list[str], grad_or_loss_fn: Callable[[tuple[str, Ellipsis]], tuple[jax_dna.utils.types.Grads, list[tuple[str, Any]]]], logger_config: dict[str, Any] = empty_dict)[source]

Bases: Objective

Objective that calculates the gradients of a simulation.

jax_dna.optimization.objective.compute_weights_and_neff(beta: float, new_energies: jax_dna.utils.types.Arr_N, ref_energies: jax_dna.utils.types.Arr_N) tuple[jax.numpy.ndarray, float][source]

Compute the weights and normalized effective sample size of a trajectory.

Calculation derived from the DiffTRe algorithm.

https://www.nature.com/articles/s41467-021-27241-4 See equations 4 and 5.

Parameters:
  • beta – The inverse temperature.

  • new_energies – The new energies of the trajectory.

  • ref_energies – The reference energies of the trajectory.

Returns:

The weights and the normalized effective sample size

jax_dna.optimization.objective.compute_loss(opt_params: jax_dna.utils.types.Params, energy_fn_builder: callable, beta: float, loss_fn: collections.abc.Callable[[jax_md.rigid_body.RigidBody, jax_dna.utils.types.Arr_N, EnergyFn], tuple[jax.numpy.ndarray, tuple[str, Any]]], ref_states: jax_md.rigid_body.RigidBody, ref_energies: jax_dna.utils.types.Arr_N) tuple[float, tuple[float, jax.numpy.ndarray]][source]

Compute the grads, loss, and auxiliary values.

Parameters:
  • opt_params – The optimization parameters.

  • energy_fn_builder – A function that builds the energy function.

  • beta – The inverse temperature.

  • loss_fn – The loss function.

  • ref_states – The reference states of the trajectory.

  • ref_energies – The reference energies of the trajectory.

Returns:

The grads, the loss, a tuple containing the normalized effective sample size and the measured value of the trajectory, and the new energies.

jax_dna.optimization.objective.compute_loss_and_grad
class jax_dna.optimization.objective.DiffTReObjective(name: str, required_observables: list[str], needed_observables: list[str], logging_observables: list[str], grad_or_loss_fn: Callable[[tuple[jax_dna.utils.types.SimulatorActorOutput]], jax_dna.utils.types.Grads], energy_fn_builder: collections.abc.Callable[[jax_dna.utils.types.Params], collections.abc.Callable[[jax.numpy.ndarray], jax.numpy.ndarray]], opt_params: jax_dna.utils.types.Params, beta: float, n_equilibration_steps: int, min_n_eff_factor: float = 0.95, max_valid_opt_steps: int = math.inf, logging_config: dict[str, Any] = empty_dict)[source]

Bases: Objective

Objective that calculates the gradients of an objective using DiffTRe.

_energy_fn_builder
_opt_params
_beta
_n_eq_steps
_n_eff_factor = 0.95
_max_valid_opt_steps = inf
_opt_steps = 1
_reference_states = None
_reference_energies = None
calculate() list[jax_dna.utils.types.Grads][source]

Calculate the gradients of the objective.

is_ready() bool[source]

Check if the objective is ready to calculate its gradients.

post_step(opt_params: jax_dna.utils.types.Params) None[source]

Reset the needed observables for the next step.

class jax_dna.optimization.objective.DiffTReObjectiveActor(name: str, required_observables: list[str], needed_observables: list[str], logging_observables: list[str], grad_or_loss_fn: Callable[[tuple[jax_dna.utils.types.SimulatorActorOutput]], jax_dna.utils.types.Grads], energy_fn_builder: collections.abc.Callable[[jax_dna.utils.types.Params], collections.abc.Callable[[jax.numpy.ndarray], jax.numpy.ndarray]], opt_params: jax_dna.utils.types.Params, beta: float, n_equilibration_steps: int, min_n_eff_factor: float = 0.95, max_valid_opt_steps: int = math.inf, logging_config: dict[str, Any] = empty_dict)[source]

Bases: DiffTReObjective

Objective that calculates the gradients of an objective using DiffTRe and ray.