The Flax Module lifecycle#

This design note is intended for users who are already familiar with Flax Linen Modules but want to understand more about the design principles behind the abstraction. This note should give you a good understanding of the assumptions and guarantees the Module API is built upon. If you have no practical experience with Modules yet, check out the Quickstart guide.

Flax Linen Modules offer a Pythonic abstraction on top of Flax core. The Module abstraction allows you to create classes that have state, parameters and randomness on top of JAX. This is a practical guide to the design and behavior of the Module class. By the end, you should feel comfortable to go off the beaten track and use Modules in new ways.

Overview#

Definition#

Let’s start with a high-level overview of the Module lifecycle. First, define a simple Module:

class MLP(nn.Module):
  # 1. Attribute annotations
  hidden_size: int
  out_size: int

  # 2. The ``setup`` method
  def setup(self):
    self.hidden = nn.Dense(self.hidden_size)
    self.out = nn.Dense(self.out_size)

  # 3. User methods
  def __call__(self, x):
    a = self.hidden(x)
    h = nn.relu(a)
    return self.out(h)

This Module consists of:

  1. Attribute annotations, defined as dataclass fields. These annotations automatically define a constructor.

  2. The ``setup`` method, which creates submodules and assigns them to attributes.

  3. User methods. By convention, most Modules have just one __call__ method, but you can define multiple methods or use different method names.

Construction/initialization#

Now we want to construct and use the MLP Module:

mlp = MLP(hidden_size=5, out_size=3)
x = jax.numpy.ones((1, 2))
variables = mlp.init(random.key(0), x)
y = mlp.apply(variables, x)

First, we construct an instance of MLP and pass the construction attributes. Note that construction here is different from what you might expect if you are not used to Functional Programming patterns. The MLP constructor does not actually create variables or any internal state whatsoever. It’s best to think of it as a specification or template of the Module that contains functionality but no data.

Let’s take a closer look at initialization. Surprisingly, there actually is no separate initialization path in Flax. Calling init is just a special case of apply, which you can also write as:

# equivalent to: variables = mlp.init(random.key(0), x)
_, variables = mlp.apply({}, x, rngs={"params": random.key(0)}, mutable=True)

Thus, init is nothing more than a wrapper around apply where:

  1. We call a Module without any initial variables (an empty dict).

  2. A PRNG generator named "params" is always passed for randomly initializing parameters (using the parameter initialization function).

  3. All variable collections are set to mutable (mutable=True). When a collection is mutable, existing variables can be updated and new variables can be created. Thus, inside init variables can be initialized in any variable collection and they are all added to the returned variable dictionary.

Lifecycle#

Now that you have learned about init being a special case of apply, let’s look at .apply(...) in more detail. In fact, most of the complexity of Modules resides in the apply method. The “Module lifecycle” consists of constructing and apply-ing a Module. We can summarize the Module lifecycle as follows:

  1. We construct mlp = MLP(hidden_size=5, out_size=3), such that mlp.hidden_size=5 and mlp.out_size=3.

  2. Then, call mlp.apply, which:

    1. Makes a clone of mlp, let’s call it mlp_copy.

    2. Calls mlp_copy.setup().

    3. Returns the output of mlp_copy.__call__() and optionally the variable collections that were specified as mutable using the keyword argument mutable=.

Notice that the lifecycle includes cloning the Module instance. This is done to ensure that apply can be treated as a pure function (i.e., if you pass the same arguments in, it will return the same outputs). You will learn about this in more detail later in the Top-level Modules section.

Variables#

The word “variable” is ubiquitous in programming and math. However, it’s important to have a good understanding of what variables are in the context of JAX and Flax. Inside Flax Modules, variables act like you expect from Python. They are initialized once, read, and perhaps even updated every so often. However, JAX has no concept of variables. Instead, values are stored in arrays similar to NumPy arrays - with one important difference: they are immutable.

The init and apply methods return the variables as a nested dictionary with string keys and JAX arrays at the leaves. At the top level each key corresponds to a variable collection. Inside each collection the nested dict structure corresponds with the Module hierarchy. The variable dict is immutable and therefore really just a snapshot of state the variables are in. When apply is called again, the variable dict is passed as an argument. Such that the variables are in the same state as when the previous init / apply call finished.

Note

Module fields are declared using the field_name: TypeHint syntax (same as dataclasses). Without a type hint, an attribute is considered a static property of the class. In case you cannot specify the type you can use typing.Any as a wildcard type.

Compact Modules#

Linen provides an alternative API for defining modules more compactly. This is especially useful for the common case where the Module consists of only one method that uses parameters and/or sub-modules. Using the compact API the MLP can be rewritten as follows:

class CompactMLP(nn.Module):
  hidden_size: int
  out_size: int

  @nn.compact
  def __call__(self, x):
    a = nn.Dense(self.hidden_size)(x)
    h = nn.relu(a)
    return nn.Dense(self.out_size)(h)

A compact Module is similar in spirit to a function. It offers a concise notation and restricts external interaction to the inputs and return values of the function. In this case the concise notation might make it easier for others to understand what the Module does. There is no need to jump back and forth between the setup and __call__ method to understand what the submodules are doing. Instead, simply reading the __call__ method from top to bottom once should provide a concise overview. This can make a significant difference if you are implementing complex Modules with many hyperparameters. See setup or compact for a practical guide on deciding between setup and compact.

Another benefit of defining submodules and/or variables inline is that you can add arguments to your method when constructing variables. The most common example of this is using shape information to determine the shape of a parameter like this:

class CompactScaledMLP(nn.Module):
  hidden_size: int
  out_size: int

  @nn.compact
  def __call__(self, x):
    scale = self.param("scale", nn.initializers.ones_init(), x.shape[-1:])
    x *= scale[None]
    a = nn.Dense(self.hidden_size)(x)
    h = nn.relu(a)
    return nn.Dense(self.out_size)(h)

Many of the standard Linen Modules like nn.Dense use shape inference already to avoid the need to specify input shapes (like the number of input features to a Dense layer).

Compact control flow#

The order in which you define submodules determines the name of a submodule if none is provided explicitly (using the name= keyword argument passed to the Module’s constructor). Because the name determines how parameters are mapped to submodules, you must be careful about mixing control flow with auto-generated names. Using control flow can change the order or remove certain submodules altogether. This is useful in case a submodule should only exist depending on some construction argument. However, when control flow depends on the input arguments to the Module, you should be careful. For example, the following Module will break:

class WrongModule(nn.Module):
  @nn.compact
  def __call__(self, x, mode):
    if mode == "encode":
      return nn.Dense(features=8)(x)
    elif mode == "decode":
      return nn.Dense(features=4)(x)

The above Module will break because either the encoder or decoder path will construct a Module named “Dense_0”. This means the two Modules will share parameters which is not intended here. Actually, the two Modules cannot share parameters because they each have a different number of features.

This problem can be solved in various ways:
  • Provide explicit names

  • create the modules in setup

  • or move the constructor out of the control flow.

The latter is done as follows:

class CorrectModule(nn.Module):
  @nn.compact
  def __call__(self, x, mode):
    encoder = nn.Dense(8)
    decoder = nn.Dense(4)
    if mode == "encode":
      return encoder(x)
    elif mode == "decode":
      return decoder(x)

In the above example the construction order is fixed. After construction the submodules can be used in an arbitrary order.

Note

compact modules show a strong resemblance to React hooks.

Top-level Modules#

When a Module instance is created at the “top-level”, it will be in an “unbound” state - that is, it has no variables attached. “Top-level” means it is not constructed as a sub-Module inside another Module class. Apart from calling init and apply, there is not much you can do with an unbound Module. Note also that setup is not called on unbound Modules, so you can only access the construction arguments. Refer to the Future work section to learn how this might change in the future.

Why are top-level Modules always unbound?#

When we call apply, a copy of the top-level Module is created which will actually hold the variables and PRNG sequences. This stateful, “bound”, clone only exists while we are executing the apply method. The reason for this is that if you create a stateful object and destroy it before the apply function returns, the apply function itself behaves like a pure function. A pure function has two constraints:

  1. If you put the same arguments in, it will return the same outputs

  2. It does not change anything outside the function. This means you cannot manipulate stateful objects that are accessible outside the pure function.

Pure functions have many advantages but when using JAX they are often essential. For example, most code requires compilation using jax.jit to be fast and once you created a Module you probably want to optimize its parameters using jax.grad. However, these APIs expect a pure function and don’t work on stateful bound Module instances directly. Moreover, pure functions allow for flexible interoperability with other libraries. For example, We recommend Optax for optimizing parameters. The optimizers in Optax expect and return a PyTree of JAX arrays to optimize, just like the apply function of a Linen Module.

Cloning#

To make this approach work reliably we need well-defined cloning behavior. Rather than relying on a complex nested cloning procedure like Python’s deepcopy, Flax enforces that a Module is exactly defined by its construction arguments. Therefore cloning a Module reduces to calling the constructor with its original construction arguments. Because Module acts as an immutable dataclass, the construction arguments are mapped directly to instance attributes. Non-construction attributes that are computed in setup or __post_init__ should also depend only on the construciton arguments to ensure a well-defined clone.

Bind#

Sometimes it’s useful to have a bound, top-level Module without having to wrap the code in a function. For example: to interact with a Module inside a Jupyter notebook. The bind method returns a bound clone with an unlimited lifetime. The downside of this is that you cannot combine it with JAX transformations or integrate it into a vanilla JAX codebase that expects stateless code. For example, Optax can optimize a Pytree of parameters but it cannot directly optimize a bound Module instance created with .bind (because that’s not a Pytree). Thus, you cannot combine the bind API with a functional optimizer API like Optax.

Setup#

The setup method is often used like the constructor hook (__init__) in normal Python classes. However, for more advanced use cases it’s good to realize that it is not quite the same as a constructor.

setup is only called after a Module becomes bound. Normally, this is not an issue because most Modules are bound (almost) immediately (as part of init and apply). Inside setup, sub-modules become bound when they are assigned to an attribute. Inside an nn.compact decorated method, sub-modules are bound immediately when constructed. As explained in the previous section, top-level Modules are never bound and thus setup is not called when they are constructed. This means you cannot access attributes assigned in setup from an unbound, top-level module.

class TopLevelAccess(nn.Module):

  def setup(self):
    self.foo = nn.Dense(2)

mdl = TopLevelAccess()
assert not hasattr(mdl, "foo")  # foo is not defined because setup is not called

The setup method is not called immediately after the Module becomes bound but only when you interact with the Module instance (e.g.: call a method or access an attribute). This should not impact the behavior of a Module but the lazy execution does sometimes affect log statements and stack traces during debugging. The section on Functionalization will explain why we need setup to be lazy in the first place.

Functionalization#

So far we had a pure apply function that is typically transformed with some JAX transformations and inside apply we have a stateful Module instance to work with. In other words: Outside of a Module we are in a functional world where we have the power of JAX’s functional transformations and inside the Module we get the power of Flax’s stateful variables and PRNG sequence, and the apply method is our bridge between these two worlds.

But what if we want to use JAX transformations inside Modules? The answer to this is functionalization.

This procedure itself is tedious and error-prone but handled internally by Flax. At a high-level we can summarize it as follows. For a method fn defined within a Module:

  1. Collect the state (variables & PRNG sequences) of the Module(s) that should be available inside the JAX transformation and take a snapshot of it.

  2. Call the JAX transformation with the original arguments and the collected state. Then inside the transformation:

    1. Unpack the state and recreate the Modules

    2. Call the user code fn

    3. Collect the updated variables and rng and return it together with the original return values from fn

  3. Update the original state with the updated state returned from the transformation.

A more in depth explanation of functionalization and lifting can be found in the Lifted Transformation design note.

Practical consequences#

For the most part functionalization is something that is handled automatically for you. Still there are some constraints that you must take into account. Most importantly, Flax only handles the stateful primitives (Linen variables and RNGs) and not arbitrary stateful Python code. Most importantly: You cannot close over stateful objects and Module objects because they are invisible to Flax’s internals (and to JAX in general).

class Foo(nn.Module):
  @nn.compact
  def __call__(self, x):
    dense = nn.Dense(x.shape[-1])
    fn = lambda x: dense(x) + 1
    # simply calling inner works fine
    # return self.inner(x, fn)
    # but applying a transformation doesn't:
    vmap_inner = nn.vmap(Foo.inner, in_axes=0, variable_axes={"params": 0}, split_rngs={"params": True})
    return vmap_inner(self, x, fn)

  def inner(self, x, fn):
    for i in range(3):
      x = fn(x)
    return x

Here inner takes a function that closes over a Module instance. In this example, that works fine because we are not transforming the inner method with a lifted transformation. Most methods are not transformed but it is good to know how to make Module methods transformable.

The main obstacle for transformability are types that JAX does not recognize. JAX only understands Pytree arguments; i.e. arbitrarily nested Python containers (dict, list, tuple) of (Jax) numpy ndarrays and Python numbers/bools. Flax allows to define dataclasses which are Pytree compatible using the flax.struct API.

Function closure is the most common way to accidentally hide a JAX array or Linen Module from a transformation. There is however an easy workaround if you want to pass closures that are also compatible with JAX and Linen transformations:

class Partial(flax.struct.PyTreeNode):
  fn: Callable = flax.struct.field(pytree_node=False)
  args: Iterable[Any]

  def __call__(self, *args, **kwargs):
    return self.fn(*(tuple(self.args) + args), **kwargs)

class Foo(nn.Module):

  @nn.compact
  def __call__(self, x):
    dense = nn.Dense(x.shape[-1])
    fn = lambda mdl, x: mdl(x) + 1
    vmap_inner = nn.vmap(Foo.inner, in_axes=0, variable_axes={"params": 0}, split_rngs={"params": True})
    return vmap_inner(self, x, Partial(fn, [dense]))

  def inner(self, x, fn):
    for i in range(3):
      x = fn(x)
    return x

Here the closure is implemented using a Flax dataclass. The function itself is annotated with flax.struct.field(pytree_node=False) to indicate that it does not contain JAX Arrays or Linen Modules. The partially applied args on the other hand is treated as a pytree container. We rewrite the closure to use Partial. Now the inner method can be transformed using lifted transformations.

Future work#

Setup for unbound Modules#

The current Module abstraction is particularly restrictive when it comes to initializing fields after construction. In the current Module API, the setup method is the place to initialize the fields of the Module instance. Because setup is only called on a bound Module, the full Module API is available inside setup, including variable declaration. However, oftentimes we don’t actually require any stateful API’s to initialize a field. In fact, most commonly we simply want to declare a submodule. More importantly, it’s often useful to inspect submodules for debugging or to partially run the model. Consider for example:

class AutoEncoder(nn.Module):
  def setup(self):
    self.encoder = Encoder(...)
    self.decoder = Decoder(...)

Imagine we want to call just the decoder using auto_encoder.decoder.apply(decoder_variables, x). With the current setup API this does not work because we must first bind the variables before setup is called and the decoder attribute is defined. Of course we can manually construct the Decoder Module with the same attributes as in setup but this is not ideal in many cases.

There are two possible solutions to make this use case more ergonomic. First, setup could be made to run immediately after construction before it becomes bound. This means you can still create sub modules but you can no longer define or manipulate variables. Therefore, this would be a breaking change and it would require a new API for defining variables lazily

Alternatively, an additional special method could be introduced that runs right away after Module construction and before it becomes bound. In this case, the setup method would preserve its original semantics.