jax_dna.optimization.objective ============================== .. py:module:: jax_dna.optimization.objective .. autoapi-nested-parse:: Objectives implemented as ray actors. Attributes ---------- .. autoapisummary:: jax_dna.optimization.objective.ERR_DIFFTRE_MISSING_KWARGS jax_dna.optimization.objective.ERR_MISSING_ARG jax_dna.optimization.objective.ERR_OBJECTIVE_NOT_READY jax_dna.optimization.objective.EnergyFn jax_dna.optimization.objective.empty_dict jax_dna.optimization.objective.compute_loss_and_grad Classes ------- .. autoapisummary:: jax_dna.optimization.objective.Objective jax_dna.optimization.objective.SimGradObjectiveActor jax_dna.optimization.objective.DiffTReObjective jax_dna.optimization.objective.DiffTReObjectiveActor Functions --------- .. autoapisummary:: jax_dna.optimization.objective.compute_weights_and_neff jax_dna.optimization.objective.compute_loss Module Contents --------------- .. py:data:: ERR_DIFFTRE_MISSING_KWARGS :value: 'Missing required kwargs: {missing_kwargs}.' .. py:data:: ERR_MISSING_ARG :value: 'Missing required argument: {missing_arg}.' .. py:data:: ERR_OBJECTIVE_NOT_READY :value: 'Not all required observables have been obtained.' .. py:data:: EnergyFn .. py:data:: empty_dict .. py:class:: Objective(name: str, required_observables: list[str], needed_observables: list[str], logging_observables: list[str], grad_or_loss_fn: Callable[[tuple[str, Ellipsis]], tuple[jax_dna.utils.types.Grads, list[tuple[str, Any]]]], logger_config: dict[str, Any] = empty_dict) Base class for objectives that calculate gradients. .. py:attribute:: _name .. py:attribute:: _required_observables .. py:attribute:: _needed_observables .. py:attribute:: _grad_or_loss_fn .. py:attribute:: _obtained_observables :value: [] .. py:attribute:: _logging_observables .. py:attribute:: _logger .. py:method:: name() -> str Return the name of the objective. .. py:method:: required_observables() -> list[str] Return the observables that are required to calculate the gradients. .. py:method:: needed_observables() -> list[str] Return the observables that are still needed. .. py:method:: obtained_observables() -> list[tuple[str, jax_dna.utils.types.SimulatorActorOutput]] Return the latest observed values for all observables. .. py:method:: logging_observables() -> list[tuple[str, Any]] Return the latest observed values for the logging observables. .. py:method:: is_ready() -> bool Check if the objective is ready to calculate its gradients. .. py:method:: update(sim_results: list[tuple[list[str], list[str]]]) -> None Update the observables with the latest simulation results. .. py:method:: calculate() -> list[jax_dna.utils.types.Grads] Calculate the gradients of the objective. .. py:method:: post_step(opt_params: dict) -> None Reset the needed observables for the next step. .. py:class:: SimGradObjectiveActor(name: str, required_observables: list[str], needed_observables: list[str], logging_observables: list[str], grad_or_loss_fn: Callable[[tuple[str, Ellipsis]], tuple[jax_dna.utils.types.Grads, list[tuple[str, Any]]]], logger_config: dict[str, Any] = empty_dict) Bases: :py:obj:`Objective` Objective that calculates the gradients of a simulation. .. py:function:: compute_weights_and_neff(beta: float, new_energies: jax_dna.utils.types.Arr_N, ref_energies: jax_dna.utils.types.Arr_N) -> tuple[jax.numpy.ndarray, float] Compute the weights and normalized effective sample size of a trajectory. Calculation derived from the DiffTRe algorithm. https://www.nature.com/articles/s41467-021-27241-4 See equations 4 and 5. :param beta: The inverse temperature. :param new_energies: The new energies of the trajectory. :param ref_energies: The reference energies of the trajectory. :returns: The weights and the normalized effective sample size .. py:function:: compute_loss(opt_params: jax_dna.utils.types.Params, energy_fn_builder: callable, beta: float, loss_fn: collections.abc.Callable[[jax_md.rigid_body.RigidBody, jax_dna.utils.types.Arr_N, EnergyFn], tuple[jax.numpy.ndarray, tuple[str, Any]]], ref_states: jax_md.rigid_body.RigidBody, ref_energies: jax_dna.utils.types.Arr_N) -> tuple[float, tuple[float, jax.numpy.ndarray]] Compute the grads, loss, and auxiliary values. :param opt_params: The optimization parameters. :param energy_fn_builder: A function that builds the energy function. :param beta: The inverse temperature. :param loss_fn: The loss function. :param ref_states: The reference states of the trajectory. :param ref_energies: The reference energies of the trajectory. :returns: The grads, the loss, a tuple containing the normalized effective sample size and the measured value of the trajectory, and the new energies. .. py:data:: compute_loss_and_grad .. py:class:: DiffTReObjective(name: str, required_observables: list[str], needed_observables: list[str], logging_observables: list[str], grad_or_loss_fn: Callable[[tuple[jax_dna.utils.types.SimulatorActorOutput]], jax_dna.utils.types.Grads], energy_fn_builder: collections.abc.Callable[[jax_dna.utils.types.Params], collections.abc.Callable[[jax.numpy.ndarray], jax.numpy.ndarray]], opt_params: jax_dna.utils.types.Params, beta: float, n_equilibration_steps: int, min_n_eff_factor: float = 0.95, max_valid_opt_steps: int = math.inf, logging_config: dict[str, Any] = empty_dict) Bases: :py:obj:`Objective` Objective that calculates the gradients of an objective using DiffTRe. .. py:attribute:: _energy_fn_builder .. py:attribute:: _opt_params .. py:attribute:: _beta .. py:attribute:: _n_eq_steps .. py:attribute:: _n_eff_factor :value: 0.95 .. py:attribute:: _max_valid_opt_steps :value: inf .. py:attribute:: _opt_steps :value: 1 .. py:attribute:: _reference_states :value: None .. py:attribute:: _reference_energies :value: None .. py:method:: calculate() -> list[jax_dna.utils.types.Grads] Calculate the gradients of the objective. .. py:method:: is_ready() -> bool Check if the objective is ready to calculate its gradients. .. py:method:: post_step(opt_params: jax_dna.utils.types.Params) -> None Reset the needed observables for the next step. .. py:class:: DiffTReObjectiveActor(name: str, required_observables: list[str], needed_observables: list[str], logging_observables: list[str], grad_or_loss_fn: Callable[[tuple[jax_dna.utils.types.SimulatorActorOutput]], jax_dna.utils.types.Grads], energy_fn_builder: collections.abc.Callable[[jax_dna.utils.types.Params], collections.abc.Callable[[jax.numpy.ndarray], jax.numpy.ndarray]], opt_params: jax_dna.utils.types.Params, beta: float, n_equilibration_steps: int, min_n_eff_factor: float = 0.95, max_valid_opt_steps: int = math.inf, logging_config: dict[str, Any] = empty_dict) Bases: :py:obj:`DiffTReObjective` Objective that calculates the gradients of an objective using DiffTRe and ray.