flax.linen.vmap

flax.linen.vmap(target, variable_axes=FrozenDict({}), split_rngs=FrozenDict({}), in_axes=0, out_axes=0, axis_size=None, axis_name=None, methods=None)[source]

A lifted version of jax.vmap.

See jax.vmap for the unlifted batch transform in Jax.

vmap can be used to add a batch axis to a Module. For example we could create a version of Dense with a batch axis that does not share parameters:

BatchDense = nn.vmap(
    nn.Dense,
    in_axes=0, out_axes=0,
    variable_axes={'params': 0},
    split_rngs={'params': True})

By using variable_axes={'params': 0}, we indicate that the parameters themselves are mapped over and therefore not shared along the mapped axis. Consequently, we also split the ‘params’ RNG, otherwise the parameters would be initialized identically along the mapped axis.

Similarly, vmap could be use to add a batch axis with parameter sharing:

BatchFoo = nn.vmap(
    Foo,
    in_axes=0, out_axes=0,
    variable_axes={'params': None},
    split_rngs={'params': False})

Here we use variable_axes={'params': None} to indicate the parameter variables are shared along the mapped axis. Consequently, the ‘params’ RNG must also be shared.

Parameters
  • target (flax.linen.transforms.Target) – a Module or a function taking a Module as its first argument.

  • variable_axes (Mapping[Union[bool, str, Collection[str], DenyList], Union[int, None, flax.core.lift.In[Optional[int]], flax.core.lift.Out[Optional[int]]]]) – the variable collections that are lifted into the batching transformation. Use None to indicate a broadcasted collection or an integer to map over an axis.

  • split_rngs (Mapping[Union[bool, str, Collection[str], DenyList], bool]) – Split PRNG sequences will be different for each index of the batch dimension. Unsplit PRNGs will be broadcasted.

  • in_axes – Specifies the mapping of the input arguments (see jax.vmap).

  • out_axes – Specifies the mapping of the return value (see jax.vmap).

  • axis_size (Optional[int]) – Specifies the size of the batch axis. This only needs to be specified if it cannot be derived from the input arguments.

  • axis_name (Optional[str]) – Specifies a name for the batch axis. Can be used together with parallel reduction primitives (e.g. jax.lax.pmean, jax.lax.ppermute, etc.)

  • methods – If target is a Module, the methods of Module to vmap over.

Returns

A batched/vectorized version of target, with the same arguments but with extra axes at positions indicated by in_axes, and the same return value, but with extra axes at positions indicated by out_axes.

Return type

flax.linen.transforms.Target