Frequently Asked Questions (FAQ)#
This is a collection of answers to frequently asked questions (FAQ). You can contribute to the Flax FAQ by starting a new topic in GitHub Discussions.
How to take the derivative with respect to an intermediate value (using
To take the derivative(s) or gradient(s) of the output with respect to a hidden/intermediate activation inside a model layer, you can use
flax.linen.Module.perturb(). You define a zero-value
flax.linen.Module “perturbation” parameter –
perturb(...) – in the forward pass with the same shape as the intermediate activation, define the loss function with
'perturbations' as an added standalone argument, perform a JAX derivative operation with
jax.grad on the perturbation argument.
For full examples and detailed documentation, go to:
Is Flax Linen
remat_scan() the same as
flax.linen.remat()) are not the same, and
remat_scan() is limited in cases it supports. Namely,
remat_scan() treats the inputs and outputs as carries (hidden states that are carried through the training loop). You are recommended to use
scan(remat(...)), as typically you would need the extra parameters, such as
in_axes (for input array axes) or
out_axes (output array axes), which
flax.linen.remat_scan() does not expose.
What are the recommended training loop libraries?#
Consider using CLU (Common Loop Utils) google/CommonLoopUtils. To get started, go to this CLU Synopsis Colab. You can find answers to common questions about CLU with Flax on google/flax GitHub Discussions.
For computer vision research, consider google-research/scenic. Scenic is a set of shared light-weight libraries solving commonly encountered tasks when training large-scale vision models (with examples of several projects). Scenic is developed in JAX with Flax. To get started, go to the README page on GitHub.