Source code for jax_dna.optimization.optimization

"""Runs an optimization loop using Ray actors for objectives and simulators."""

import dataclasses as dc
import itertools
import typing

import chex
import optax
import ray

import jax_dna.optimization.objective as jdna_objective
import jax_dna.optimization.simulator as jdna_actor
import jax_dna.utils.types as jdna_types
from jax_dna.ui.loggers import logger as jdna_logger

ERR_MISSING_OBJECTIVES = "At least one objective is required."
ERR_MISSING_SIMULATORS = "At least one simulator is required."
ERR_MISSING_AGG_GRAD_FN = "An aggregate gradient function is required."
ERR_MISSING_OPTIMIZER = "An optimizer is required."

# we assign at the global level to make it easier to mock for testing
get_fn = ray.get
wait_fn = ray.wait
grad_update_fn = optax.apply_updates


[docs] def split_by_ready( objectives: list[jdna_objective.Objective], ) -> tuple[list[jdna_objective.Objective], list[jdna_objective.Objective]]: """Splits a list of objectives into two lists: ready and not ready.""" ready, not_ready = [], [] for objective in objectives: if get_fn(objective.is_ready.remote()): ready.append(objective) else: not_ready.append(objective) return ready, not_ready
[docs] @chex.dataclass(frozen=True) class Optimization: """Optimization of a list of objectives using a list of simulators. Parameters: objectives: A list of objectives to optimize. simulators: A list of simulators to use for the optimization. aggregate_grad_fn: A function that aggregates the gradients from the objectives. optimizer: An optax optimizer. optimizer_state: The state of the optimizer. logger: A logger to use for the optimization. """ objectives: list[jdna_objective.Objective] simulators: list[tuple[jdna_actor.SimulatorActor, jdna_types.MetaData]] aggregate_grad_fn: typing.Callable[[list[jdna_types.Grads]], jdna_types.Grads] optimizer: optax.GradientTransformation optimizer_state: optax.OptState | None = None logger: jdna_logger.Logger = dc.field(default_factory=lambda: jdna_logger.Logger())
[docs] def __post_init__(self) -> None: """Validate the initialization of the Optimization.""" if not self.objectives: raise ValueError(ERR_MISSING_OBJECTIVES) if not self.simulators: raise ValueError(ERR_MISSING_SIMULATORS) if self.aggregate_grad_fn is None: raise ValueError(ERR_MISSING_AGG_GRAD_FN) if self.optimizer is None: raise ValueError(ERR_MISSING_OPTIMIZER)
[docs] def step(self, params: jdna_types.Params) -> tuple[optax.OptState, list[jdna_types.Grads], list[jdna_types.Grads]]: """Perform a single optimization step. Args: params: The current parameters. Returns: A tuple containing the updated optimizer state, new params, and the gradients. """ # get the currently needed observables # some objectives might use difftre and not actually need something rerun # so check which objectives have observables that need to be run ready_objectives, not_ready_objectives = split_by_ready(self.objectives) grad_refs = [objective.calculate.remote() for objective in ready_objectives] ready_names = get_fn([objective.name.remote() for objective in ready_objectives]) ready_funcs = itertools.repeat(self.logger.set_objective_running, len(ready_names)) not_ready_names = get_fn([objective.name.remote() for objective in not_ready_objectives]) not_ready_funcs = itertools.repeat(self.logger.set_objective_started, len(not_ready_names)) sim_names = get_fn([sim.name.remote() for sim in self.simulators]) sim_funcs = itertools.repeat(self.logger.set_simulator_started, len(sim_names)) names = itertools.chain(ready_names, not_ready_names, sim_names) funcs = itertools.chain(ready_funcs, not_ready_funcs, sim_funcs) for name, func in zip(names, funcs, strict=True): func(name) need_observables = list( itertools.chain.from_iterable(get_fn([co.needed_observables.remote() for co in not_ready_objectives])) ) needed_simulators = [ sim for sim in self.simulators if set(get_fn(sim.exposes.remote())) & set(need_observables) ] needed_names = get_fn([sim.name.remote() for sim in needed_simulators]) needed_exposes = get_fn([sim.exposes.remote() for sim in needed_simulators]) sim_remotes = [sim.run.remote(params) for sim in needed_simulators] simid_exposes = {} simid_name = {} for sr, name, exposes in zip(sim_remotes, needed_names, needed_exposes, strict=True): simid_exposes[sr.task_id().hex()] = exposes simid_name[sr.task_id().hex()] = name self.logger.set_simulator_running(name) [self.logger.set_observable_running(e) for e in exposes] # wait for the simulators to finish while not_ready_objectives: # `done` is a list of object refs that are ready to collect. # sim_remotes is a list of object refs that are not ready to collect. done, sim_remotes = wait_fn(sim_remotes) if done: captured_results = [] for d in done: task_id = d.task_id().hex() exposes = simid_exposes[task_id] result = get_fn(d) captured_results.append((exposes, result)) if self.logger: self.logger.set_simulator_complete(simid_name[task_id]) for expose in exposes: self.logger.set_observable_complete(expose) # update the objectives with the new observables and check if they are ready get_fn([objective.update.remote(captured_results) for objective in not_ready_objectives]) ready_objectives, not_ready_objectives = split_by_ready(not_ready_objectives) for name in get_fn([objective.name.remote() for objective in ready_objectives]): self.logger.set_objective_running(name) grad_refs += [objective.calculate.remote() for objective in ready_objectives] grads_resolved = get_fn(grad_refs) for name in get_fn([o.name.remote() for o in self.objectives]): self.logger.set_objective_complete(name) grads = self.aggregate_grad_fn(grads_resolved) opt_state = self.optimizer.init(params) if self.optimizer_state is None else self.optimizer_state updates, opt_state = self.optimizer.update(grads, opt_state, params) new_params = grad_update_fn(params, updates) return opt_state, new_params, grads
[docs] def post_step( self, optimizer_state: optax.OptState, opt_params: jdna_types.Params, ) -> "Optimization": """An update step intended to be called after an optimization step.""" _ = get_fn([o.post_step.remote(opt_params) for o in self.objectives]) return self.replace(optimizer_state=optimizer_state)
[docs] @chex.dataclass(frozen=True) class SimpleOptimizer: """A simple optimizer that uses a single objective and simulator.""" objective: jdna_objective.Objective simulator: jdna_actor.SimulatorActor optimizer: optax.GradientTransformation optimizer_state: optax.OptState | None = None logger: jdna_logger.Logger = dc.field(default_factory=lambda: jdna_logger.Logger())
[docs] def step(self, params: jdna_types.Params) -> tuple[optax.OptState, list[jdna_types.Grads], list[jdna_types.Grads]]: """Perform a single optimization step. Args: params: The current parameters. Returns: A tuple containing the updated optimizer state, new params, and the gradients. """ # get the currently needed observables # some objectives might use difftre and not actually need something rerun # so check which objectives have observables that need to be run if self.objective.is_ready(): grads = self.objective.calculate() else: observables = self.simulator.run(params) exposes = self.simulator.exposes() self.objective.update( [ (exposes, observables), ] ) grads = self.objective.calculate() opt_state = self.optimizer.init(params) if self.optimizer_state is None else self.optimizer_state updates, opt_state = self.optimizer.update(grads, opt_state, params) new_params = grad_update_fn(params, updates) return opt_state, new_params, grads
[docs] def post_step(self, optimizer_state: optax.OptState, opt_params: jdna_types.Params) -> "SimpleOptimizer": """An update step intended to be called after an optimization step.""" self.objective.post_step(opt_params) return self.replace(optimizer_state=optimizer_state)