Frequently Asked Questions (FAQ)#

This is a collection of answers to frequently asked questions (FAQ). You can contribute to the Flax FAQ by starting a new topic in GitHub Discussions.

How to take the derivative with respect to an intermediate value (using Module.perturb)?#

To take the derivative(s) or gradient(s) of the output with respect to a hidden/intermediate activation inside a model layer, you can use flax.linen.Module.perturb(). You define a zero-value flax.linen.Module “perturbation” parameter – perturb(...) – in the forward pass with the same shape as the intermediate activation, define the loss function with 'perturbations' as an added standalone argument, perform a JAX derivative operation with jax.grad on the perturbation argument.

For full examples and detailed documentation, go to:

Is Flax Linen remat_scan() the same as scan(remat(...))?#

Flax remat_scan() (flax.linen.remat_scan()) and scan(remat(...)) (flax.linen.scan() over flax.linen.remat()) are not the same, and remat_scan() is limited in cases it supports. Namely, remat_scan() treats the inputs and outputs as carries (hidden states that are carried through the training loop). You are recommended to use scan(remat(...)), as typically you would need the extra parameters, such as in_axes (for input array axes) or out_axes (output array axes), which flax.linen.remat_scan() does not expose.