flax.linen.scan

Contents

flax.linen.scan#

flax.linen.scan(target, variable_axes=FrozenDict({}), variable_broadcast=False, variable_carry=False, split_rngs=FrozenDict({}), in_axes=0, out_axes=0, length=None, reverse=False, unroll=1, data_transform=None, metadata_params={}, methods=None)[source]#

A lifted version of jax.lax.scan.

See jax.lax.scan for the unlifted scan in Jax.

To improve consistency with vmap, this version of scan uses in_axes and out_axes to determine which arguments are scanned over and along which axis.

scan distinguishes between 3 different types of values inside the loop:

  1. scan: a value that is iterated over in a loop. All scan values must have the same size in the axis they are scanned over. Scanned outputs will be stacked along the scan axis.

  2. carry: A carried value is updated at each loop iteration. It must have the same shape and dtype throughout the loop.

  3. broadcast: a value that is closed over by the loop. When a variable is broadcasted they are typically initialized inside the loop body but independent of the loop variables.

The target should have the signature (module, carry, *xs) -> (carry, ys), where xs and ys are the scan values that go in and out of the loop.

Example:

>>> import flax.linen as nn
>>> import jax
>>> import jax.numpy as jnp
...
>>> class LSTM(nn.Module):
...   features: int
...
...   @nn.compact
...   def __call__(self, x):
...     ScanLSTM = nn.scan(
...       nn.LSTMCell, variable_broadcast="params",
...       split_rngs={"params": False}, in_axes=1, out_axes=1)
...
...     lstm = ScanLSTM(self.features)
...     input_shape =  x[:, 0].shape
...     carry = lstm.initialize_carry(jax.random.key(0), input_shape)
...     carry, x = lstm(carry, x)
...     return x
...
>>> x = jnp.ones((4, 12, 7))
>>> module = LSTM(features=32)
>>> y, variables = module.init_with_output(jax.random.key(0), x)

Note that when providing a function to nn.scan, the scanning happens over all arguments starting from the third argument, as specified by in_axes. The previous example could also be written using the functional form as:

>>> class LSTM(nn.Module):
...   features: int
...
...   @nn.compact
...   def __call__(self, x):
...
...     cell = nn.LSTMCell(self.features)
...     def body_fn(cell, carry, x):
...       carry, y = cell(carry, x)
...       return carry, y
...     scan = nn.scan(
...       body_fn, variable_broadcast="params",
...       split_rngs={"params": False}, in_axes=1, out_axes=1)
...
...     input_shape =  x[:, 0].shape
...     carry = cell.initialize_carry(
...       jax.random.key(0), input_shape)
...     carry, x = scan(cell, carry, x)
...     return x
...
>>> module = LSTM(features=32)
>>> variables = module.init(jax.random.key(0), jnp.ones((4, 12, 7)))

You can also use scan to reduce the compilation time of your JAX program by merging multiple layers into a single scan loop, you can do this when you have a sequence of identical layers that you want to apply iteratively to an input. For example:

>>> class ResidualMLPBlock(nn.Module):
...   @nn.compact
...   def __call__(self, x, _):
...     h = nn.Dense(features=2)(x)
...     h = nn.relu(h)
...     return x + h, None
...
>>> class ResidualMLP(nn.Module):
...   n_layers: int = 4
...
...   @nn.compact
...   def __call__(self, x):
...     ScanMLP = nn.scan(
...       ResidualMLPBlock, variable_axes={'params': 0},
...       variable_broadcast=False, split_rngs={'params': True},
...       length=self.n_layers)
...     x, _ = ScanMLP()(x, None)
...     return x
...
>>> model = ResidualMLP(n_layers=4)
>>> variables = model.init(jax.random.key(42), jnp.ones((1, 2)))

To reduce both compilation and memory usage, you can use remat_scan() which will in addition checkpoint each layer in the scan loop.

Parameters
  • target – a Module or a function taking a Module as its first argument.

  • variable_axes – the variable collections that are scanned over.

  • variable_broadcast – Specifies the broadcasted variable collections. A broadcasted variable should not depend on any computation that cannot be lifted out of the loop. This is typically used to define shared parameters inside the fn.

  • variable_carry – Specifies the variable collections that are carried through the loop. Mutations to these variables are carried to the next iteration and will be preserved when the scan finishes.

  • split_rngs – Split PRNG sequences will be different for each loop iterations. If split is False the PRNGs will be the same across iterations.

  • in_axes – Specifies the axis to scan over for the arguments. Should be a prefix tree of the arguments. Use flax.core.broadcast to feed an entire input to each iteration of the scan body.

  • out_axes – Specifies the axis to scan over for the return value. Should be a prefix tree of the return value.

  • length – Specifies the number of loop iterations. This only needs to be specified if it cannot be derived from the scan arguments.

  • reverse – If true, scan from end to start in reverse order.

  • unroll – how many scan iterations to unroll within a single iteration of a loop (default: 1).

  • data_transform – optional function to transform raw functional-core variable and rng groups inside lifted scan body_fn, intended for inline SPMD annotations.

  • metadata_params – arguments dict passed to AxisMetadata instances in the variable tree.

  • methods – If target is a Module, the methods of Module to scan over.

Returns

The scan function with the signature (module, carry, *xs) -> (carry, ys), where xs and ys are the scan values that go in and out of the loop.