flax.linen.vmap#

flax.linen.vmap(target, variable_axes=FrozenDict({}), split_rngs=FrozenDict({}), in_axes=0, out_axes=0, axis_size=None, axis_name=None, spmd_axis_name=None, metadata_params={}, methods=None)[source]#

A lifted version of jax.vmap.

See jax.vmap for the unlifted batch transform in Jax.

vmap can be used to add a batch axis to a Module. For example we could create a version of Dense with a batch axis that does not share parameters:

>>> import flax.linen as nn
>>> BatchDense = nn.vmap(
...     nn.Dense,
...     in_axes=0, out_axes=0,
...     variable_axes={'params': 0},
...     split_rngs={'params': True})

By using variable_axes={'params': 0}, we indicate that the parameters themselves are mapped over and therefore not shared along the mapped axis. Consequently, we also split the ‘params’ RNG, otherwise the parameters would be initialized identically along the mapped axis.

Similarly, vmap could be used to add a batch axis with parameter sharing:

>>> import flax.linen as nn
>>> BatchDense = nn.vmap(
...     nn.Dense,
...     in_axes=0, out_axes=0,
...     variable_axes={'params': None},
...     split_rngs={'params': False})

Here we use variable_axes={'params': None} to indicate the parameter variables are shared along the mapped axis. Consequently, the ‘params’ RNG must also be shared.

Parameters
  • target – a Module or a function taking a Module as its first argument.

  • variable_axes – the variable collections that are lifted into the batching transformation. Use None to indicate a broadcasted collection or an integer to map over an axis.

  • split_rngs – Split PRNG sequences will be different for each index of the batch dimension. Unsplit PRNGs will be broadcasted.

  • in_axes – Specifies the mapping of the input arguments (see jax.vmap).

  • out_axes – Specifies the mapping of the return value (see jax.vmap).

  • axis_size – Specifies the size of the batch axis. This only needs to be specified if it cannot be derived from the input arguments.

  • axis_name – Specifies a name for the batch axis. Can be used together with parallel reduction primitives (e.g. jax.lax.pmean, jax.lax.ppermute, etc.). Note, this is only used for pmap and shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives.

  • methods – If target is a Module, the methods of Module to vmap over.

  • spmd_axis_name – Axis name added to any pjit sharding constraints appearing in fn. See also google/flax.

  • metadata_params – arguments dict passed to AxisMetadata instances in the variable tree.

Returns

A batched/vectorized version of target, with the same arguments but with extra axes at positions indicated by in_axes, and the same return value, but with extra axes at positions indicated by out_axes.