flax.linen.remat#
- flax.linen.remat(target, variables=True, rngs=True, concrete=False, prevent_cse=True, static_argnums=(), policy=None, methods=None)#
Lifted version of
jax.checkpoint
.Checkpointing is a technique for reducing memory usage by recomputing activations during backpropagation. When training large models, it can be helpful to checkpoint parts of the model to trade off memory usage for additional computation.
Example:
>>> import jax >>> import jax.numpy as jnp >>> import flax.linen as nn ... >>> class CheckpointedMLP(nn.Module): ... @nn.compact ... def __call__(self, x): ... CheckpointDense = nn.checkpoint(nn.Dense) ... x = CheckpointDense(128)(x) ... x = nn.relu(x) ... x = CheckpointDense(1)(x) ... return x ... >>> model = CheckpointedMLP() >>> variables = model.init(jax.random.key(0), jnp.ones((1, 16)))
This function is aliased to
remat
just likejax.remat
.- Parameters
target – a
Module
or a function taking aModule
as its first argument. intermediate computations will be re-computed when computing gradients for the target.variables – The variable collections that are lifted. By default all collections are lifted.
rngs – The PRNG sequences that are lifted. By default all PRNG sequences are lifted.
concrete – Optional, boolean indicating whether
fun
may involve value-dependent Python control flow (default False). Support for such control flow is optional, and disabled by default, because in some edge-case compositions withjax.jit()
it can lead to some extra computation.prevent_cse – Optional, boolean indicating whether to prevent common subexpression elimination (CSE) optimizations in the HLO generated from differentiation. This CSE prevention has costs because it can foil other optimizations, and because it can incur high overheads on some backends, especially GPU. The default is True because otherwise, under a
jit
orpmap
, CSE can defeat the purpose of this decorator. But in some settings, like when used inside ascan
, this CSE prevention mechanism is unnecessary, in which caseprevent_cse
should be set to False.static_argnums – Optional, int or sequence of ints, indicates which argument values on which to specialize for tracing and caching purposes. Specifying arguments as static can avoid ConcretizationTypeErrors when tracing, but at the cost of more retracing overheads.
policy – Experimental checkpoint policy, see
jax.checkpoint
.methods – An optional list of method names that will be lifted, if methods is None (default) only the __call__ method will be lifted. If target is a function, methods is ignored.
- Returns
A wrapped version of
target
. When computing gradients intermediate computations will be re-computed on the backward pass.