Evolution from Linen to NNX#
This guide will walk you through the differences between Flax Linen and Flax NNX models, and side-by-side comparisions to help you migrate your code from the Linen API to NNX.
Before this guide, it’s highly recommended to read through The Basics of Flax NNX to learn about the core concepts and code examples of Flax NNX.
This guide mainly covers converting arbitratry Linen code to NNX. If you want to play it safe and convert your codebase iteratively, check out the guide that allows you to use NNX and Linen code together
Basic Module Definition#
Both Linen and NNX uses the Module
as the default way to express a neural
library layer. There are two fundamental difference between Linen and NNX
Stateless vs. stateful: Linen module instances are stateless: variables are returned from a purely functional
call and managed separately. NNX modules, however, owns its variables as attributes of this Python object.Lazy vs. eager: Linen modules only allocate space to create variables when they actually see their input. Whereas NNX module instances create their variables the moment they are instantiated, without seeing a sample input.
Linen can use the
decorator to define the model in a single method and use shape inference from the input sample, whereas NNX modules generally requests additional shape information to create all parameters during__init__
and separately define the computation in__call__
import flax.linen as nn
class Block(nn.Module):
features: int
def __call__(self, x, training: bool):
x = nn.Dense(self.features)(x)
x = nn.Dropout(0.5, deterministic=not training)(x)
x = jax.nn.relu(x)
return x
class Model(nn.Module):
dmid: int
dout: int
def __call__(self, x, training: bool):
x = Block(self.dmid)(x, training)
x = nn.Dense(self.dout)(x)
return x
from flax import nnx
class Block(nnx.Module):
def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
self.dropout = nnx.Dropout(0.5, rngs=rngs)
def __call__(self, x):
x = self.linear(x)
x = self.dropout(x)
x = jax.nn.relu(x)
return x
class Model(nnx.Module):
def __init__(self, din: int, dmid: int, dout: int, rngs: nnx.Rngs):
self.block = Block(din, dmid, rngs=rngs)
self.linear = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = self.block(x)
x = self.linear(x)
return x
Variable Creation#
To generate the model parameters for a Linen model, you call the init
method with a jax.random.key
plus some sample inputs that the model shall take. The result is a nested dictionary of JAX arrays to be carried around and maintained separately.
In NNX, the model parameters are automatically initialized when the user instantiates the model, and the variables are stored inside the module (or its submodule) as attributes. You still need to give it an RNG key, but the key will be wrapped inside a nnx.Rngs
class and will be stored inside, generating more RNG keys when needed.
If you want to access NNX model parameters in the stateless, dictionary-like fashion for checkpoint saving or model surgery, check out the NNX split/merge API.
model = Model(256, 10)
sample_x = jnp.ones((1, 784))
variables = model.init(jax.random.key(0), sample_x, training=False)
params = variables["params"]
assert params['Dense_0']['bias'].shape == (10,)
assert params['Block_0']['Dense_0']['kernel'].shape == (784, 256)
model = Model(784, 256, 10, rngs=nnx.Rngs(0))
# parameters were already initialized during model instantiation
assert model.linear.bias.value.shape == (10,)
assert model.block.linear.kernel.value.shape == (784, 256)
Training Step and Compilation#
Now we write a training step and compile it using JAX just-in-time compilation. Note a few differences here:
Linen uses
to compile the training step, whereas NNX uses@nnx.jit
only accepts pure stateless arguments, butnnx.jit
allows the arguments to be stateful NNX modules. This greatly reduced the number of lines needed for a train step.Similarly, Linen uses
to return a raw dictionary of gradients, wheras NNX can usennx.grad
to return the gradients of Modules as NNXState
dictionaries. To use regularjax.grad
with NNX you need to use the NNX split/merge API.If you are already using Optax optimizers like
(instead of the rawjax.tree.map
computation shown here), check out nnx.Optimizer example for a much more concise way of training and updating your model.
The Linen train step needs to return a tree of parameters, as the input of the next step. On the other hand, NNX’s step doesn’t need to return anything, because the
was already in-place-updated withinnnx.jit
.NNX modules are stateful and automatically tracks a few things within, such as RNG keys and BatchNorm stats. That’s why you don’t need to explicitly pass an RNG key in on every step. Note that you can use nnx.reseed to reset its underlying RNG state.
In Linen, you need to explicitly define and pass in an argument
to control the behavior ofnn.Dropout
(namely, itsdeterministic
flag, which means random dropout only happens iftraining=True
). In NNX, you can callmodel.train()
to automatically switchnnx.Dropout
to training mode. Conversely, callmodel.eval()
to turn off training mode. You can learn more about what this API does at its API reference.
def train_step(key, params, inputs, labels):
def loss_fn(params):
logits = model.apply(
{'params': params},
inputs, training=True, # <== inputs
rngs={'dropout': key}
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(params)
params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
return params
model.train() # Sets ``deterministic=False` under the hood for nnx.Dropout
def train_step(model, inputs, labels):
def loss_fn(model):
logits = model(inputs)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = nnx.grad(loss_fn)(model)
_, params, rest = nnx.split(model, nnx.Param, ...)
params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
nnx.update(model, nnx.GraphState.merge(params, rest))
Collections and Variable Types#
One key difference between Linen and NNX APIs is how we group variables into categories. In Linen, we use different collections; in NNX, since all variables shall be top-level Python attributes, you use different variable types.
You can freely create your own variable types as subclasses of nnx.Variable
For all the built-in Flax Linen layers and collections, NNX already created the corresponding layers and variable type. For example:
You can also simply get the intermediates by assigning it to a module attribute, like
self.sowed = nnx.Intermediates(x)
. This will be similar to Linen’sself.variable('intermediates' 'sowed', lambda: x)
class Block(nn.Module):
features: int
def setup(self):
self.dense = nn.Dense(self.features)
self.batchnorm = nn.BatchNorm(momentum=0.99)
self.count = self.variable('counter', 'count',
lambda: jnp.zeros((), jnp.int32))
def __call__(self, x, training: bool):
x = self.dense(x)
x = self.batchnorm(x, use_running_average=not training)
self.count.value += 1
x = jax.nn.relu(x)
return x
x = jax.random.normal(jax.random.key(0), (2, 4))
model = Block(4)
variables = model.init(jax.random.key(0), x, training=True)
variables['params']['dense']['kernel'].shape # (4, 4)
variables['batch_stats']['batchnorm']['mean'].shape # (4, )
variables['counter']['count'] # 1
class Counter(nnx.Variable): pass
class Block(nnx.Module):
def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
self.batchnorm = nnx.BatchNorm(
num_features=out_features, momentum=0.99, rngs=rngs
self.count = Counter(jnp.array(0))
def __call__(self, x):
x = self.linear(x)
x = self.batchnorm(x)
self.count += 1
x = jax.nn.relu(x)
return x
model = Block(4, 4, rngs=nnx.Rngs(0))
model.linear.kernel # Param(value=...)
model.batchnorm.mean # BatchStat(value=...)
model.count # Counter(value=...)
If you want to extract certain arrays from the tree of variables, you can access the specific dictionary path in Linen, or use nnx.split
to distinguish the types apart in NNX. The code below is an easier example, and check out Filter API Guide for more sophisticated filtering expressions.
params, batch_stats, counter = (
variables['params'], variables['batch_stats'], variables['counter'])
params.keys() # ['dense', 'batchnorm']
batch_stats.keys() # ['batchnorm']
counter.keys() # ['count']
# ... make arbitrary modifications ...
# Merge back with raw dict to carry on:
variables = {'params': params, 'batch_stats': batch_stats, 'counter': counter}
graphdef, params, batch_stats, count = nnx.split(
model, nnx.Param, nnx.BatchStat, Counter)
params.keys() # ['batchnorm', 'linear']
batch_stats.keys() # ['batchnorm']
count.keys() # ['count']
# ... make arbitrary modifications ...
# Merge back with ``nnx.merge`` to carry on:
model = nnx.merge(graphdef, params, batch_stats, count)
Using Multiple Methods#
In this section we will take a look at how to use multiple methods in both
frameworks. As an example, we will implement an auto-encoder model with three methods:
, decode
, and __call__
As before, we define the encoder and decoder layers without having to pass in the input shape, since the module parameters will be initialized lazily using shape inference in Linen. In NNX, we must pass in the input shape since the module parameters will be initialized eagerly without shape inference.
class AutoEncoder(nn.Module):
embed_dim: int
output_dim: int
def setup(self):
self.encoder = nn.Dense(self.embed_dim)
self.decoder = nn.Dense(self.output_dim)
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
model = AutoEncoder(256, 784)
variables = model.init(jax.random.key(0), x=jnp.ones((1, 784)))
class AutoEncoder(nnx.Module):
def __init__(self, in_dim: int, embed_dim: int, output_dim: int, rngs):
self.encoder = nnx.Linear(in_dim, embed_dim, rngs=rngs)
self.decoder = nnx.Linear(embed_dim, output_dim, rngs=rngs)
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
model = AutoEncoder(784, 256, 784, rngs=nnx.Rngs(0))
The variable structure is as follows:
# variables['params']
decoder: {
bias: (784,),
kernel: (256, 784),
encoder: {
bias: (256,),
kernel: (784, 256),
# _, params, _ = nnx.split(model, nnx.Param, ...)
# params
'decoder': {
'bias': VariableState(type=Param, value=(784,)),
'kernel': VariableState(type=Param, value=(256, 784))
'encoder': {
'bias': VariableState(type=Param, value=(256,)),
'kernel': VariableState(type=Param, value=(784, 256))
To call methods other than __call__
, in Linen you still need to use the apply
API, wheras in NNX you can simply call the method directly.
z = model.apply(variables, x=jnp.ones((1, 784)), method="encode")
z = model.encode(jnp.ones((1, 784)))
Lifted Transforms#
Flax APIs provide a set of transforms, which we will refer to as lifted transforms, that wrap JAX transforms in such a way that they can be used with Modules.
Most of the transforms in Linen doesn’t change much in NNX. See the next section (Scan over Layers) for a case in which the code differs a lot more.
To begin, we will first define a RNNCell
module that will contain the logic for a single
step of the RNN. We will also define a initial_state
method that will be used to initialize
the state (a.k.a. carry
) of the RNN. Like with jax.lax.scan
, the RNNCell.__call__
method will be a function that takes the carry and input, and returns the new
carry and output. In this case, the carry and the output are the same.
class RNNCell(nn.Module):
hidden_size: int
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = nn.Dense(self.hidden_size)(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
class RNNCell(nnx.Module):
def __init__(self, input_size, hidden_size, rngs):
self.linear = nnx.Linear(hidden_size + input_size, hidden_size, rngs=rngs)
self.hidden_size = hidden_size
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = self.linear(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
Next, we will define a RNN
Module that will contain the logic for the entire RNN.
In Linen, we will use nn.scan
to define a new temporary type that wraps
. During this process we will also specify instruct nn.scan
to broadcast
the params
collection (all steps share the same parameters) and to not split the
rng stream (so all steps intialize with the same parameters), and finally
we will specify that we want scan to run over the second axis of the input and stack
the outputs along the second axis as well. We will then use this temporary type immediately
to create an instance of the lifted RNNCell
and use it to create the carry
the run the __call__
method which will scan
over the sequence.
In NNX, we define a scan function scan_fn
that will use the RNNCell
in __init__
to scan over the sequence, and explicitly set in_axes=(nnx.Carry, None, 1)
means that the carry
argument will be the carry, None
means that cell
be broadcasted to all steps, and 1
means x
will be scanned across axis 1.
class RNN(nn.Module):
hidden_size: int
def __call__(self, x):
rnn = nn.scan(
RNNCell, variable_broadcast='params',
split_rngs={'params': False}, in_axes=1, out_axes=1
carry = rnn.initial_state(x.shape[0])
carry, y = rnn(carry, x)
return y
x = jnp.ones((3, 12, 32))
model = RNN(64)
variables = model.init(jax.random.key(0), x=jnp.ones((3, 12, 32)))
y = model.apply(variables, x=jnp.ones((3, 12, 32)))
class RNN(nnx.Module):
def __init__(self, input_size: int, hidden_size: int, rngs: nnx.Rngs):
self.hidden_size = hidden_size
self.cell = RNNCell(input_size, self.hidden_size, rngs=rngs)
def __call__(self, x):
scan_fn = lambda carry, cell, x: cell(carry, x)
carry = self.cell.initial_state(x.shape[0])
carry, y = nnx.scan(
scan_fn, in_axes=(nnx.Carry, None, 1), out_axes=(nnx.Carry, 1)
)(carry, self.cell, x)
return y
x = jnp.ones((3, 12, 32))
model = RNN(x.shape[2], 64, rngs=nnx.Rngs(0))
y = model(x)
Scan over Layers#
In general, lifted transforms of Linen and NNX should look the same. However, NNX lifted transforms are designed to be closer to their lower level JAX counterparts, and thus we throw away some assumptions in certain Linen lifted transforms. This scan-over-layers use case will be a good example to showcase it.
Scan-over-layers is a technique in which, we want run an input through a sequence of N repeated layers, passing the output of each layer as the input to the next layer. This pattern can significantly reduce compilation time for big models. In this example, we will repeat the module Block
for 5 times in a top-level module MLP
In Linen, we apply a nn.scan
upon the module Block
to create a larger module ScanBlock
that contains 5 Block
. It will automatically create a large parameter of shape (5, 64, 64)
at initialization time, and at call time iterate over every (64, 64)
slice for a total of 5 times, like a jax.lax.scan
But if you think closely, there actually isn’t any need for jax.lax.scan
operation at initialization time. What happened there is more like a jax.vmap
operation - you are given a Block
that accepts (in_dim, out_dim)
, and you “vmap” it over num_layers
of times to create a larger array.
In NNX we take advantage of the fact that model initialization and running code are completely decoupled, and instead use nnx.vmap
to initialize the underlying blocks, and nnx.scan
to run the model input through them.
For more information on NNX transforms, check out the Transforms Guide.
class Block(nn.Module):
features: int
training: bool
def __call__(self, x, _):
x = nn.Dense(self.features)(x)
x = nn.Dropout(0.5)(x, deterministic=not self.training)
x = jax.nn.relu(x)
return x, None
class MLP(nn.Module):
features: int
num_layers: int
def __call__(self, x, training: bool):
ScanBlock = nn.scan(
Block, variable_axes={'params': 0}, split_rngs={'params': True},
y, _ = ScanBlock(self.features, training)(x, None)
return y
model = MLP(64, num_layers=5)
class Block(nnx.Module):
def __init__(self, input_dim, features, rngs):
self.linear = nnx.Linear(input_dim, features, rngs=rngs)
self.dropout = nnx.Dropout(0.5, rngs=rngs)
def __call__(self, x: jax.Array): # No need to require a second input!
x = self.linear(x)
x = self.dropout(x)
x = jax.nn.relu(x)
return x # No need to return a second output!
class MLP(nnx.Module):
def __init__(self, features, num_layers, rngs):
@nnx.vmap(in_axes=(0,), out_axes=0)
def create_block(rngs: nnx.Rngs):
return Block(features, features, rngs=rngs)
self.blocks = create_block(rngs)
self.num_layers = num_layers
def __call__(self, x):
@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
def forward(x, model):
x = model(x)
return x
return forward(x, self.blocks)
model = MLP(64, num_layers=5, rngs=nnx.Rngs(0))
There are a few other details to explain in this example:
What is that `nnx.split_rngs` decorator? NNX transforms are completely agnostic of RNG state, which makes them behave more like JAX transforms but diverge from the Linen transforms that handle RNG state. To regain this functionality, the
decorator allows you to split theRngs
before passing them to the decorated function and ‘lower’ them afterwards so they can be used outside.Here we split the RNG keys because
requires a list of RNG keys if each of its internal operations needs its own key. So for the 5 layers insideMLP
, we split and provide 5 different RNG keys from its arguments before going down to the JAX transform.Note that actually
knows it needs to create 5 layers precisely because it sees 5 RNG keys, becausein_axes=(0,)
will look into the first argument’s first dimension to know the size it will map over.Same goes for
, which looks at the variables inside the first argument (aka.model
) to find out how many times it needs to scan.nnx.split_rngs
here actually splits the RNG state inside themodel
. (IfBlock
doesn’t have dropout, you don’t need thennx.split_rngs
line because it would not consume any RNG key anyway.)
Why the `Block` in NNX doesn’t need to take and return that extra dummy value? This is a requirement from jax.lax.scan. NNX simplifies this so that now you can choose to ignore the second input/output if you set
instead of the default(nnx.Carry, 0)
.This is one of the rare cases in which NNX transforms diverge from JAX transforms API.
This is more lines of code, but it expresses what happened at each time more precisely. Since NNX lifted transforms become way closer to JAX APIs, it’s recommended to have a good understanding of the underlying JAX transform before using their NNX versions.
Now take a look at the variable tree on both sides:
# variables = model.init(key, x=jnp.ones((1, 64)), training=True)
# variables['params']
ScanBlock_0: {
Dense_0: {
bias: (5, 64),
kernel: (5, 64, 64),
# _, params, _ = nnx.split(model, nnx.Param, ...)
# params
'blocks': {
'linear': {
'bias': VariableState(type=Param, value=(5, 64)),
'kernel': VariableState(type=Param, value=(5, 64, 64))
Using TrainState
in NNX#
Flax offered a convenient TrainState
dataclass to bundle the model,
parameters and optimizer. This is not really necessary in NNX era, but this section we would show how to construct your NNX code around it, for any backward compatibility needs.
In NNX, we must first call nnx.split
on the model to get the
separated GraphDef
and State
objects. We can pass in nnx.Param
to filter
all trainable parameters into a single State
, and pass in ...
for the remaining
variables. We also need to subclass TrainState
to add a field for the other variables.
We can then pass in GraphDef.apply
as the apply function, State
as the parameters
and other variables and an optimizer as arguments to the TrainState
One thing to note is that GraphDef.apply
will take in State
’s as arguments and
return a callable function. This function can be called on the inputs to output the
model’s logits, as well as updated GraphDef
and State
objects. Notice we also use
since we aren’t passing in NNX modules into train_step
from flax.training import train_state
sample_x = jnp.ones((1, 784))
model = nn.Dense(features=10)
params = model.init(jax.random.key(0), sample_x)['params']
state = train_state.TrainState.create(
def train_step(key, state, inputs, labels):
def loss_fn(params):
logits = state.apply_fn(
{'params': params},
inputs, # <== inputs
rngs={'dropout': key}
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(state.params)
state = state.apply_gradients(grads=grads)
return state
from flax.training import train_state
model = nnx.Linear(784, 10, rngs=nnx.Rngs(0))
model.train() # set deterministic=False
graphdef, params, other_variables = nnx.split(model, nnx.Param, ...)
class TrainState(train_state.TrainState):
other_variables: nnx.State
state = TrainState.create(
def train_step(state, inputs, labels):
def loss_fn(params, other_variables):
logits, (graphdef, new_state) = state.apply_fn(
)(inputs) # <== inputs
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(state.params, state.other_variables)
state = state.apply_gradients(grads=grads)
return state