flax.linen.remat

flax.linen.remat(target, variables=True, rngs=True, concrete=False, prevent_cse=True, policy=None, methods=None)

Lifted version of jax.checkpoint.

This function is aliased to lift.remat just like jax.remat.

Parameters
  • target (flax.linen.transforms.Target) – a Module or a function taking a Module as its first argument. intermediate computations will be re-computed when computing gradients for the target.

  • variables (Union[bool, str, Collection[str], DenyList]) – The variable collections that are lifted. By default all collections are lifted.

  • rngs (Union[bool, str, Collection[str], DenyList]) – The PRNG sequences that are lifted. By default all PRNG sequences are lifted.

  • concrete (bool) – Optional, boolean indicating whether fun may involve value-dependent Python control flow (default False). Support for such control flow is optional, and disabled by default, because in some edge-case compositions with jax.jit() it can lead to some extra computation.

  • prevent_cse (bool) – Optional, boolean indicating whether to prevent common subexpression elimination (CSE) optimizations in the HLO generated from differentiation. This CSE prevention has costs because it can foil other optimizations, and because it can incur high overheads on some backends, especially GPU. The default is True because otherwise, under a jit or pmap, CSE can defeat the purpose of this decorator. But in some settings, like when used inside a scan, this CSE prevention mechanism is unnecessary, in which case prevent_cse should be set to False.

  • policy (Optional[Callable[[...], bool]]) – Experimental checkpoint policy, see jax.checkpoint.

  • methods – If target is a Module, the methods of Module to checkpoint.

Returns

A wrapped version of target. When computing gradients intermediate computations will be re-computed on the backward pass.

Return type

flax.linen.transforms.Target