flax.linen.vmap#

flax.linen.vmap(target, variable_axes=FrozenDict({}), split_rngs=FrozenDict({}), in_axes=0, out_axes=0, axis_size=None, axis_name=None, spmd_axis_name=None, methods=None)[source]#

A lifted version of jax.vmap.

See jax.vmap for the unlifted batch transform in Jax.

vmap can be used to add a batch axis to a Module. For example we could create a version of Dense with a batch axis that does not share parameters:

BatchDense = nn.vmap(
    nn.Dense,
    in_axes=0, out_axes=0,
    variable_axes={'params': 0},
    split_rngs={'params': True})

By using variable_axes={'params': 0}, we indicate that the parameters themselves are mapped over and therefore not shared along the mapped axis. Consequently, we also split the β€˜params’ RNG, otherwise the parameters would be initialized identically along the mapped axis.

Similarly, vmap could be use to add a batch axis with parameter sharing:

BatchFoo = nn.vmap(
    Foo,
    in_axes=0, out_axes=0,
    variable_axes={'params': None},
    split_rngs={'params': False})

Here we use variable_axes={'params': None} to indicate the parameter variables are shared along the mapped axis. Consequently, the β€˜params’ RNG must also be shared.

Parameters
  • target – a Module or a function taking a Module as its first argument.

  • variable_axes – the variable collections that are lifted into the batching transformation. Use None to indicate a broadcasted collection or an integer to map over an axis.

  • split_rngs – Split PRNG sequences will be different for each index of the batch dimension. Unsplit PRNGs will be broadcasted.

  • in_axes – Specifies the mapping of the input arguments (see jax.vmap).

  • out_axes – Specifies the mapping of the return value (see jax.vmap).

  • axis_size – Specifies the size of the batch axis. This only needs to be specified if it cannot be derived from the input arguments.

  • axis_name – Specifies a name for the batch axis. Can be used together with parallel reduction primitives (e.g. jax.lax.pmean, jax.lax.ppermute, etc.)

  • methods – If target is a Module, the methods of Module to vmap over.

  • spmd_axis_name – Axis name added to any pjit sharding constraints appearing in fn. See also https://github.com/google/flax/blob/main/flax/linen/partitioning.py.

Returns

A batched/vectorized version of target, with the same arguments but with extra axes at positions indicated by in_axes, and the same return value, but with extra axes at positions indicated by out_axes.