flax.linen.scan
flax.linen.scan#
- flax.linen.scan(target, variable_axes=FrozenDict({}), variable_broadcast=False, variable_carry=False, split_rngs=FrozenDict({}), in_axes=0, out_axes=0, length=None, reverse=False, unroll=1, data_transform=None, methods=None)[source]#
A lifted version of
jax.lax.scan
.See
jax.lax.scan
for the unlifted scan in Jax.To improve consistency with
vmap
, this version of scan usesin_axes
andout_axes
to determine which arguments are scanned over and along which axis.scan
distinguishes between 3 different types of values inside the loop:scan: a value that is iterated over in a loop. All scan values must have the same size in the axis they are scanned over. Scanned outputs will be stacked along the scan axis.
carry: A carried value is updated at each loop iteration. It must have the same shape and dtype throughout the loop.
broadcast: a value that is closed over by the loop. When a variable is broadcasted they are typically initialized inside the loop body but independent of the loop variables.
The loop body should have the signature
(scope, body, carry, *xs) -> (carry, ys)
, wherexs
andys
are the scan values that go in and out of the loop.Example:
import flax import flax.linen as nn from jax import random class SimpleScan(nn.Module): @nn.compact def __call__(self, c, xs): LSTM = nn.scan(nn.LSTMCell, variable_broadcast="params", split_rngs={"params": False}, in_axes=1, out_axes=1) return LSTM()(c, xs) seq_len, batch_size, in_feat, out_feat = 20, 16, 3, 5 key_1, key_2, key_3 = random.split(random.PRNGKey(0), 3) xs = random.uniform(key_1, (batch_size, seq_len, in_feat)) init_carry = nn.LSTMCell.initialize_carry(key_2, (batch_size,), out_feat) model = SimpleScan() variables = model.init(key_3, init_carry, xs) out_carry, out_val = model.apply(variables, init_carry, xs) assert out_val.shape == (batch_size, seq_len, out_feat)
Note that when providing a function to
nn.scan
, the scanning happens over all arguments starting from the third argument, as specified byin_axes
. So in the following example, the input that are being scanned over arexs
,*args*
, and**kwargs
:def body_fn(cls, carry, xs, *args, **kwargs): extended_states = cls.some_fn(xs, carry, *args, **kwargs) return extended_states scan_fn = nn.scan( body_fn, in_axes=0, # scan over axis 0 from third arg of body_fn onwards. variable_axes=SCAN_VARIABLE_AXES, split_rngs=SCAN_SPLIT_RNGS)
- Parameters
target – a
Module
or a function taking aModule
as its first argument.variable_axes – the variable collections that are scanned over.
variable_broadcast – Specifies the broadcasted variable collections. A broadcasted variable should not depend on any computation that cannot be lifted out of the loop. This is typically used to define shared parameters inside the fn.
variable_carry – Specifies the variable collections that are carried through the loop. Mutations to these variables are carried to the next iteration and will be preserved when the scan finishes.
split_rngs – Split PRNG sequences will be different for each loop iterations. If split is False the PRNGs will be the same across iterations.
in_axes – Specifies the axis to scan over for the arguments. Should be a prefix tree of the arguments. Use flax.core.broadcast to feed an entire input to each iteration of the scan body.
out_axes – Specifies the axis to scan over for the return value. Should be a prefix tree of the return value.
length – Specifies the number of loop iterations. This only needs to be specified if it cannot be derivied from the scan arguments.
reverse – If true, scan from end to start in reverse order.
unroll – how many scan iterations to unroll within a single iteration of a loop (default: 1).
data_transform – optional function to transform raw functional-core variable and rng groups inside lifted scan body_fn, intended for inline SPMD annotations.
methods – If target is a Module, the methods of Module to scan over.
- Returns
The scan function with the signature
(scope, carry, *xxs) -> (carry, yys)
, wherexxs
andyys
are the scan values that go in and out of the loop.