flax.linen.remat

Contents

flax.linen.remat#

flax.linen.remat(target, variables=True, rngs=True, concrete=False, prevent_cse=True, static_argnums=(), policy=None, methods=None)#

Lifted version of jax.checkpoint.

Checkpointing is a technique for reducing memory usage by recomputing activations during backpropagation. When training large models, it can be helpful to checkpoint parts of the model to trade off memory usage for additional computation.

Example:

>>> import jax
>>> import jax.numpy as jnp
>>> import flax.linen as nn
...
>>> class CheckpointedMLP(nn.Module):
...   @nn.checkpoint
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(128)(x)
...     x = nn.relu(x)
...     x = nn.Dense(1)(x)
...     return x
...
>>> model = CheckpointedMLP()
>>> variables = model.init(jax.random.key(0), jnp.ones((1, 16)))

This function is aliased to remat just like jax.remat.

Parameters
  • target – a Module or a function taking a Module as its first argument. intermediate computations will be re-computed when computing gradients for the target.

  • variables – The variable collections that are lifted. By default all collections are lifted.

  • rngs – The PRNG sequences that are lifted. By default all PRNG sequences are lifted.

  • concrete – Optional, boolean indicating whether fun may involve value-dependent Python control flow (default False). Support for such control flow is optional, and disabled by default, because in some edge-case compositions with jax.jit() it can lead to some extra computation.

  • prevent_cse – Optional, boolean indicating whether to prevent common subexpression elimination (CSE) optimizations in the HLO generated from differentiation. This CSE prevention has costs because it can foil other optimizations, and because it can incur high overheads on some backends, especially GPU. The default is True because otherwise, under a jit or pmap, CSE can defeat the purpose of this decorator. But in some settings, like when used inside a scan, this CSE prevention mechanism is unnecessary, in which case prevent_cse should be set to False.

  • static_argnums – Optional, int or sequence of ints, indicates which argument values on which to specialize for tracing and caching purposes. Specifying arguments as static can avoid ConcretizationTypeErrors when tracing, but at the cost of more retracing overheads.

  • policy – Experimental checkpoint policy, see jax.checkpoint.

  • methods – An optional list of method names that will be lifted, if methods is None (default) only the __call__ method will be lifted. If``target`` is a function, methods is ignored.

Returns

A wrapped version of target. When computing gradients intermediate computations will be re-computed on the backward pass.