flax.linen.scan

flax.linen.scan(target, variable_axes=FrozenDict({}), variable_broadcast=False, variable_carry=False, split_rngs=FrozenDict({}), in_axes=0, out_axes=0, length=None, reverse=False, data_transform=None, methods=None)[source]

A lifted version of jax.lax.scan.

See jax.lax.scan for the unlifted scan in Jax.

To improve consistency with vmap, this version of scan uses in_axes and out_axes to determine which arguments are scanned over and along which axis.

scan distinguishes between 3 different types of values inside the loop:

  1. scan: a value that is iterated over in a loop. All scan values must have the same size in the axis they are scanned over. Scanned outputs will be stacked along the scan axis.

  2. carry: A carried value is updated at each loop iteration. It must have the same shape and dtype throughout the loop.

  3. broadcast: a value that is closed over by the loop. When a variable is broadcasted they are typically initialized inside the loop body but independent of the loop variables.

The loop body should have the signature (scope, body, carry, *xs) -> (carry, ys), where xs and ys are the scan values that go in and out of the loop.

Example:

import flax
import flax.linen as nn
from jax import random

class SimpleScan(nn.Module):
  @nn.compact
  def __call__(self, c, xs):
    LSTM = nn.scan(nn.LSTMCell,
                   variable_broadcast="params",
                   split_rngs={"params": False},
                   in_axes=1,
                   out_axes=1)
    return LSTM()(c, xs)

seq_len, batch_size, in_feat, out_feat = 20, 16, 3, 5
key_1, key_2, key_3 = random.split(random.PRNGKey(0), 3)

xs = random.uniform(key_1, (batch_size, seq_len, in_feat))
init_carry = nn.LSTMCell.initialize_carry(key_2, (batch_size,), out_feat)

model = SimpleScan()
variables = model.init(key_3, init_carry, xs)
out_carry, out_val = model.apply(variables, init_carry, xs)

assert out_val.shape == (batch_size, seq_len, out_feat)

Note that when providing a function to nn.scan, the scanning happens over all arguments starting from the third argument, as specified by in_axes. So in the following example, the input that are being scanned over are xs, *args*, and **kwargs:

def body_fn(cls, carry, xs, *args, **kwargs):
  extended_states = cls.some_fn(xs, carry, *args, **kwargs)
  return extended_states

scan_fn = nn.scan(
    body_fn,
    in_axes=0,  # scan over axis 0 from third arg of body_fn onwards.
    variable_axes=SCAN_VARIABLE_AXES,
    split_rngs=SCAN_SPLIT_RNGS)
Parameters
  • target (flax.linen.transforms.Target) – a Module or a function taking a Module as its first argument.

  • variable_axes (Mapping[Union[bool, str, Collection[str], DenyList], Union[int, flax.core.lift.In[int], flax.core.lift.Out[int]]]) – the variable collections that are scanned over.

  • variable_broadcast (Union[bool, str, Collection[str], DenyList]) – Specifies the broadcasted variable collections. A broadcasted variable should not depend on any computation that cannot be lifted out of the loop. This is typically used to define shared parameters inside the fn.

  • variable_carry (Union[bool, str, Collection[str], DenyList]) – Specifies the variable collections that are carried through the loop. Mutations to these variables are carried to the next iteration and will be preserved when the scan finishes.

  • split_rngs (Mapping[Union[bool, str, Collection[str], DenyList], bool]) – Split PRNG sequences will be different for each loop iterations. If split is False the PRNGs will be the same across iterations.

  • in_axes – Specifies the axis to scan over for the arguments. Should be a prefix tree of the arguments. Use flax.core.broadcast to feed an entire input to each iteration of the scan body.

  • out_axes – Specifies the axis to scan over for the return value. Should be a prefix tree of the return value.

  • length (Optional[int]) – Specifies the number of loop iterations. This only needs to be specified if it cannot be derivied from the scan arguments.

  • reverse (bool) – If true, scan from end to start in reverse order.

  • data_transform (Optional[Callable[[...], Any]]) – optional function to transform raw functional-core variable and rng groups inside lifted scan body_fn, intended for inline SPMD annotations.

  • methods – If target is a Module, the methods of Module to scan over.

Returns

The scan function with the signature (scope, carry, *xxs) -> (carry, yys), where xxs and yys are the scan values that go in and out of the loop.

Return type

flax.linen.transforms.Target