- flax.linen.remat(target, variables=True, rngs=True, concrete=False, prevent_cse=True, policy=None, methods=None)¶
Lifted version of
This function is aliased to
target (flax.linen.transforms.Target) – a
Moduleor a function taking a
Moduleas its first argument. intermediate computations will be re-computed when computing gradients for the target.
variables (Union[bool, str, Collection[str], DenyList]) – The variable collections that are lifted. By default all collections are lifted.
rngs (Union[bool, str, Collection[str], DenyList]) – The PRNG sequences that are lifted. By default all PRNG sequences are lifted.
concrete (bool) – Optional, boolean indicating whether
funmay involve value-dependent Python control flow (default False). Support for such control flow is optional, and disabled by default, because in some edge-case compositions with
jax.jit()it can lead to some extra computation.
prevent_cse (bool) – Optional, boolean indicating whether to prevent common subexpression elimination (CSE) optimizations in the HLO generated from differentiation. This CSE prevention has costs because it can foil other optimizations, and because it can incur high overheads on some backends, especially GPU. The default is True because otherwise, under a
pmap, CSE can defeat the purpose of this decorator. But in some settings, like when used inside a
scan, this CSE prevention mechanism is unnecessary, in which case
prevent_cseshould be set to False.
policy (Optional[Callable[[...], bool]]) – Experimental checkpoint policy, see
methods – If target is a Module, the methods of Module to checkpoint.
A wrapped version of
target. When computing gradients intermediate computations will be re-computed on the backward pass.
- Return type