flax.linen.remat
flax.linen.remat#
- flax.linen.remat(target, variables=True, rngs=True, concrete=False, prevent_cse=True, policy=None, methods=None)#
Lifted version of
jax.checkpoint
.This function is aliased to
lift.remat
just likejax.remat
.- Parameters
target – a
Module
or a function taking aModule
as its first argument. intermediate computations will be re-computed when computing gradients for the target.variables – The variable collections that are lifted. By default all collections are lifted.
rngs – The PRNG sequences that are lifted. By default all PRNG sequences are lifted.
concrete – Optional, boolean indicating whether
fun
may involve value-dependent Python control flow (default False). Support for such control flow is optional, and disabled by default, because in some edge-case compositions withjax.jit()
it can lead to some extra computation.prevent_cse – Optional, boolean indicating whether to prevent common subexpression elimination (CSE) optimizations in the HLO generated from differentiation. This CSE prevention has costs because it can foil other optimizations, and because it can incur high overheads on some backends, especially GPU. The default is True because otherwise, under a
jit
orpmap
, CSE can defeat the purpose of this decorator. But in some settings, like when used inside ascan
, this CSE prevention mechanism is unnecessary, in which caseprevent_cse
should be set to False.policy – Experimental checkpoint policy, see
jax.checkpoint
.methods – If target is a Module, the methods of Module to checkpoint.
- Returns
A wrapped version of
target
. When computing gradients intermediate computations will be re-computed on the backward pass.