flax.linen.vmap#
- flax.linen.vmap(target, variable_axes=FrozenDict({}), split_rngs=FrozenDict({}), in_axes=0, out_axes=0, axis_size=None, axis_name=None, spmd_axis_name=None, metadata_params={}, methods=None)[source]#
A lifted version of
jax.vmap
.See
jax.vmap
for the unlifted batch transform in Jax.vmap
can be used to add a batch axis to aModule
. For example we could create a version ofDense
with a batch axis that does not share parameters:BatchDense = nn.vmap( nn.Dense, in_axes=0, out_axes=0, variable_axes={'params': 0}, split_rngs={'params': True})
By using
variable_axes={'params': 0}
, we indicate that the parameters themselves are mapped over and therefore not shared along the mapped axis. Consequently, we also split the ‘params’ RNG, otherwise the parameters would be initialized identically along the mapped axis.Similarly,
vmap
could be use to add a batch axis with parameter sharing:BatchFoo = nn.vmap( Foo, in_axes=0, out_axes=0, variable_axes={'params': None}, split_rngs={'params': False})
Here we use
variable_axes={'params': None}
to indicate the parameter variables are shared along the mapped axis. Consequently, the ‘params’ RNG must also be shared.- Parameters
target – a
Module
or a function taking aModule
as its first argument.variable_axes – the variable collections that are lifted into the batching transformation. Use None to indicate a broadcasted collection or an integer to map over an axis.
split_rngs – Split PRNG sequences will be different for each index of the batch dimension. Unsplit PRNGs will be broadcasted.
in_axes – Specifies the mapping of the input arguments (see jax.vmap).
out_axes – Specifies the mapping of the return value (see jax.vmap).
axis_size – Specifies the size of the batch axis. This only needs to be specified if it cannot be derived from the input arguments.
axis_name – Specifies a name for the batch axis. Can be used together with parallel reduction primitives (e.g. jax.lax.pmean, jax.lax.ppermute, etc.)
methods – If target is a Module, the methods of Module to vmap over.
spmd_axis_name – Axis name added to any pjit sharding constraints appearing in fn. See also google/flax.
metadata_params – arguments dict passed to AxisMetadata instances in the variable tree.
- Returns
A batched/vectorized version of
target
, with the same arguments but with extra axes at positions indicated byin_axes
, and the same return value, but with extra axes at positions indicated byout_axes
.