flax.linen.remat_scan#
- flax.linen.remat_scan(target, lengths=(), policy=None, variable_broadcast=False, variable_carry=False, variable_axes=FrozenDict({True: 0}), split_rngs=FrozenDict({True: True}))[source]#
Combines remat and scan for memory efficiency and constant time compilation.
remat_scan
allows for constant compile times and sublinear memory usage with respect to model depth. At a small constant penalty. This is typically beneficial for very deep models.Example:
class BigModel(nn.Module): @nn.compact def __call__(self, x): DenseStack = nn.remat_scan(nn.Dense, lengths=(10, 10)) # 100x dense with O(sqrt(N)) memory for gradient computation return DenseStack(8, name="dense_stack")(x)
- Parameters
target – a
Module
or a function taking aModule
as its first argument.lengths – number of loop iterations at the given level. The total number of iterations n = prod(lengths). each loop is rematerialized. This way the memory consumption is proportional to n^(1 / d) where d = len(lengths). Minimal memory consumptions requires tuning the lengths such that the same amount of memory is consumed at each level of the nested loop.
policy – Experimental checkpoint policy, see
jax.checkpoint
.variable_broadcast – Specifies the broadcasted variable collections. A broadcasted variable should not depend on any computation that cannot be lifted out of the loop. This is typically used to define shared parameters inside the fn.
variable_carry – Specifies the variable collections that are carried through the loop. Mutations to these variables are carried to the next iteration and will be preserved when the scan finishes.
variable_axes – the variable collections that are scanned over. Defaults to
{True: 0}
.split_rngs – Split PRNG sequences will be different for each loop iterations. If split is False the PRNGs will be the same across iterations. Defaults to
{True: True}
.
- Returns
A wrapped version of
target
that repeats itself prod(lengths) times.