Flax NNX glossary#
For additional terms, refer to the JAX glossary.
- Filter#
A way to extract only certain nnx.Variable objects out of a Flax NNX Module (
nnx.Module
). This is usually done by callingnnx.split
upon thennx.Module
. Refer to the Filter guide to learn more.- Folding in#
In Flax, folding in means generating a new JAX pseudorandom number generator (PRNG) key, given an input PRNG key and integer. This is typically used when you want to generate a new key but still be able to use the original PRNG key afterwards. You can also do this in JAX with jax.random.split, but this method will effectively create two PRNG keys, which is slower. Learn how Flax generates new PRNG keys automatically in the Randomness/PRNG guide.
- GraphDef#
nnx.GraphDef
is a class that represents all the static, stateless, and Pythonic parts of a Flax Module (nnx.Module
).- Merge#
Refer to Split and merge.
- Module#
nnx.Module
is a dataclass that enables defining and initializing parameters in a referentially-transparent form. It is responsible for storing and updating :term:`Variable<Variable> objects and parameters within itself.- Params / parameters#
nnx.Param
is a particular subclass ofnnx.Variable
that generally contains the trainable weights.- PRNG states#
A Flax
nnx.Module
can keep a reference of a pseudorandom number generator (PRNG) state objectnnx.Rngs
that can generate new JAX PRNG keys. These keys are used to generate random JAX arrays through JAX’s functional PRNGs. You can use a PRNG state with different seeds to add more fine-grained control to your model (for example, to have independent random numbers for parameters and dropout masks). Refer to the Flax Randomness/PRNG guide for more details.- Split and merge#
nnx.split
is a way to represent annnx.Module
by two parts: 1) a static Flax NNX GraphDef that captures its Pythonic static information; and 2) one or more Variable state(s) that capture its JAX arrays (jax.Array
) in the form of JAX pytrees. They can be merged back to the originalnnx.Module
usingnnx.merge
.- Transformation#
A Flax NNX transformation (transform) is a wrapped version of a JAX transformation that allows the function that is being transformed to take the Flax NNX Module (
nnx.Module
) as input or output. For example, a “lifted” version of jax.jit isnnx.jit
. Check out the Flax NNX transforms guide to learn more.- Variable#
The weights / parameters / data / array
nnx.Variable
residing in a Flax Module. Variables are defined inside modules asnnx.Variable
or its subclasses.- Variable state#
nnx.VariableState
is a purely functional JAX pytree of all the Variables inside a Module. Since it is pure, it can be an input or output of a JAX transformation function.nnx.VariableState
is obtained by usingnnx.split
on thennx.Module
. (Refer to splitting and Module to learn more.)