"""Objectives implemented as ray actors."""
import functools
import logging
import math
import operator
import types
import typing
from collections.abc import Callable
import jax
import jax.numpy as jnp
import jax_md
import ray
import typing_extensions
import jax_dna.energy as jdna_energy
import jax_dna.input.tree as jdna_tree
import jax_dna.utils.types as jdna_types
ERR_DIFFTRE_MISSING_KWARGS = "Missing required kwargs: {missing_kwargs}."
ERR_MISSING_ARG = "Missing required argument: {missing_arg}."
ERR_OBJECTIVE_NOT_READY = "Not all required observables have been obtained."
EnergyFn = jdna_energy.base.BaseEnergyFunction | jdna_energy.base.ComposedEnergyFunction
empty_dict = types.MappingProxyType({})
[docs]
class Objective:
"""Base class for objectives that calculate gradients."""
def __init__(
self,
name: str,
required_observables: list[str],
needed_observables: list[str],
logging_observables: list[str],
grad_or_loss_fn: typing.Callable[[tuple[str, ...]], tuple[jdna_types.Grads, list[tuple[str, typing.Any]]]],
logger_config: dict[str, typing.Any] = empty_dict,
) -> "Objective":
"""Initialize the objective.
Args:
name (str): The name of the objective.
required_observables (list[str]): The observables that are required
to calculate the gradients.
needed_observables (list[str]): The observables that are needed to
calculate the gradients.
logging_observables (list[str]): The observables that are used for
logging.
grad_or_loss_fn (typing.Callable[[tuple[str, ...]], jdna_types.Grads]):
The function that calculates the loss of the objective
logger_config (dict[str, typing.Any]): The configuration for the logger.
"""
if name is None:
raise ValueError(ERR_MISSING_ARG.format(missing_arg="name"))
if required_observables is None:
raise ValueError(ERR_MISSING_ARG.format(missing_arg="required_observables"))
if needed_observables is None:
raise ValueError(ERR_MISSING_ARG.format(missing_arg="needed_observables"))
if logging_observables is None:
raise ValueError(ERR_MISSING_ARG.format(missing_arg="logging_observables"))
if grad_or_loss_fn is None:
raise ValueError(ERR_MISSING_ARG.format(missing_arg="grad_or_loss_fn"))
self._name = name
self._required_observables = required_observables
self._needed_observables = needed_observables
self._grad_or_loss_fn = grad_or_loss_fn
self._obtained_observables = []
self._logging_observables = logging_observables
logging.basicConfig(**logger_config)
self._logger = logging.getLogger(__name__)
[docs]
def name(self) -> str:
"""Return the name of the objective."""
return self._name
[docs]
def required_observables(self) -> list[str]:
"""Return the observables that are required to calculate the gradients."""
return self._required_observables
[docs]
def needed_observables(self) -> list[str]:
"""Return the observables that are still needed."""
return self._needed_observables
[docs]
def obtained_observables(self) -> list[tuple[str, jdna_types.SimulatorActorOutput]]:
"""Return the latest observed values for all observables."""
return self._obtained_observables
[docs]
def logging_observables(self) -> list[tuple[str, typing.Any]]:
"""Return the latest observed values for the logging observables."""
lastest_observed = self._obtained_observables
return_values = []
for log_obs in self._logging_observables:
for obs in lastest_observed:
if obs[0] == log_obs:
return_values.append(obs)
break
return return_values
[docs]
def is_ready(self) -> bool:
"""Check if the objective is ready to calculate its gradients."""
obtained_keys = [obs[0] for obs in self._obtained_observables]
return all(obs in obtained_keys for obs in self._required_observables)
[docs]
def update(
self,
sim_results: list[tuple[list[str], list[str]]],
) -> None:
"""Update the observables with the latest simulation results."""
for sim_exposes, sim_output in sim_results:
for exposed, output in filter(
lambda e: e[0] in self._needed_observables, zip(sim_exposes, sim_output, strict=True)
):
self._obtained_observables.append((exposed, jdna_tree.load_pytree(output)))
self._needed_observables.remove(exposed)
[docs]
def calculate(self) -> list[jdna_types.Grads]:
"""Calculate the gradients of the objective."""
if not self.is_ready():
raise ValueError(ERR_OBJECTIVE_NOT_READY)
sorted_obtained_observables = sorted(
self._obtained_observables,
key=lambda x: self._required_observables.index(x[0]),
)
sorted_obs = [x[1] for x in sorted_obtained_observables]
grads, aux = self._grad_or_loss_fn(*sorted_obs)
self._obtained_observables = [
*aux,
*list(zip(self._required_observables, sorted_obs, strict=True)),
]
return grads
[docs]
def post_step(self, opt_params: dict) -> None: # noqa: ARG002 - not all objectives need params
"""Reset the needed observables for the next step."""
self._needed_observables = self._required_observables[:]
self._obtained_observables = []
[docs]
@ray.remote
class SimGradObjectiveActor(Objective):
"""Objective that calculates the gradients of a simulation."""
[docs]
def compute_weights_and_neff(
beta: float,
new_energies: jdna_types.Arr_N,
ref_energies: jdna_types.Arr_N,
) -> tuple[jnp.ndarray, float]:
"""Compute the weights and normalized effective sample size of a trajectory.
Calculation derived from the DiffTRe algorithm.
https://www.nature.com/articles/s41467-021-27241-4
See equations 4 and 5.
Args:
beta: The inverse temperature.
new_energies: The new energies of the trajectory.
ref_energies: The reference energies of the trajectory.
Returns:
The weights and the normalized effective sample size
"""
diffs = new_energies - ref_energies
boltz = jnp.exp(-beta * diffs)
weights = boltz / jnp.sum(boltz)
n_eff = jnp.exp(-jnp.sum(weights * jnp.log(weights)))
return weights, n_eff / len(weights)
[docs]
def compute_loss(
opt_params: jdna_types.Params,
energy_fn_builder: callable,
beta: float,
loss_fn: Callable[
[jax_md.rigid_body.RigidBody, jdna_types.Arr_N, EnergyFn], tuple[jnp.ndarray, tuple[str, typing.Any]]
],
ref_states: jax_md.rigid_body.RigidBody,
ref_energies: jdna_types.Arr_N,
) -> tuple[float, tuple[float, jnp.ndarray]]:
"""Compute the grads, loss, and auxiliary values.
Args:
opt_params: The optimization parameters.
energy_fn_builder: A function that builds the energy function.
beta: The inverse temperature.
loss_fn: The loss function.
ref_states: The reference states of the trajectory.
ref_energies: The reference energies of the trajectory.
Returns:
The grads, the loss, a tuple containing the normalized effective sample
size and the measured value of the trajectory, and the new energies.
"""
energy_fn = energy_fn_builder(opt_params)
new_energies = energy_fn_builder(opt_params)(ref_states)
weights, neff = compute_weights_and_neff(
beta,
new_energies,
ref_energies,
)
loss, (measured_value, meta) = loss_fn(ref_states, weights, energy_fn)
return loss, (neff, measured_value, new_energies)
compute_loss_and_grad = jax.value_and_grad(compute_loss, has_aux=True)
[docs]
class DiffTReObjective(Objective):
"""Objective that calculates the gradients of an objective using DiffTRe."""
def __init__(
self,
name: str,
required_observables: list[str],
needed_observables: list[str],
logging_observables: list[str],
grad_or_loss_fn: typing.Callable[[tuple[jdna_types.SimulatorActorOutput]], jdna_types.Grads],
energy_fn_builder: Callable[[jdna_types.Params], Callable[[jnp.ndarray], jnp.ndarray]],
opt_params: jdna_types.Params,
beta: float,
n_equilibration_steps: int,
min_n_eff_factor: float = 0.95,
max_valid_opt_steps: int = math.inf,
logging_config: dict[str, typing.Any] = empty_dict,
) -> "DiffTReObjective":
"""Initialize the DiffTRe objective.
Args:
name: The name of the objective.
required_observables: The observables that are required to calculate the gradients.
needed_observables: The observables that are needed to calculate the gradients.
logging_observables: The observables that are used for logging.
grad_or_loss_fn: The function that calculates the loss of the objective.
energy_fn_builder: A function that builds the energy function.
opt_params: The optimization parameters.
beta: The inverse temperature.
n_equilibration_steps: The number of equilibration steps.
min_n_eff_factor: The minimum normalized effective sample size.
max_valid_opt_steps: The maximum number of steps a trajectory is valid for.
logging_config: The configuration for the logger.
"""
super().__init__(
name,
required_observables,
needed_observables,
logging_observables,
grad_or_loss_fn,
logger_config=logging_config,
)
if energy_fn_builder is None:
raise ValueError(ERR_MISSING_ARG.format(missing_arg="energy_fn_builder"))
if opt_params is None:
raise ValueError(ERR_MISSING_ARG.format(missing_arg="opt_params"))
if beta is None:
raise ValueError(ERR_MISSING_ARG.format(missing_arg="beta"))
if n_equilibration_steps is None:
raise ValueError(ERR_MISSING_ARG.format(missing_arg="n_equilibration_steps"))
self._energy_fn_builder = energy_fn_builder
self._opt_params = opt_params
self._beta = beta
self._n_eq_steps = n_equilibration_steps
self._n_eff_factor = min_n_eff_factor
self._max_valid_opt_steps = max_valid_opt_steps
self._opt_steps = 1
self._reference_states = None
self._reference_energies = None
[docs]
@typing_extensions.override
def calculate(self) -> list[jdna_types.Grads]:
if not self.is_ready():
raise ValueError(ERR_OBJECTIVE_NOT_READY)
# want the required observables in the order they are requested
sorted_obtained_observables = sorted(
filter(lambda x: x[0] in self._required_observables, self._obtained_observables),
key=lambda x: self._required_observables.index(x[0]),
)
sorted_obs = [x[1] for x in sorted_obtained_observables]
(loss, (_, measured_value, new_energies)), grads = compute_loss_and_grad(
self._opt_params,
self._energy_fn_builder,
self._beta,
self._grad_or_loss_fn,
self._reference_states,
self._reference_energies,
)
latest_neff = next(obs for obs in self._obtained_observables if obs[0] == "neff")
self._obtained_observables = [
("loss", loss),
latest_neff,
measured_value,
*list(zip(self._required_observables, sorted_obs, strict=True)),
]
return grads
[docs]
@typing_extensions.override
def is_ready(self) -> bool:
have_trajectories = super().is_ready()
if have_trajectories:
sorted_obtained_observables = sorted(
filter(lambda x: x[0] in self._required_observables, self._obtained_observables),
key=lambda x: self._required_observables.index(x[0]),
)
new_tracjectories = [oo[1] for oo in sorted_obtained_observables]
if self._reference_states is None:
def slc_f(n: int) -> slice:
return slice(self._n_eq_steps, n, None)
self._reference_states = functools.reduce(
operator.add,
[obs.slice(slc_f(len(obs.rigid_body.center))) for obs in new_tracjectories],
)
self._reference_energies = self._energy_fn_builder(self._opt_params)(self._reference_states)
self._logger.info("trajectory length is %d", len(self._reference_states.rigid_body.center))
_, neff = compute_weights_and_neff(
beta=self._beta,
new_energies=self._energy_fn_builder(self._opt_params)(self._reference_states),
ref_energies=self._reference_energies,
)
if any(obs[0] == "neff" for obs in self._obtained_observables):
self._obtained_observables = [
(obs[0], neff) if obs[0] == "neff" else obs for obs in self._obtained_observables
]
else:
self._obtained_observables.append(("neff", neff))
# if the trajectory is no longer valid remove it form obtained
# and add it to needed so that a new trajectory is run.
self._logger.info("checking neff %f neff_factory %f", neff, self._n_eff_factor)
self._logger.info("checking opt steps %d vs %f", self._opt_steps, float(self._max_valid_opt_steps))
if (neff < self._n_eff_factor) or (self._opt_steps == self._max_valid_opt_steps):
self._obtained_observables = []
self._needed_observables = self._required_observables[:]
self._reference_states = None
self._opt_steps = 1
have_trajectories = False
return have_trajectories
[docs]
@typing_extensions.override
def post_step(
self,
opt_params: jdna_types.Params,
) -> None:
# DiffTre objectives may not need to update the trajectory depending on neff
# the need for a new trajectory is checked in `is_ready`
self._obtained_observables = [oo for oo in self._obtained_observables if oo[0] not in ("neff", "loss")]
self._opt_params = opt_params
self._opt_steps += 1
[docs]
@ray.remote
class DiffTReObjectiveActor(DiffTReObjective):
"""Objective that calculates the gradients of an objective using DiffTRe and ray."""