For additional terms, refer to the Jax glossary.
- Bound Module#
Moduleis created through regular Python object construction (e.g. module = SomeModule(args…), it is in an unbound state. This means that only dataclass attributes are set, and no variables are bound to the module. When the pure functions
Module.apply()are called, Flax clones the Module and binds the variables to it, and the module’s method code is executed in a locally bound state, allowing things like calling submodules directly without providing variables. For more details, refer to the module lifecycle.
- Compact / Non-compact Module#
Modules with a single method are able to declare submodules and variables inline by using the
@nn.compactdecorator. These are referred to as “compact-style modules”, whereas modules defining a
setup()method (usually but not always with multiple callable methods) are referred to as “setup-style modules”. To learn more, refer to the setup vs compact guide.
- Folding in#
Generating a new PRNG key given an input PRNG key and integer. Typically used when you want to generate a new key but still be able to use the original rng key afterwards. You can also do this with jax.random.split but this will effectively create two RNG keys, which is slower.
An immutable dictionary which can be “unfrozen” to a regular, mutable dictionary. Internally, Flax uses FrozenDicts to ensure variable dicts aren’t accidentally mutated. Note: We are considering returning to regular dicts from our APIs, and only using FrozenDicts internally. (see #1223).
- Functional core#
The flax core library implements the simple container Scope API for threading variables and PRNGs through a model, as well as the lifting machinery needed to transform functions passing Scope objects. The python class-based module API is built on top of this core library.
- Lazy initialization#
Variables in Flax are initialized late, only when needed. That is, during normal execution of a module, if a requested variable name isn’t found in the provided variable collection data, we call the initializer function to create it. This allows us to treat initialization and application under the same code-paths, simplifying the use of JAX transforms with layers.
- Lifted transformation#
Refer to the Flax docs.
A dataclass allowing the definition and initialization of parameters in a referentially-transparent form. This is responsible for storing and updating variables and parameters within itself. Modules can be readily transformed into functions, allowing them to be trivially used with JAX transformations like vmap and scan.
- Params / parameters#
“params” is the canonical variable collection in the variable dictionary (dict). The “params” collection generally contains the trainable weights.
- RNG sequences#
Modules, you can obtain a new PRNG key through
Module.make_rng(). These keys can be used to generate random numbers through JAX’s functional random number generators. Having different RNG sequences (e.g. for “params” and “dropout”) allows fine-grained control in a multi-host setup (e.g. initializing parameters identically on different hosts, but have different dropout masks) and treating these sequences differently when lifting transformations.
A container class for holding the variables and PRNG keys for each layer.
- Shape inference#
Modules do not need to specify the shape of the input array in their definitions. Flax upon initialization inspects the input array, and infers the correct shapes for parameters in the model.
The weights / parameters / data / arrays residing in the leaves of variable collections. Variables are defined inside modules using
Module.variable(). A variable of collection “params” is simply called a param and can be set using
- Variable collections#
Entries in the variable dict, containing weights / parameters / data / arrays that are used by the model. “params” is the canonical collection in the variable dict. They are typically differentiable, updated by an outer SGD-like loop / optimizer, rather than modified directly by forward-pass code.
- Variable dictionary#
A dictionary containing variable collections. Each variable collection is a mapping from a string name (e.g., “params” or “batch_stats”) to a (possibly nested) dictionary with Variables as leaves, matching the submodule tree structure. Read more about pytrees and leaves in the Jax docs.