flax.linen.remat_scan

flax.linen.remat_scan(target, lengths=(), policy=None, variable_broadcast=False, variable_carry=False, variable_axes=FrozenDict({True: 0}), split_rngs=FrozenDict({True: True}))[source]

Combines remat and scan for memory efficiency and constant time compilation.

remat_scan allows for constant compile times and sublinear memory usage with respect to model depth. At a small constant penalty. This is typically beneficial for very deep models.

Example:

class BigModel(nn.Module):
  @nn.compact
  def __call__(self, x):
    DenseStack = nn.remat_scan(nn.Dense, lengths=(10, 10))
    # 100x dense with O(sqrt(N)) memory for gradient computation
    return DenseStack(8, name="dense_stack")(x)
Parameters
  • target (flax.linen.transforms.Target) – a Module or a function taking a Module as its first argument.

  • lengths (Optional[Sequence[int]]) – number of loop iterations at the given level. The total number of iterations n = prod(lengths). each loop is rematerialized. This way the memory consumption is proportional to n^(1 / d) where d = len(lengths). Minimal memory consumptions requires tuning the lengths such that the same amount of memory is consumed at each level of the nested loop.

  • policy (Optional[Callable[[...], bool]]) – Experimental checkpoint policy, see jax.checkpoint.

  • variable_broadcast (Union[bool, str, Collection[str], DenyList]) – Specifies the broadcasted variable collections. A broadcasted variable should not depend on any computation that cannot be lifted out of the loop. This is typically used to define shared parameters inside the fn.

  • variable_carry (Union[bool, str, Collection[str], DenyList]) – Specifies the variable collections that are carried through the loop. Mutations to these variables are carried to the next iteration and will be preserved when the scan finishes.

  • variable_axes (Mapping[Union[bool, str, Collection[str], DenyList], Union[int, flax.core.lift.In[int], flax.core.lift.Out[int]]]) – the variable collections that are scanned over. Defaults to {True: 0}.

  • split_rngs (Mapping[Union[bool, str, Collection[str], DenyList], bool]) – Split PRNG sequences will be different for each loop iterations. If split is False the PRNGs will be the same across iterations. Defaults to {True: True}.

Returns

A wrapped version of target that repeats itself prod(lengths) times.

Return type

flax.linen.transforms.Target