Flax NNX glossary

Flax NNX glossary#

For additional terms, refer to the JAX glossary.

Filter#

A way to extract only certain nnx.Variable objects out of a Flax NNX Module (nnx.Module). This is usually done by calling nnx.split upon the nnx.Module. Refer to the Filter guide to learn more.

Folding in#

In Flax, folding in means generating a new JAX pseudorandom number generator (PRNG) key, given an input PRNG key and integer. This is typically used when you want to generate a new key but still be able to use the original PRNG key afterwards. You can also do this in JAX with jax.random.split, but this method will effectively create two PRNG keys, which is slower. Learn how Flax generates new PRNG keys automatically in the Randomness/PRNG guide.

GraphDef#

nnx.GraphDef is a class that represents all the static, stateless, and Pythonic parts of a Flax Module (nnx.Module).

Merge#

Refer to Split and merge.

Module#

nnx.Module is a dataclass that enables defining and initializing parameters in a referentially-transparent form. It is responsible for storing and updating :term:`Variable<Variable> objects and parameters within itself.

Params / parameters#

nnx.Param is a particular subclass of nnx.Variable that generally contains the trainable weights.

PRNG states#

A Flax nnx.Module can keep a reference of a pseudorandom number generator (PRNG) state object nnx.Rngs that can generate new JAX PRNG keys. These keys are used to generate random JAX arrays through JAX’s functional PRNGs. You can use a PRNG state with different seeds to add more fine-grained control to your model (for example, to have independent random numbers for parameters and dropout masks). Refer to the Flax Randomness/PRNG guide for more details.

Split and merge#

nnx.split is a way to represent an nnx.Module by two parts: 1) a static Flax NNX GraphDef that captures its Pythonic static information; and 2) one or more Variable state(s) that capture its JAX arrays (jax.Array) in the form of JAX pytrees. They can be merged back to the original nnx.Module using nnx.merge.

Transformation#

A Flax NNX transformation (transform) is a wrapped version of a JAX transformation that allows the function that is being transformed to take the Flax NNX Module (nnx.Module) as input or output. For example, a “lifted” version of jax.jit is nnx.jit. Check out the Flax NNX transforms guide to learn more.

Variable#

The weights / parameters / data / array nnx.Variable residing in a Flax Module. Variables are defined inside modules as nnx.Variable or its subclasses.

Variable state#

nnx.VariableState is a purely functional JAX pytree of all the Variables inside a Module. Since it is pure, it can be an input or output of a JAX transformation function. nnx.VariableState is obtained by using nnx.split on the nnx.Module. (Refer to splitting and Module to learn more.)