flax.linen.scan#
- flax.linen.scan(target, variable_axes=FrozenDict({}), variable_broadcast=False, variable_carry=False, split_rngs=FrozenDict({}), in_axes=0, out_axes=0, length=None, reverse=False, unroll=1, data_transform=None, metadata_params={}, methods=None)[source]#
A lifted version of
jax.lax.scan
.See
jax.lax.scan
for the unlifted scan in Jax.To improve consistency with
vmap
, this version of scan usesin_axes
andout_axes
to determine which arguments are scanned over and along which axis.scan
distinguishes between 3 different types of values inside the loop:scan: a value that is iterated over in a loop. All scan values must have the same size in the axis they are scanned over. Scanned outputs will be stacked along the scan axis.
carry: A carried value is updated at each loop iteration. It must have the same shape and dtype throughout the loop.
broadcast: a value that is closed over by the loop. When a variable is broadcasted they are typically initialized inside the loop body but independent of the loop variables.
The
target
should have the signature(module, carry, *xs) -> (carry, ys)
, wherexs
andys
are the scan values that go in and out of the loop.Example:
>>> import flax.linen as nn >>> import jax >>> import jax.numpy as jnp ... >>> class LSTM(nn.Module): ... features: int ... ... @nn.compact ... def __call__(self, x): ... ScanLSTM = nn.scan( ... nn.LSTMCell, variable_broadcast="params", ... split_rngs={"params": False}, in_axes=1, out_axes=1) ... ... lstm = ScanLSTM(self.features) ... input_shape = x[:, 0].shape ... carry = lstm.initialize_carry(jax.random.key(0), input_shape) ... carry, x = lstm(carry, x) ... return x ... >>> x = jnp.ones((4, 12, 7)) >>> module = LSTM(features=32) >>> y, variables = module.init_with_output(jax.random.key(0), x)
Note that when providing a function to
nn.scan
, the scanning happens over all arguments starting from the third argument, as specified byin_axes
. The previous example could also be written using the functional form as:>>> class LSTM(nn.Module): ... features: int ... ... @nn.compact ... def __call__(self, x): ... ... cell = nn.LSTMCell(self.features) ... def body_fn(cell, carry, x): ... carry, y = cell(carry, x) ... return carry, y ... scan = nn.scan( ... body_fn, variable_broadcast="params", ... split_rngs={"params": False}, in_axes=1, out_axes=1) ... ... input_shape = x[:, 0].shape ... carry = cell.initialize_carry( ... jax.random.key(0), input_shape) ... carry, x = scan(cell, carry, x) ... return x ... >>> module = LSTM(features=32) >>> variables = module.init(jax.random.key(0), jnp.ones((4, 12, 7)))
You can also use
scan
to reduce the compilation time of your JAX program by merging multiple layers into a single scan loop, you can do this when you have a sequence of identical layers that you want to apply iteratively to an input. For example:>>> class ResidualMLPBlock(nn.Module): ... @nn.compact ... def __call__(self, x, _): ... h = nn.Dense(features=2)(x) ... h = nn.relu(h) ... return x + h, None ... >>> class ResidualMLP(nn.Module): ... n_layers: int = 4 ... ... @nn.compact ... def __call__(self, x): ... ScanMLP = nn.scan( ... ResidualMLPBlock, variable_axes={'params': 0}, ... variable_broadcast=False, split_rngs={'params': True}, ... length=self.n_layers) ... x, _ = ScanMLP()(x, None) ... return x ... >>> model = ResidualMLP(n_layers=4) >>> variables = model.init(jax.random.key(42), jnp.ones((1, 2)))
To reduce both compilation and memory usage, you can use
remat_scan()
which will in addition checkpoint each layer in the scan loop.- Parameters
target – a
Module
or a function taking aModule
as its first argument.variable_axes – the variable collections that are scanned over.
variable_broadcast – Specifies the broadcasted variable collections. A broadcasted variable should not depend on any computation that cannot be lifted out of the loop. This is typically used to define shared parameters inside the fn.
variable_carry – Specifies the variable collections that are carried through the loop. Mutations to these variables are carried to the next iteration and will be preserved when the scan finishes.
split_rngs – Split PRNG sequences will be different for each loop iterations. If split is False the PRNGs will be the same across iterations.
in_axes – Specifies the axis to scan over for the arguments. Should be a prefix tree of the arguments. Use flax.core.broadcast to feed an entire input to each iteration of the scan body.
out_axes – Specifies the axis to scan over for the return value. Should be a prefix tree of the return value.
length – Specifies the number of loop iterations. This only needs to be specified if it cannot be derived from the scan arguments.
reverse – If true, scan from end to start in reverse order.
unroll – how many scan iterations to unroll within a single iteration of a loop (default: 1).
data_transform – optional function to transform raw functional-core variable and rng groups inside lifted scan body_fn, intended for inline SPMD annotations.
metadata_params – arguments dict passed to AxisMetadata instances in the variable tree.
methods – If target is a Module, the methods of Module to scan over.
- Returns
The scan function with the signature
(module, carry, *xs) -> (carry, ys)
, wherexs
andys
are the scan values that go in and out of the loop.