flax.linen.vmap#
- flax.linen.vmap(target, variable_axes=FrozenDict({}), split_rngs=FrozenDict({}), in_axes=0, out_axes=0, axis_size=None, axis_name=None, spmd_axis_name=None, metadata_params={}, methods=None)[source]#
A lifted version of
jax.vmap
.See
jax.vmap
for the unlifted batch transform in Jax.vmap
can be used to add a batch axis to aModule
. For example we could create a version ofDense
with a batch axis that does not share parameters:>>> import flax.linen as nn >>> BatchDense = nn.vmap( ... nn.Dense, ... in_axes=0, out_axes=0, ... variable_axes={'params': 0}, ... split_rngs={'params': True})
By using
variable_axes={'params': 0}
, we indicate that the parameters themselves are mapped over and therefore not shared along the mapped axis. Consequently, we also split the ‘params’ RNG, otherwise the parameters would be initialized identically along the mapped axis.Similarly,
vmap
could be used to add a batch axis with parameter sharing:>>> import flax.linen as nn >>> BatchDense = nn.vmap( ... nn.Dense, ... in_axes=0, out_axes=0, ... variable_axes={'params': None}, ... split_rngs={'params': False})
Here we use
variable_axes={'params': None}
to indicate the parameter variables are shared along the mapped axis. Consequently, the ‘params’ RNG must also be shared.- Parameters
target – a
Module
or a function taking aModule
as its first argument.variable_axes – the variable collections that are lifted into the batching transformation. Use None to indicate a broadcasted collection or an integer to map over an axis.
split_rngs – Split PRNG sequences will be different for each index of the batch dimension. Unsplit PRNGs will be broadcasted.
in_axes – Specifies the mapping of the input arguments (see jax.vmap).
out_axes – Specifies the mapping of the return value (see jax.vmap).
axis_size – Specifies the size of the batch axis. This only needs to be specified if it cannot be derived from the input arguments.
axis_name – Specifies a name for the batch axis. Can be used together with parallel reduction primitives (e.g. jax.lax.pmean, jax.lax.ppermute, etc.). Note, this is only used for pmap and shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives.
methods – If target is a Module, the methods of Module to vmap over.
spmd_axis_name – Axis name added to any pjit sharding constraints appearing in fn. See also google/flax.
metadata_params – arguments dict passed to AxisMetadata instances in the variable tree.
- Returns
A batched/vectorized version of
target
, with the same arguments but with extra axes at positions indicated byin_axes
, and the same return value, but with extra axes at positions indicated byout_axes
.