class flax.linen.BatchApply(f, num_dims=2)[source]#

Temporarily merges leading dimensions of input tensors.

Merges the leading dimensions of a tensor into a single dimension, runs the given callable, then splits the leading dimension of the result to match the input.

Input arrays whose rank is smaller than the number of dimensions to collapse are passed unmodified.

This may be useful for applying a module to each timestep of e.g. a [Time, Batch, ...] array.

For some fs and platforms, this may be more efficient than jax.vmap(), especially when combined with other transformations like jax.grad().

Example usage:

>>> import jax, jax.numpy as jnp

>>> a = jax.random.normal(jax.random.key(0), [2, 3, 4])
>>> b = jax.random.normal(jax.random.key(1), [4])

>>> def raises(a, b):
...   if len(a.shape) != 2:
...     raise ValueError("a must be shape 2")
...   if len(b.shape) != 1:
...     raise ValueError("b must be shape 1")
...   return jnp.dot(a, b)

>>> out = BatchApply(raises)(a, b)
>>> expected_merged_leading = raises(a.reshape(2*3, 4), b)
>>> expected = expected_merged_leading.reshape((2, 3) + expected_merged_leading.shape[1:])
>>> np.testing.assert_array_equal(out, expected)
__call__(*args, **kwargs)[source]#

Call self as a function.