Evolution from Flax Linen to NNX#

This guide demonstrates the differences between Flax Linen and Flax NNX models, providing side-by-side example code to help you migrate to the Flax NNX API from Flax Linen.

This document mainly teaches how to convert arbitrary Flax Linen code to Flax NNX. If you want to play it “safe” and convert your codebase iteratively, check out the Use Flax NNX and Linen together via nnx.bridge guide.

To get the most out of this guide, it is highly recommended to get go through Flax NNX basics document, which covers the nnx.Module system, Flax transformations, and the Functional API with examples.

Basic Module definition#

Both Flax Linen and Flax NNX use the Module class as the default unit to express a neural network library layer. In the example below, you first create a Block (by subclassing Module) composed of one linear layer with dropout and a ReLU activation function; then you use it as a sub-Module when creating a Model (also by subclassing Module), which is made up of Block and a linear layer.

There are two fundamental differences between Flax Linen and Flax NNX Module objects:

  • Stateless vs. stateful: A flax.linen.Module (nn.Module) instance is stateless - the variables are returned from a purely functional Module.init() call and managed separately. A flax.nnx.Module, however, owns its variables as attributes of this Python object.

  • Lazy vs. eager: A flax.linen.Module only allocates space to create variables when they actually see their input (lazy). A flax.nnx.Module instance creates variables the moment they are instantiated before seeing a sample input (eager).

  • Flax Linen can use the @nn.compact decorator to define the model in a single method, and use shape inference from the input sample. A Flax NNX Module generally requests additional shape information to create all parameters during __init__ , and separately defines the computation in the __call__ method.

import flax.linen as nn

class Block(nn.Module):
  features: int


  @nn.compact
  def __call__(self, x, training: bool):
    x = nn.Dense(self.features)(x)
    x = nn.Dropout(0.5, deterministic=not training)(x)
    x = jax.nn.relu(x)
    return x

class Model(nn.Module):
  dmid: int
  dout: int

  @nn.compact
  def __call__(self, x, training: bool):
    x = Block(self.dmid)(x, training)
    x = nn.Dense(self.dout)(x)
    return x
from flax import nnx

class Block(nnx.Module):
  def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
    self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
    self.dropout = nnx.Dropout(0.5, rngs=rngs)

  def __call__(self, x):
    x = self.linear(x)
    x = self.dropout(x)
    x = jax.nn.relu(x)
    return x

class Model(nnx.Module):
  def __init__(self, din: int, dmid: int, dout: int, rngs: nnx.Rngs):
    self.block = Block(din, dmid, rngs=rngs)
    self.linear = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x):
    x = self.block(x)
    x = self.linear(x)
    return x

Variable creation#

Next, let’s discuss instantiating the model and initializing its parameters:

  • To generate model parameters for a Flax Linen model, you call the flax.linen.Module.init (nn.Module.init) method with a jax.random.key (doc) plus some sample inputs that the model shall take. This results in a nested dictionary of JAX Arrays (jax.Array data types) to be carried around and maintained separately.

  • In Flax NNX, the model parameters are automatically initialized when you instantiate the model, and the variables (nnx.Variable objects) are stored inside the nnx.Module (or its sub-Module) as attributes. You still need to provide it with a pseudorandom number generator (PRNG) key, but that key will be wrapped inside an nnx.Rngs class and stored inside, generating more PRNG keys when needed.

If you want to access Flax NNX model parameters in the stateless, dictionary-like fashion for checkpoint saving or model surgery, check out the Flax NNX split/merge API (nnx.split / nnx.merge).

model = Model(256, 10)
sample_x = jnp.ones((1, 784))
variables = model.init(jax.random.key(0), sample_x, training=False)
params = variables["params"]

assert params['Dense_0']['bias'].shape == (10,)
assert params['Block_0']['Dense_0']['kernel'].shape == (784, 256)
model = Model(784, 256, 10, rngs=nnx.Rngs(0))


# Parameters were already initialized during model instantiation.

assert model.linear.bias.value.shape == (10,)
assert model.block.linear.kernel.value.shape == (784, 256)

Training step and compilation#

Now, let’s proceed to writing a training step and compiling it using JAX just-in-time compilation. Below are certain differences between Flax Linen and Flax NNX approaches.

Compiling the training step:

  • Flax Linen uses @jax.jit - a JAX transform - to compile the training step.

  • Flax NNX uses @nnx.jit - a Flax NNX transform (one of several transform APIs that behave similarly to JAX transforms, but also work well with Flax NNX objects). So, while jax.jit only accepts functions pure stateless arguments, nnx.jit allows the arguments to be stateful NNX Modules. This greatly reduced the number of lines needed for a train step.

Taking gradients:

  • Similarly, Flax Linen uses jax.grad (a JAX transform for automatic differentiation) to return a raw dictionary of gradients.

  • Flax NNX uses nnx.grad (a Flax NNX transform) to return the gradients of NNX Modules as nnx.State dictionaries. If you want to use regular jax.grad with Flax NNX you need to use the Flax NNX split/merge API.

Optimizers:

  • If you are already using Optax optimizers like optax.adamw (instead of the raw jax.tree.map computation shown here) with Flax Linen, check out the nnx.Optimizer example in the Flax NNX basics guide for a much more concise way of training and updating your model.

Model updates during each training step:

  • The Flax Linen training step needs to return a pytree of parameters as the input of the next step.

  • The Flax NNX training step doesn’t need to return anything, because the model was already updated in-place within nnx.jit.

  • In addition, nnx.Module objects are stateful, and Module automatically tracks several things within it, such as PRNG keys and BatchNorm stats. That is why you don’t need to explicitly pass an PRNG key in on every step. Also note that you can use nnx.reseed to reset its underlying PRNG state.

Dropout behavior:

  • In Flax Linen, you need to explicitly define and pass in the training argument to control the behavior of flax.linen.Dropout (nn.Dropout), namely, its deterministic flag, which means random dropout only happens if training=True.

  • In Flax NNX, you can call model.train() (flax.nnx.Module.train()) to automatically switch nnx.Dropout to the training mode. Conversely, you can call model.eval() (flax.nnx.Module.eval()) to turn off the training mode. You can learn more about what nnx.Module.train does in its API reference.

...

@jax.jit
def train_step(key, params, inputs, labels):
  def loss_fn(params):
    logits = model.apply(
      {'params': params},
      inputs, training=True, # <== inputs
      rngs={'dropout': key}
    )
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = jax.grad(loss_fn)(params)

  params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
  return params
model.train() # Sets ``deterministic=False` under the hood for nnx.Dropout

@nnx.jit
def train_step(model, inputs, labels):
  def loss_fn(model):
    logits = model(inputs)




    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = nnx.grad(loss_fn)(model)
  _, params, rest = nnx.split(model, nnx.Param, ...)
  params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
  nnx.update(model, nnx.merge_state(params, rest))

Collections and variable types#

One key difference between Flax Linen and NNX APIs is how they group variables into categories. Flax Linen uses different collections, while Flax NNX, since all variables shall be top-level Python attributes, you use different variable types.

In Flax NNX, you can freely create your own variable types as subclasses of nnx.Variable.

For all the built-in Flax Linen layers and collections, Flax NNX already creates the corresponding layers and variable types. For example:

  • flax.linen.Dense (nn.Dense) creates params -> nnx.Linear creates :class:nnx.Param<flax.nnx.Param>`.

  • flax.linen.BatchNorm (nn.BatchNorm) creates batch_stats -> nnx.BatchNorm creates nnx.BatchStats.

  • flax.linen.Module.sow() creates intermediates -> nnx.Module.sow() creates nnx.Intermediaries.

  • In Flax NNX, you can also simply obtain the intermediates by assigning it to an nnx.Module attribute - for example, self.sowed = nnx.Intermediates(x). This will be similar to Flax Linen’s self.variable('intermediates' 'sowed', lambda: x).

class Block(nn.Module):
  features: int
  def setup(self):
    self.dense = nn.Dense(self.features)
    self.batchnorm = nn.BatchNorm(momentum=0.99)
    self.count = self.variable('counter', 'count',
                                lambda: jnp.zeros((), jnp.int32))


  @nn.compact
  def __call__(self, x, training: bool):
    x = self.dense(x)
    x = self.batchnorm(x, use_running_average=not training)
    self.count.value += 1
    x = jax.nn.relu(x)
    return x

x = jax.random.normal(jax.random.key(0), (2, 4))
model = Block(4)
variables = model.init(jax.random.key(0), x, training=True)
variables['params']['dense']['kernel'].shape         # (4, 4)
variables['batch_stats']['batchnorm']['mean'].shape  # (4, )
variables['counter']['count']                        # 1
class Counter(nnx.Variable): pass

class Block(nnx.Module):
  def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
    self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
    self.batchnorm = nnx.BatchNorm(
      num_features=out_features, momentum=0.99, rngs=rngs
    )
    self.count = Counter(jnp.array(0))

  def __call__(self, x):
    x = self.linear(x)
    x = self.batchnorm(x)
    self.count += 1
    x = jax.nn.relu(x)
    return x



model = Block(4, 4, rngs=nnx.Rngs(0))

model.linear.kernel   # Param(value=...)
model.batchnorm.mean  # BatchStat(value=...)
model.count           # Counter(value=...)

If you want to extract certain arrays from the pytree of variables:

  • In Flax Linen, you can access the specific dictionary path.

  • In Flax NNX, you can use nnx.split to distinguish the types apart in Flax NNX. The code below is a simple example that splits up the variables by their types - check out the Flax NNX Filters guide for more sophisticated filtering expressions.

params, batch_stats, counter = (
  variables['params'], variables['batch_stats'], variables['counter'])
params.keys()       # ['dense', 'batchnorm']
batch_stats.keys()  # ['batchnorm']
counter.keys()      # ['count']

# ... make arbitrary modifications ...
# Merge back with raw dict to carry on:
variables = {'params': params, 'batch_stats': batch_stats, 'counter': counter}
graphdef, params, batch_stats, count = nnx.split(
  model, nnx.Param, nnx.BatchStat, Counter)
params.keys()       # ['batchnorm', 'linear']
batch_stats.keys()  # ['batchnorm']
count.keys()        # ['count']

# ... make arbitrary modifications ...
# Merge back with ``nnx.merge`` to carry on:
model = nnx.merge(graphdef, params, batch_stats, count)

Using multiple methods#

In this section you will learn how to use multiple methods in both Flax Linen and Flax NNX. As an example, you will implement an auto-encoder model with three methods: encode, decode, and __call__.

Defining the encoder and decoder layers:

  • In Flax Linen, as before, define the layers without having to pass in the input shape, since the flax.linen.Module parameters will be initialized lazily using shape inference.

  • In Flax NNX, you must pass in the input shape since the nnx.Module parameters will be initialized eagerly without shape inference.

class AutoEncoder(nn.Module):
  embed_dim: int
  output_dim: int

  def setup(self):
    self.encoder = nn.Dense(self.embed_dim)
    self.decoder = nn.Dense(self.output_dim)

  def encode(self, x):
    return self.encoder(x)

  def decode(self, x):
    return self.decoder(x)

  def __call__(self, x):
    x = self.encode(x)
    x = self.decode(x)
    return x

model = AutoEncoder(256, 784)
variables = model.init(jax.random.key(0), x=jnp.ones((1, 784)))
class AutoEncoder(nnx.Module):



  def __init__(self, in_dim: int, embed_dim: int, output_dim: int, rngs):
    self.encoder = nnx.Linear(in_dim, embed_dim, rngs=rngs)
    self.decoder = nnx.Linear(embed_dim, output_dim, rngs=rngs)

  def encode(self, x):
    return self.encoder(x)

  def decode(self, x):
    return self.decoder(x)

  def __call__(self, x):
    x = self.encode(x)
    x = self.decode(x)
    return x

model = AutoEncoder(784, 256, 784, rngs=nnx.Rngs(0))

The variable structure is as follows:

# variables['params']
{
  decoder: {
      bias: (784,),
      kernel: (256, 784),
  },
  encoder: {
      bias: (256,),
      kernel: (784, 256),
  },
}
# _, params, _ = nnx.split(model, nnx.Param, ...)
# params
{
  'decoder': {
    'bias': VariableState(type=Param, value=(784,)),
    'kernel': VariableState(type=Param, value=(256, 784))
  },
  'encoder': {
    'bias': VariableState(type=Param, value=(256,)),
    'kernel': VariableState(type=Param, value=(784, 256))
  }
}

To call methods other than __call__:

  • In Flax Linen, you still need to use the apply API.

  • In Flax NNX, you can simply call the method directly.

z = model.apply(variables, x=jnp.ones((1, 784)), method="encode")
z = model.encode(jnp.ones((1, 784)))

Transformations#

Both Flax Linen and Flax NNX transformations provide their own set of transforms that wrap JAX transforms in a way that they can be used with Module objects.

Most of the transforms in Flax Linen, such as grad or jit, don’t change much in Flax NNX. But, for example, if you try to do scan over layers, as described in the next section, the code differs by a lot.

Let’s start with an example:

  • First, define an RNNCell Module that will contain the logic for a single step of the RNN.

  • Define a initial_state method that will be used to initialize the state (a.k.a. carry) of the RNN. Like with jax.lax.scan (API doc), the RNNCell.__call__ method will be a function that takes the carry and input, and returns the new carry and output. In this case, the carry and the output are the same.

class RNNCell(nn.Module):
  hidden_size: int


  @nn.compact
  def __call__(self, carry, x):
    x = jnp.concatenate([carry, x], axis=-1)
    x = nn.Dense(self.hidden_size)(x)
    x = jax.nn.relu(x)
    return x, x

  def initial_state(self, batch_size: int):
    return jnp.zeros((batch_size, self.hidden_size))
class RNNCell(nnx.Module):
  def __init__(self, input_size, hidden_size, rngs):
    self.linear = nnx.Linear(hidden_size + input_size, hidden_size, rngs=rngs)
    self.hidden_size = hidden_size

  def __call__(self, carry, x):
    x = jnp.concatenate([carry, x], axis=-1)
    x = self.linear(x)
    x = jax.nn.relu(x)
    return x, x

  def initial_state(self, batch_size: int):
    return jnp.zeros((batch_size, self.hidden_size))

Next, define an RNN Module that will contain the logic for the entire RNN.

In Flax Linen:

  • You will use flax.linen.scan (nn.scan) to define a new temporary type that wraps RNNCell. During this process you will also: 1) instruct nn.scan to broadcast the params collection (all steps share the same parameters) and to not split the params PRNG stream (so that all steps initialize with the same parameters); and, finally, 2) specify that you want scan to run over the second axis of the input and stack outputs along the second axis as well.

  • You will then use this temporary type immediately to create an instance of the “lifted” RNNCell and use it to create the carry, and the run the __call__ method, which will scan over the sequence.

In Flax NNX:

  • You will create a scan function (scan_fn) that will use the RNNCell defined in __init__ to scan over the sequence, and explicitly set in_axes=(nnx.Carry, None, 1). nnx.Carry means that the carry argument will be the carry, None means that cell will be broadcasted to all steps, and 1 means x will be scanned across axis 1.

class RNN(nn.Module):
  hidden_size: int

  @nn.compact
  def __call__(self, x):
    rnn = nn.scan(
      RNNCell, variable_broadcast='params',
      split_rngs={'params': False}, in_axes=1, out_axes=1
    )(self.hidden_size)
    carry = rnn.initial_state(x.shape[0])
    carry, y = rnn(carry, x)

    return y

x = jnp.ones((3, 12, 32))
model = RNN(64)
variables = model.init(jax.random.key(0), x=jnp.ones((3, 12, 32)))
y = model.apply(variables, x=jnp.ones((3, 12, 32)))
class RNN(nnx.Module):
  def __init__(self, input_size: int, hidden_size: int, rngs: nnx.Rngs):
    self.hidden_size = hidden_size
    self.cell = RNNCell(input_size, self.hidden_size, rngs=rngs)

  def __call__(self, x):
    scan_fn = lambda carry, cell, x: cell(carry, x)
    carry = self.cell.initial_state(x.shape[0])
    carry, y = nnx.scan(
      scan_fn, in_axes=(nnx.Carry, None, 1), out_axes=(nnx.Carry, 1)
    )(carry, self.cell, x)

    return y

x = jnp.ones((3, 12, 32))
model = RNN(x.shape[2], 64, rngs=nnx.Rngs(0))

y = model(x)

Scan over layers#

In general, transforms of Flax Linen and Flax NNX should look the same. However, Flax NNX transforms are designed to be closer to their lower-level JAX counterparts, and thus we throw away some assumptions in certain Linen lifted transforms. This scan-over-layers use case will be a good example to showcase it.

Scan-over-layers is a technique where you run an input through a sequence of N repeated layers, passing the output of each layer as the input to the next layer. This pattern can significantly reduce compilation time for large models. In the example below, you will repeat the Block Module 5 times in the top-level MLP Module.

  • In Flax Linen, you apply the flax.linen.scan (nn.scan) transforms upon the Block nn.Module to create a larger ScanBlock nn.Module that contains 5 Block nn.Module objects. It will automatically create a large parameter of shape (5, 64, 64) at initialization time, and iterate over at call time every (64, 64) slice for a total of 5 times, like a jax.lax.scan (API doc) would.

  • Up close, in the logic of this model there actually is no need for the jax.lax.scan operation at initialization time. What happens there is more like a jax.vmap operation - you are given a Block sub-Module that accepts (in_dim, out_dim), and you “vmap” it over num_layers of times to create a larger array.

  • In Flax NNX, you take advantage of the fact that model initialization and running code are completely decoupled, and instead use the nnx.vmap transform to initialize the underlying Block parameters, and the nnx.scan transform to run the model input through them.

For more information on Flax NNX transforms, check out the Transforms guide.

class Block(nn.Module):
  features: int
  training: bool

  @nn.compact
  def __call__(self, x, _):
    x = nn.Dense(self.features)(x)
    x = nn.Dropout(0.5)(x, deterministic=not self.training)
    x = jax.nn.relu(x)
    return x, None

class MLP(nn.Module):
  features: int
  num_layers: int




  @nn.compact
  def __call__(self, x, training: bool):
    ScanBlock = nn.scan(
      Block, variable_axes={'params': 0}, split_rngs={'params': True},
      length=self.num_layers)

    y, _ = ScanBlock(self.features, training)(x, None)
    return y

model = MLP(64, num_layers=5)
class Block(nnx.Module):
  def __init__(self, input_dim, features, rngs):
    self.linear = nnx.Linear(input_dim, features, rngs=rngs)
    self.dropout = nnx.Dropout(0.5, rngs=rngs)

  def __call__(self, x: jax.Array):  # No need to require a second input!
    x = self.linear(x)
    x = self.dropout(x)
    x = jax.nn.relu(x)
    return x   # No need to return a second output!

class MLP(nnx.Module):
  def __init__(self, features, num_layers, rngs):
    @nnx.split_rngs(splits=num_layers)
    @nnx.vmap(in_axes=(0,), out_axes=0)
    def create_block(rngs: nnx.Rngs):
      return Block(features, features, rngs=rngs)

    self.blocks = create_block(rngs)
    self.num_layers = num_layers

  def __call__(self, x):
    @nnx.split_rngs(splits=self.num_layers)
    @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
    def forward(x, model):
      x = model(x)
      return x

    return forward(x, self.blocks)

model = MLP(64, num_layers=5, rngs=nnx.Rngs(0))

There are a few other details to explain in the Flax NNX example above:

  • The `@nnx.split_rngs` decorator: Flax NNX transforms are completely agnostic of PRNG state, which makes them behave more like JAX transforms but diverge from the Flax Linen transforms that handle PRNG state. To regain this functionality, the nnx.split_rngs decorator allows you to split the nnx.Rngs before passing them to the decorated function and ‘lower’ them afterwards, so they can be used outside.

    • Here, you split the PRNG keys because jax.vmap and jax.lax.scan require a list of PRNG keys if each of its internal operations needs its own key. So for the 5 layers inside the MLP, you split and provide 5 different PRNG keys from its arguments before going down to the JAX transform.

    • Note that actually create_block() knows it needs to create 5 layers precisely because it sees 5 PRNG keys, because in_axes=(0,) indicates that vmap will look into the first argument’s first dimension to know the size it will map over.

    • Same goes for forward(), which looks at the variables inside the first argument (aka. model) to find out how many times it needs to scan. nnx.split_rngs here actually splits the PRNG state inside the model. (If the Block Module doesn’t have dropout, you don’t need the nnx.split_rngs line as it would not consume any PRNG key anyway.)

  • Why the Block Module in Flax NNX doesn’t need to take and return that extra dummy value: This is a requirement from jax.lax.scan (API doc. Flax NNX simplifies this, so that you can now choose to ignore the second output if you set out_axes=nnx.Carry instead of the default (nnx.Carry, 0).

    • This is one of the rare cases where Flax NNX transforms diverge from the JAX transforms APIs.

There are more lines of code in the Flax NNX example above, but they express what happens at each time more precisely. Since Flax NNX transforms become way closer to the JAX transform APIs, it is recommended to have a good understanding of the underlying JAX transforms before using their Flax NNX equivalents

Now inspect the variable pytree on both sides:

# variables = model.init(key, x=jnp.ones((1, 64)), training=True)
# variables['params']
{
  ScanBlock_0: {
    Dense_0: {
      bias: (5, 64),
      kernel: (5, 64, 64),
    },
  },
}
# _, params, _ = nnx.split(model, nnx.Param, ...)
# params
{
  'blocks': {
    'linear': {
      'bias': VariableState(type=Param, value=(5, 64)),
      'kernel': VariableState(type=Param, value=(5, 64, 64))
    }
  }
}

Using TrainState in Flax NNX#

Flax Linen has a convenient TrainState data class to bundle the model, parameters and optimizer. In Flax NNX, this is not really necessary. In this section, you will learn how to construct your Flax NNX code around TrainState for any backward compatibility needs.

In Flax NNX:

  • You must first call nnx.split on the model to get the separate nnx.GraphDef and nnx.State objects.

  • You can pass in nnx.Param to filter all trainable parameters into a single nnx.State, and pass in ... for the remaining variables.

  • You also need to subclass TrainState to add a field for the other variables.

  • Then, you can pass in nnx.GraphDef.apply as the apply function, nnx.State as the parameters and other variables, and an optimizer as arguments to the TrainState constructor.

Note that nnx.GraphDef.apply will take in nnx.State objects as arguments and return a callable function. This function can be called on the inputs to output the model’s logits, as well as the updated nnx.GraphDef and nnx.State objects. Notice below the use of @jax.jit since you aren’t passing in Flax NNX Modules into the train_step.

from flax.training import train_state

sample_x = jnp.ones((1, 784))
model = nn.Dense(features=10)
params = model.init(jax.random.key(0), sample_x)['params']




state = train_state.TrainState.create(
  apply_fn=model.apply,
  params=params,

  tx=optax.adam(1e-3)
)

@jax.jit
def train_step(key, state, inputs, labels):
  def loss_fn(params):
    logits = state.apply_fn(
      {'params': params},
      inputs, # <== inputs
      rngs={'dropout': key}
    )
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = jax.grad(loss_fn)(state.params)


  state = state.apply_gradients(grads=grads)

  return state
from flax.training import train_state

model = nnx.Linear(784, 10, rngs=nnx.Rngs(0))
model.train() # set deterministic=False
graphdef, params, other_variables = nnx.split(model, nnx.Param, ...)

class TrainState(train_state.TrainState):
  other_variables: nnx.State

state = TrainState.create(
  apply_fn=graphdef.apply,
  params=params,
  other_variables=other_variables,
  tx=optax.adam(1e-3)
)

@jax.jit
def train_step(state, inputs, labels):
  def loss_fn(params, other_variables):
    logits, (graphdef, new_state) = state.apply_fn(
      params,
      other_variables

    )(inputs) # <== inputs
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = jax.grad(loss_fn)(state.params, state.other_variables)


  state = state.apply_gradients(grads=grads)

  return state