flax.linen.remat_scan(target, lengths=(), policy=None, variable_broadcast=False, variable_carry=False, variable_axes=FrozenDict({True: 0}), split_rngs=FrozenDict({True: True}))[source]#

Combines remat and scan for memory efficiency and constant time compilation.

remat_scan allows for constant compile times and sublinear memory usage with respect to model depth. At a small constant penalty. This is typically beneficial for very deep models.


>>> import flax.linen as nn

>>> class BigModel(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     DenseStack = nn.remat_scan(nn.Dense, lengths=(10, 10))
...     # 100x dense with O(sqrt(N)) memory for gradient computation
...     return DenseStack(8, name="dense_stack")(x)
  • target – a Module or a function taking a Module as its first argument.

  • lengths – number of loop iterations at the given level. The total number of iterations n = prod(lengths). each loop is rematerialized. This way the memory consumption is proportional to n^(1 / d) where d = len(lengths). Minimal memory consumptions requires tuning the lengths such that the same amount of memory is consumed at each level of the nested loop.

  • policy – Experimental checkpoint policy, see jax.checkpoint.

  • variable_broadcast – Specifies the broadcasted variable collections. A broadcasted variable should not depend on any computation that cannot be lifted out of the loop. This is typically used to define shared parameters inside the fn.

  • variable_carry – Specifies the variable collections that are carried through the loop. Mutations to these variables are carried to the next iteration and will be preserved when the scan finishes.

  • variable_axes – the variable collections that are scanned over. Defaults to {True: 0}.

  • split_rngs – Split PRNG sequences will be different for each loop iterations. If split is False the PRNGs will be the same across iterations. Defaults to {True: True}.


A wrapped version of target that repeats itself prod(lengths) times.