flax.linen.scan(target, variable_axes=FrozenDict({}), variable_broadcast=False, variable_carry=False, split_rngs=FrozenDict({}), in_axes=0, out_axes=0, length=None, reverse=False, unroll=1, data_transform=None, methods=None)[source]#

A lifted version of jax.lax.scan.

See jax.lax.scan for the unlifted scan in Jax.

To improve consistency with vmap, this version of scan uses in_axes and out_axes to determine which arguments are scanned over and along which axis.

scan distinguishes between 3 different types of values inside the loop:

  1. scan: a value that is iterated over in a loop. All scan values must have the same size in the axis they are scanned over. Scanned outputs will be stacked along the scan axis.

  2. carry: A carried value is updated at each loop iteration. It must have the same shape and dtype throughout the loop.

  3. broadcast: a value that is closed over by the loop. When a variable is broadcasted they are typically initialized inside the loop body but independent of the loop variables.

The loop body should have the signature (scope, body, carry, *xs) -> (carry, ys), where xs and ys are the scan values that go in and out of the loop.


import flax
import flax.linen as nn
from jax import random

class SimpleScan(nn.Module):
  def __call__(self, c, xs):
    LSTM = nn.scan(nn.LSTMCell,
                   split_rngs={"params": False},
    return LSTM()(c, xs)

seq_len, batch_size, in_feat, out_feat = 20, 16, 3, 5
key_1, key_2, key_3 = random.split(random.PRNGKey(0), 3)

xs = random.uniform(key_1, (batch_size, seq_len, in_feat))
init_carry = nn.LSTMCell.initialize_carry(key_2, (batch_size,), out_feat)

model = SimpleScan()
variables = model.init(key_3, init_carry, xs)
out_carry, out_val = model.apply(variables, init_carry, xs)

assert out_val.shape == (batch_size, seq_len, out_feat)

Note that when providing a function to nn.scan, the scanning happens over all arguments starting from the third argument, as specified by in_axes. So in the following example, the input that are being scanned over are xs, *args*, and **kwargs:

def body_fn(cls, carry, xs, *args, **kwargs):
  extended_states = cls.some_fn(xs, carry, *args, **kwargs)
  return extended_states

scan_fn = nn.scan(
    in_axes=0,  # scan over axis 0 from third arg of body_fn onwards.
  • target – a Module or a function taking a Module as its first argument.

  • variable_axes – the variable collections that are scanned over.

  • variable_broadcast – Specifies the broadcasted variable collections. A broadcasted variable should not depend on any computation that cannot be lifted out of the loop. This is typically used to define shared parameters inside the fn.

  • variable_carry – Specifies the variable collections that are carried through the loop. Mutations to these variables are carried to the next iteration and will be preserved when the scan finishes.

  • split_rngs – Split PRNG sequences will be different for each loop iterations. If split is False the PRNGs will be the same across iterations.

  • in_axes – Specifies the axis to scan over for the arguments. Should be a prefix tree of the arguments. Use flax.core.broadcast to feed an entire input to each iteration of the scan body.

  • out_axes – Specifies the axis to scan over for the return value. Should be a prefix tree of the return value.

  • length – Specifies the number of loop iterations. This only needs to be specified if it cannot be derivied from the scan arguments.

  • reverse – If true, scan from end to start in reverse order.

  • unroll – how many scan iterations to unroll within a single iteration of a loop (default: 1).

  • data_transform – optional function to transform raw functional-core variable and rng groups inside lifted scan body_fn, intended for inline SPMD annotations.

  • methods – If target is a Module, the methods of Module to scan over.


The scan function with the signature (scope, carry, *xxs) -> (carry, yys), where xxs and yys are the scan values that go in and out of the loop.