flax.jax_utils package

Utilities we could consider upstreaming to Jax.

flax.jax_utils.partial_eval_by_shape(fn, input_spec, *args, **kwargs)[source]

Lazily evaluate a function by using the shapes of the inputs.

This function is similar to jax.eval_shape with the key difference that function outputs that can be computed without a concrete value of the inputs are returned as is instead of only the shape. See for example module.init_by_shape where this functionality is used to initialize a model without using input data lr computation.

  • fn – the function to be lazily evaluated.

  • input_spec – an iterable of shapes or (shape, dtype) tuples specifying the shape and type of the inputs. If unspecified the dtype is float32.

  • *args – other arguments passed to the module’s apply function

  • **kwargs – keyword arguments passed to the module’s apply function


A pair consisting of the model output and an instance of Model

Multi device utilities

flax.jax_utils.replicate(tree, devices=None)[source]

Replicates arrays to multiple devices.

  • tree – a pytree containing the arrays that should be replicated.

  • devices – the devices the data is replicated to (default: jax.local_devices()).


A new pytree containing the replicated arrays.


Returns a single instance of a replicated array.

flax.jax_utils.prefetch_to_device(iterator, size, devices=None)[source]

“Shard and prefetch batches on device.

This utility takes an iterator and returns a new iterator which fills an on device prefetch buffer. Eager prefetching can improve the performance of training loops significantly by overlapping compute and data transfer.

This utility is mostly useful for GPUs, for TPUs it should not be necessary.

  • iterator – an iterator that yields a pytree of ndarrays where the first dimension is sharded across devices.

  • size – the size of the prefetch buffer.

  • devices – the list of devices to which the arrays should be prefetched.


The original items from the iterator where each ndarray is now a sharded to the specified devices.

flax.jax_utils.pmean(xs, axis_name)[source]