flax.jax_utils package

Utilities we could consider upstreaming to Jax.

flax.jax_utils.partial_eval_by_shape(fn, input_spec, *args, **kwargs)[source]

Lazily evaluate a function by using the shapes of the inputs.

This function is similar to jax.eval_shape with the key difference that function outputs that can be computed without a concrete value of the inputs are returned as is instead of only the shape. See for example module.init_by_shape where this functionality is used to initialize a model without using input data lr computation.

  • fn – the function to be lazily evaluated.

  • input_spec – an iterable of shapes or (shape, dtype) tuples specifying the shape and type of the inputs. If unspecified the dtype is float32.

  • *args – other arguments passed to the module’s apply function

  • **kwargs – keyword arguments passed to the module’s apply function


A pair consisting of the model output and an instance of Model

Multi device utilities

flax.jax_utils.replicate(tree, devices=None)[source]

Replicates arrays to multiple devices.

  • tree – a pytree containing the arrays that should be replicated.

  • devices – the devices the data is replicated to (default: same order as expected by jax.pmap()).


A new pytree containing the replicated arrays.


Returns a single instance of a replicated array.

flax.jax_utils.prefetch_to_device(iterator, size, devices=None)[source]

Shard and prefetch batches on device.

This utility takes an iterator and returns a new iterator which fills an on device prefetch buffer. Eager prefetching can improve the performance of training loops significantly by overlapping compute and data transfer.

This utility is mostly useful for GPUs, for TPUs and CPUs it should not be necessary – the TPU & CPU memory allocators (normally) don’t pick a memory location that isn’t free yet so they don’t block. Instead those allocators OOM.

  • iterator – an iterator that yields a pytree of ndarrays where the first dimension is sharded across devices.

  • size

    the size of the prefetch buffer.

    If you’re training on GPUs, 2 is generally the best choice because this guarantees that you can overlap a training step on GPU with a data prefetch step on CPU.

  • devices

    the list of devices to which the arrays should be prefetched.

    Defaults to the order of devices expected by jax.pmap.


The original items from the iterator where each ndarray is now a sharded to the specified devices.

flax.jax_utils.pmean(xs, axis_name)[source]
flax.jax_utils.pad_shard_unpad(wrapped, static_argnums=(0,), static_argnames=(), static_return=False)[source]

Wraps a function with code that pads, shards, then un-shards, un-pads.

  • wrapped – the function to be wrapped. Signature is params, *args, *kwargs.

  • static_argnums – indices of arguments to wrapped that should _not_ be padded and sharded, but instead be forwarded as-is. The default is (0,) because by far the most common use-case is to pass params first.

  • static_argnames – names of kwargs to wrapped that should _not_ be padded and sharded, but instead be forwarded as-is.

  • static_return – whether not to un-shard, and un-pad the return value; static return values are typically used with eval steps that compute metrics


A new function that pads and shards its arguments before passing them to the wrapped function, and un-shards and un-pads the returned pytree.

This is useful for calling a pmap’ed function with inputs that aren’t divisible by the number of devices. A typical use is:

@pad_shard_unpad @jax.pmap def forward(params, x): …


The padding is done in host-memory before being passed to the function, and the values returned by the function are transferred back to host memory.

The returned function is augmented with a new keyword-only argument min_device_batch that, if specified, forces padding inputs to at least this size per device. This can be useful to avoid recompiles for the last batch and reduce memory fragmentation.

For more information refer to https://flax.readthedocs.io/en/latest/howtos/full_eval.html