jax_dna.utils.helpers ===================== .. py:module:: jax_dna.utils.helpers .. autoapi-nested-parse:: Helper functions for the JAX-DNA package. Attributes ---------- .. autoapisummary:: jax_dna.utils.helpers.ERR_BATCHED_N Functions --------- .. autoapisummary:: jax_dna.utils.helpers.batched jax_dna.utils.helpers.tree_stack jax_dna.utils.helpers.tree_concatenate Module Contents --------------- .. py:data:: ERR_BATCHED_N :value: 'n must be at least one' .. py:function:: batched(iterable: collections.abc.Iterable[Any], n: int) -> collections.abc.Iterable[Any] Batch an iterable into chunks of size n. :param iterable: iterable to batch :type iterable: iter[Any] :param n: batch size :type n: int :returns: batched iterable :rtype: iter[Any] .. py:function:: tree_stack(trees: list[jaxtyping.PyTree]) -> jaxtyping.PyTree Stacks corresponding leaves of PyTrees into arrays along a new axis. .. py:function:: tree_concatenate(trees: list[jaxtyping.PyTree]) -> jaxtyping.PyTree Concatenates corresponding leaves of PyTrees along the first axis.