flax.jax_utils package ======================== .. currentmodule:: flax.jax_utils .. automodule:: flax.jax_utils .. autofunction:: partial_eval_by_shape Multi device utilities ------------------------ .. autofunction:: replicate .. autofunction:: unreplicate .. autofunction:: prefetch_to_device .. autofunction:: pmean .. autofunction:: pad_shard_unpad