Migrating from Haiku to Flax#
This guide will showcase the differences between Haiku, Flax Linen and Flax NNX. Both Haiku and Linen enforce a functional paradigm with stateless modules, while NNX is a new, next-generation API that embraces the python language to provide a more intuitive development experience.
Basic Example#
To create custom Modules you subclass from a Module
base class in
both Haiku and Flax. Modules can be defined inline in Haiku and Flax
Linen (using the @nn.compact
decorator), whereas modules can’t be
defined inline in NNX and must be defined in __init__
Linen requires a deterministic
argument to control whether or
not dropout is used. NNX also uses a deterministic
but the value can be set later using .eval()
and .train()
that will be shown in a later code snippet.
import haiku as hk
class Block(hk.Module):
def __init__(self, features: int, name=None):
self.features = features
def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
x = jax.nn.relu(x)
return x
class Model(hk.Module):
def __init__(self, dmid: int, dout: int, name=None):
self.dmid = dmid
self.dout = dout
def __call__(self, x, training: bool):
x = Block(self.dmid)(x, training)
x = hk.Linear(self.dout)(x)
return x
import flax.linen as nn
class Block(nn.Module):
features: int
def __call__(self, x, training: bool):
x = nn.Dense(self.features)(x)
x = nn.Dropout(0.5, deterministic=not training)(x)
x = jax.nn.relu(x)
return x
class Model(nn.Module):
dmid: int
dout: int
def __call__(self, x, training: bool):
x = Block(self.dmid)(x, training)
x = nn.Dense(self.dout)(x)
return x
from flax.experimental import nnx
class Block(nnx.Module):
def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
self.dropout = nnx.Dropout(0.5, rngs=rngs)
def __call__(self, x):
x = self.linear(x)
x = self.dropout(x)
x = jax.nn.relu(x)
return x
class Model(nnx.Module):
def __init__(self, din: int, dmid: int, dout: int, rngs: nnx.Rngs):
self.block = Block(din, dmid, rngs=rngs)
self.linear = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = self.block(x)
x = self.linear(x)
return x
Since modules are defined inline in Haiku and Linen, the parameters are lazily initialized, by inferring the shape of a sample input. In Flax NNX, the module is stateful and is initialized eagerly. This means that the input shape must be explicitly passed during module instantiation since there is no shape inference in NNX.
def forward(x, training: bool):
return Model(256, 10)(x, training)
model = hk.transform(forward)
model = Model(256, 10)
model = Model(784, 256, 10, rngs=nnx.Rngs(0))
To get the model parameters in both Haiku and Linen, you use the init
with a random.key
plus some inputs to run the model.
In NNX, the model parameters are automatically initialized when the user instantiates the model because the input shapes are already explicitly passed at instantiation time.
Since NNX is eager and the module is bound upon instantiation, the user can access
the parameters (and other fields defined in __init__
via dot-access). On the other
hand, Haiku and Linen use lazy initialization and so the parameters can only be accessed
once the module is initialized with a sample input and both frameworks do not support
dot-access of their attributes.
sample_x = jnp.ones((1, 784))
params = model.init(
sample_x, training=False # <== inputs
assert params['model/linear']['b'].shape == (10,)
assert params['model/block/linear']['w'].shape == (784, 256)
sample_x = jnp.ones((1, 784))
variables = model.init(
sample_x, training=False # <== inputs
params = variables["params"]
assert params['Dense_0']['bias'].shape == (10,)
assert params['Block_0']['Dense_0']['kernel'].shape == (784, 256)
# parameters were already initialized during model instantiation
assert model.linear.bias.value.shape == (10,)
assert model.block.linear.kernel.value.shape == (784, 256)
Let’s take a look at the parameter structure. In Haiku and Linen, we can
simply inspect the params
object returned from .init()
To see the parameter structure in NNX, the user can call nnx.split
generate Graphdef
and State
objects. The Graphdef
is a static pytree
denoting the structure of the model (for example usages, see
NNX Basics).
objects contains all the module variables (i.e. any class that sub-classes
). If we filter for nnx.Param
, we will generate a State
of all the learnable module parameters.
'model/block/linear': {
'b': (256,),
'w': (784, 256),
'model/linear': {
'b': (10,),
'w': (256, 10),
Block_0: {
Dense_0: {
bias: (256,),
kernel: (784, 256),
Dense_0: {
bias: (10,),
kernel: (256, 10),
graphdef, params, rngs = nnx.split(model, nnx.Param, nnx.RngState)
'block': {
'linear': {
'bias': VariableState(type=Param, value=(256,)),
'kernel': VariableState(type=Param, value=(784, 256))
'linear': {
'bias': VariableState(type=Param, value=(10,)),
'kernel': VariableState(type=Param, value=(256, 10))
During training in Haiku and Linen, you pass the parameters structure to the
method to run the forward pass. To use dropout, we must pass in
and provide a key
to apply
in order to generate the
random dropout masks. To use dropout in NNX, we first call model.train()
which will set the dropout layer’s deterministic
attribute to False
(conversely, calling model.eval()
would set deterministic
to True
Since the stateful NNX module already contains both the parameters and RNG key
(used for dropout), we simply need to call the module to run the forward pass. We
use nnx.split
to extract the learnable parameters (all learnable parameters
subclass the NNX class nnx.Param
) and then apply the gradients and statefully
update the model using nnx.update
To compile train_step
, we decorate the function using @jax.jit
for Haiku
and Linen, and @nnx.jit
for NNX. Similar to @jax.jit
, @nnx.jit
compiles functions, with the additional feature of allowing the user to compile
functions that take in NNX modules as arguments.
def train_step(key, params, inputs, labels):
def loss_fn(params):
logits = model.apply(
params, key,
inputs, training=True # <== inputs
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(params)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params
def train_step(key, params, inputs, labels):
def loss_fn(params):
logits = model.apply(
{'params': params},
inputs, training=True, # <== inputs
rngs={'dropout': key}
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(params)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params
model.train() # set deterministic=False
def train_step(model, inputs, labels):
def loss_fn(model):
logits = model(
inputs, # <== inputs
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = nnx.grad(loss_fn)(model)
# we can use Ellipsis to filter out the rest of the variables
_, params, _ = nnx.split(model, nnx.Param, ...)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
nnx.update(model, params)
Flax also offers a convenient TrainState
dataclass to bundle the model,
parameters and optimizer, to simplify training and updating the model. In Haiku
and Linen, we simply pass in the model.apply
function, initialized parameters
and optimizer as arguments to the TrainState
In NNX, we must first call nnx.split
on the model to get the
separated GraphDef
and State
objects. We can pass in nnx.Param
to filter
all trainable parameters into a single State
, and pass in ...
for the remaining
variables. We also need to subclass TrainState
to add a field for the other variables.
We can then pass in GraphDef.apply
as the apply function, State
as the parameters
and other variables and an optimizer as arguments to the TrainState
One thing to note is that GraphDef.apply
will take in State
’s as arguments and
return a callable function. This function can be called on the inputs to output the
model’s logits, as well as updated GraphDef
and State
objects. This isn’t needed
for our current example with dropout, but in the next section, you will see that using
these updated objects are relevant with layers like batch norm. Notice we also use
since we aren’t passing in NNX modules into train_step
from flax.training import train_state
state = train_state.TrainState.create(
def train_step(key, state, inputs, labels):
def loss_fn(params):
logits = state.apply_fn(
params, key,
inputs, training=True # <== inputs
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(state.params)
state = state.apply_gradients(grads=grads)
return state
from flax.training import train_state
state = train_state.TrainState.create(
def train_step(key, state, inputs, labels):
def loss_fn(params):
logits = state.apply_fn(
{'params': params},
inputs, training=True, # <== inputs
rngs={'dropout': key}
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(state.params)
state = state.apply_gradients(grads=grads)
return state
from flax.training import train_state
model.train() # set deterministic=False
graphdef, params, other_variables = nnx.split(model, nnx.Param, ...)
class TrainState(train_state.TrainState):
other_variables: nnx.State
state = TrainState.create(
def train_step(state, inputs, labels):
def loss_fn(params, other_variables):
logits, (graphdef, new_state) = state.apply_fn(
)(inputs) # <== inputs
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(state.params, state.other_variables)
state = state.apply_gradients(grads=grads)
return state
Handling State#
Now let’s see how mutable state is handled in all three frameworks. We will take the same model as before, but now we will replace Dropout with BatchNorm.
class Block(hk.Module):
def __init__(self, features: int, name=None):
self.features = features
def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.BatchNorm(
create_scale=True, create_offset=True, decay_rate=0.99
)(x, is_training=training)
x = jax.nn.relu(x)
return x
class Block(nn.Module):
features: int
def __call__(self, x, training: bool):
x = nn.Dense(self.features)(x)
x = nn.BatchNorm(
)(x, use_running_average=not training)
x = jax.nn.relu(x)
return x
class Block(nnx.Module):
def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
self.batchnorm = nnx.BatchNorm(
num_features=out_features, momentum=0.99, rngs=rngs
def __call__(self, x):
x = self.linear(x)
x = self.batchnorm(x)
x = jax.nn.relu(x)
return x
Haiku requires an is_training
argument and Linen requires a
argument to control whether or not to update the
running statistics. NNX also uses a use_running_average
but the value can be set later using .eval()
and .train()
that will be shown in later code snippets.
As before, you need to pass in the input shape to construct the Module eagerly in NNX.
def forward(x, training: bool):
return Model(256, 10)(x, training)
model = hk.transform_with_state(forward)
model = Model(256, 10)
model = Model(784, 256, 10, rngs=nnx.Rngs(0))
To initialize both the parameters and state in Haiku and Linen, you just
call the init
method as before. However, in Haiku you now get batch_stats
as a second return value, and in Linen you get a new batch_stats
in the variables
Note that since hk.BatchNorm
only initializes batch statistics when
, we must set training=True
when initializing parameters
of a Haiku model with an hk.BatchNorm
layer. In Linen, we can set
as usual.
In NNX, the parameters and state are already initialized upon module
instantiation. The batch statistics are of class nnx.BatchStat
subclasses the nnx.Variable
class (not nnx.Param
since they aren’t
learnable parameters). Calling nnx.split
with no additional filter arguments
will return a state containing all nnx.Variable
’s by default.
sample_x = jnp.ones((1, 784))
params, batch_stats = model.init(
sample_x, training=True # <== inputs
sample_x = jnp.ones((1, 784))
variables = model.init(
sample_x, training=False # <== inputs
params, batch_stats = variables["params"], variables["batch_stats"]
graphdef, params, batch_stats = nnx.split(model, nnx.Param, nnx.BatchStat)
Now, training looks very similar in Haiku and Linen as you use the same
method to run the forward pass. In Haiku, now pass the batch_stats
as the second argument to apply
, and get the newly updated batch_stats
as the second return value. In Linen, you instead add batch_stats
as a new
key to the input dictionary, and get the updates
variables dictionary as the
second return value. To update the batch statistics, we must pass in
to apply
In NNX, the training code is identical to the earlier example as the
batch statistics (which are bounded to the stateful NNX module) are updated
statefully. To update batch statistics in NNX, we first call model.train()
which will set the batchnorm layer’s use_running_average
attribute to False
(conversely, calling model.eval()
would set use_running_average
to True
Since the stateful NNX module already contains the parameters and batch statistics,
we simply need to call the module to run the forward pass. We use nnx.split
extract the learnable parameters (all learnable parameters subclass the NNX class
) and then apply the gradients and statefully update the model using
def train_step(params, batch_stats, inputs, labels):
def loss_fn(params, batch_stats):
logits, batch_stats = model.apply(
params, batch_stats,
None, # <== rng
inputs, training=True # <== inputs
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
return loss, batch_stats
grads, batch_stats = jax.grad(loss_fn, has_aux=True)(params, batch_stats)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params, batch_stats
def train_step(params, batch_stats, inputs, labels):
def loss_fn(params, batch_stats):
logits, updates = model.apply(
{'params': params, 'batch_stats': batch_stats},
inputs, training=True, # <== inputs
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
return loss, updates["batch_stats"]
grads, batch_stats = jax.grad(loss_fn, has_aux=True)(params, batch_stats)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params, batch_stats
model.train() # set use_running_average=False
def train_step(model, inputs, labels):
def loss_fn(model):
logits = model(
inputs, # <== inputs
) # batch statistics are updated statefully in this step
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
return loss
grads = nnx.grad(loss_fn)(model)
_, params, _ = nnx.split(model, nnx.Param, ...)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
nnx.update(model, params)
To use TrainState
, we subclass to add an additional field that can store
the batch statistics:
class TrainState(train_state.TrainState):
batch_stats: Any
state = TrainState.create(
def train_step(state, inputs, labels):
def loss_fn(params, batch_stats):
logits, batch_stats = state.apply_fn(
params, batch_stats,
None, # <== rng
inputs, training=True # <== inputs
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
return loss, batch_stats
grads, batch_stats = jax.grad(
loss_fn, has_aux=True
)(state.params, state.batch_stats)
state = state.apply_gradients(grads=grads)
state = state.replace(batch_stats=batch_stats)
return state
class TrainState(train_state.TrainState):
batch_stats: Any
state = TrainState.create(
def train_step(state, inputs, labels):
def loss_fn(params, batch_stats):
logits, updates = state.apply_fn(
{'params': params, 'batch_stats': batch_stats},
inputs, training=True, # <== inputs
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
return loss, updates['batch_stats']
grads, batch_stats = jax.grad(
loss_fn, has_aux=True
)(state.params, state.batch_stats)
state = state.apply_gradients(grads=grads)
state = state.replace(batch_stats=batch_stats)
return state
model.train() # set deterministic=False
graphdef, params, batch_stats = nnx.split(model, nnx.Param, nnx.BatchStat)
class TrainState(train_state.TrainState):
batch_stats: Any
state = TrainState.create(
def train_step(state, inputs, labels):
def loss_fn(params, batch_stats):
logits, (graphdef, new_state) = state.apply_fn(
params, batch_stats
)(inputs) # <== inputs
_, batch_stats = new_state.split(nnx.Param, nnx.BatchStat)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
return loss, batch_stats
grads, batch_stats = jax.grad(
loss_fn, has_aux=True
)(state.params, state.batch_stats)
state = state.apply_gradients(grads=grads)
state = state.replace(batch_stats=batch_stats)
return state
Using Multiple Methods#
In this section we will take a look at how to use multiple methods in all three
frameworks. As an example, we will implement an auto-encoder model with three methods:
, decode
, and __call__
As before, we define the encoder and decoder layers without having to pass in the input shape, since the module parameters will be initialized lazily using shape inference in Haiku and Linen. In NNX, we must pass in the input shape since the module parameters will be initialized eagerly without shape inference.
class AutoEncoder(hk.Module):
def __init__(self, embed_dim: int, output_dim: int, name=None):
self.encoder = hk.Linear(embed_dim, name="encoder")
self.decoder = hk.Linear(output_dim, name="decoder")
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
class AutoEncoder(nn.Module):
embed_dim: int
output_dim: int
def setup(self):
self.encoder = nn.Dense(self.embed_dim)
self.decoder = nn.Dense(self.output_dim)
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
class AutoEncoder(nnx.Module):
def __init__(self, in_dim: int, embed_dim: int, output_dim: int, rngs):
self.encoder = nnx.Linear(in_dim, embed_dim, rngs=rngs)
self.decoder = nnx.Linear(embed_dim, output_dim, rngs=rngs)
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
As before, we pass in the input shape when instantiating the NNX module.
def forward():
module = AutoEncoder(256, 784)
init = lambda x: module(x)
return init, (module.encode, module.decode)
model = hk.multi_transform(forward)
model = AutoEncoder(256, 784)
model = AutoEncoder(784, 256, 784, rngs=nnx.Rngs(0))
For Haiku and Linen, init
can be used to trigger the
method to initialize the parameters of our model,
which uses both the encode
and decode
method. This will
create all the necessary parameters for the model. In NNX,
the parameters are already initialized upon module instantiation.
params = model.init(
x=jnp.ones((1, 784)),
params = model.init(
x=jnp.ones((1, 784)),
# parameters were already initialized during model instantiation
The parameter structure is as follows:
'auto_encoder/~/decoder': {
'b': (784,),
'w': (256, 784)
'auto_encoder/~/encoder': {
'b': (256,),
'w': (784, 256)
decoder: {
bias: (784,),
kernel: (256, 784),
encoder: {
bias: (256,),
kernel: (784, 256),
_, params, _ = nnx.split(model, nnx.Param, ...)
'decoder': {
'bias': VariableState(type=Param, value=(784,)),
'kernel': VariableState(type=Param, value=(256, 784))
'encoder': {
'bias': VariableState(type=Param, value=(256,)),
'kernel': VariableState(type=Param, value=(784, 256))
Finally, let’s explore how we can employ the forward pass. In Haiku
and Linen, we use the apply
function to invoke the encode
method. In NNX, we simply can simply call the encode
encode, decode = model.apply
z = encode(
None, # <== rng
x=jnp.ones((1, 784)),
z = model.apply(
{"params": params},
x=jnp.ones((1, 784)),
z = model.encode(jnp.ones((1, 784)))
Lifted Transforms#
Both Flax and Haiku provide a set of transforms, which we will refer to as lifted transforms,
that wrap JAX transformations in such a way that they can be used with Modules and sometimes
provide additional functionality. In this section we will take a look at how to use the
lifted version of scan
in both Flax and Haiku to implement a simple RNN layer.
To begin, we will first define a RNNCell
module that will contain the logic for a single
step of the RNN. We will also define a initial_state
method that will be used to initialize
the state (a.k.a. carry
) of the RNN. Like with jax.lax.scan
, the RNNCell.__call__
method will be a function that takes the carry and input, and returns the new
carry and output. In this case, the carry and the output are the same.
class RNNCell(hk.Module):
def __init__(self, hidden_size: int, name=None):
self.hidden_size = hidden_size
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = hk.Linear(self.hidden_size)(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
class RNNCell(nn.Module):
hidden_size: int
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = nn.Dense(self.hidden_size)(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
class RNNCell(nnx.Module):
def __init__(self, input_size, hidden_size, rngs):
self.linear = nnx.Linear(hidden_size + input_size, hidden_size, rngs=rngs)
self.hidden_size = hidden_size
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = self.linear(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
Next, we will define a RNN
Module that will contain the logic for the entire RNN.
In Haiku, we will first initialze the RNNCell
, then use it to construct the carry
and finally use hk.scan
to run the RNNCell
over the input sequence.
In Linen, we will use nn.scan
to define a new temporary type that wraps
. During this process we will also specify instruct nn.scan
to broadcast
the params
collection (all steps share the same parameters) and to not split the
rng stream (so all steps intialize with the same parameters), and finally
we will specify that we want scan to run over the second axis of the input and stack
the outputs along the second axis as well. We will then use this temporary type immediately
to create an instance of the lifted RNNCell
and use it to create the carry
the run the __call__
method which will scan
over the sequence.
In NNX, we define a scan function scan_fn
that will use the RNNCell
in __init__
to scan over the sequence.
class RNN(hk.Module):
def __init__(self, hidden_size: int, name=None):
self.hidden_size = hidden_size
def __call__(self, x):
cell = RNNCell(self.hidden_size)
carry = cell.initial_state(x.shape[0])
carry, y = hk.scan(
cell, carry,
jnp.swapaxes(x, 1, 0)
y = jnp.swapaxes(y, 0, 1)
return y
class RNN(nn.Module):
hidden_size: int
def __call__(self, x):
rnn = nn.scan(
RNNCell, variable_broadcast='params',
split_rngs={'params': False}, in_axes=1, out_axes=1
carry = rnn.initial_state(x.shape[0])
carry, y = rnn(carry, x)
return y
class RNN(nnx.Module):
def __init__(self, input_size: int, hidden_size: int, rngs: nnx.Rngs):
self.hidden_size = hidden_size
self.cell = RNNCell(input_size, self.hidden_size, rngs=rngs)
def __call__(self, x):
scan_fn = lambda carry, cell, x: cell(carry, x)
carry = self.cell.initial_state(x.shape[0])
carry, y = nnx.scan(
scan_fn, in_axes=(nnx.Carry, None, 1), out_axes=(nnx.Carry, 1)
)(carry, self.cell, x)
return y
In general, the main difference between lifted transforms between Flax and Haiku is that
in Haiku the lifted transforms don’t operate over the state, that is, Haiku will handle the
and state
in such a way that it keeps the same shape inside and outside of the
transform. In Flax, the lifted transforms can operate over both variable collections and rng
streams, the user must define how different collections are treated by each transform
according to the transform’s semantics.
As before, the parameters must be initialized via .init()
and passed into .apply()
to conduct a forward pass in Haiku and Linen. In NNX, the parameters are already
eagerly initialized and bound to the stateful module, and the module can be simply called
on the input to conduct a forward pass.
x = jnp.ones((3, 12, 32))
def forward(x):
return RNN(64)(x)
model = hk.without_apply_rng(hk.transform(forward))
params = model.init(
x=jnp.ones((3, 12, 32)),
y = model.apply(
x=jnp.ones((3, 12, 32)),
x = jnp.ones((3, 12, 32))
model = RNN(64)
params = model.init(
x=jnp.ones((3, 12, 32)),
y = model.apply(
{'params': params},
x=jnp.ones((3, 12, 32)),
x = jnp.ones((3, 12, 32))
model = RNN(x.shape[2], 64, rngs=nnx.Rngs(0))
y = model(x)
The only notable change with respect to the examples in the previous sections is that
this time around we used hk.without_apply_rng
in Haiku so we didn’t have to
pass the rng
argument as None
to the apply
Scan over layers#
One very important application of scan
is apply a sequence of layers iteratively
over an input, passing the output of each layer as the input to the next layer. This
is very useful to reduce compilation time for big models. As an example we will create
a simple Block
Module, and then use it inside an MLP
Module that will apply
the Block
Module num_layers
In Haiku, we define the Block
Module as usual, and then inside MLP
we will
use hk.experimental.layer_stack
over a stack_block
function to create a stack
of Block
In Linen, the definition of Block
is a little different,
will accept and return a second dummy input/output that in both cases will
be None
. In MLP
, we will use nn.scan
as in the previous example, but
by setting split_rngs={'params': True}
and variable_axes={'params': 0}
we are telling nn.scan
create different parameters for each step and slice the
collection along the first axis, effectively implementing a stack of
Modules as in Haiku.
In NNX, we use nnx.Scan.constructor()
to define a stack of Block
We can then simply call the stack of Block
’s, self.blocks
, on the input and
carry to get the forward pass output.
class Block(hk.Module):
def __init__(self, features: int, name=None):
self.features = features
def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
x = jax.nn.relu(x)
return x
class MLP(hk.Module):
def __init__(self, features: int, num_layers: int, name=None):
self.features = features
self.num_layers = num_layers
def __call__(self, x, training: bool):
def stack_block(x):
return Block(self.features)(x, training)
stack = hk.experimental.layer_stack(self.num_layers)
return stack_block(x)
class Block(nn.Module):
features: int
training: bool
def __call__(self, x, _):
x = nn.Dense(self.features)(x)
x = nn.Dropout(0.5)(x, deterministic=not self.training)
x = jax.nn.relu(x)
return x, None
class MLP(nn.Module):
features: int
num_layers: int
def __call__(self, x, training: bool):
ScanBlock = nn.scan(
Block, variable_axes={'params': 0}, split_rngs={'params': True},
y, _ = ScanBlock(self.features, training)(x, None)
return y
class Block(nnx.Module):
def __init__(self, input_dim, features, rngs):
self.linear = nnx.Linear(input_dim, features, rngs=rngs)
self.dropout = nnx.Dropout(0.5, rngs=rngs)
def __call__(self, x: jax.Array, _):
x = self.linear(x)
x = self.dropout(x)
x = jax.nn.relu(x)
return x, None
class MLP(nnx.Module):
def __init__(self, input_dim, features, num_layers, rngs):
self.blocks = nnx.Scan.constructor(
Block, length=num_layers
)(input_dim, features, rngs=rngs)
def __call__(self, x):
y, _ = self.blocks(x, None)
return y
Notice how in Flax we pass None
as the second argument to ScanBlock
and ignore
its second output. These represent the inputs/outputs per-step but they are None
because in this case we don’t have any.
Initializing each model is the same as in previous examples. In this case,
we will be specifying that we want to use 5
layers each with 64
As before, we also pass in the input shape for NNX.
def forward(x, training: bool):
return MLP(64, num_layers=5)(x, training)
model = hk.transform(forward)
sample_x = jnp.ones((1, 64))
params = model.init(
sample_x, training=False # <== inputs
model = MLP(64, num_layers=5)
sample_x = jnp.ones((1, 64))
params = model.init(
sample_x, training=False # <== inputs
model = MLP(64, 64, num_layers=5, rngs=nnx.Rngs(0))
When using scan over layers the one thing you should notice is that all layers
are fused into a single layer whose parameters have an extra “layer” dimension on
the first axis. In this case, the shape of all parameters will start with (5, ...)
as we are using 5
'mlp/__layer_stack_no_per_layer/block/linear': {
'b': (5, 64),
'w': (5, 64, 64)
ScanBlock_0: {
Dense_0: {
bias: (5, 64),
kernel: (5, 64, 64),
_, params, _ = nnx.split(model, nnx.Param, ...)
'blocks': {
'scan_module': {
'linear': {
'bias': VariableState(type=Param, value=(5, 64)),
'kernel': VariableState(type=Param, value=(5, 64, 64))
Top-level Haiku functions vs top-level Flax modules#
In Haiku, it is possible to write the entire model as a single function by using
the raw hk.{get,set}_{parameter,state}
to define/access model parameters and
states. It very common to write the top-level “Module” as a function instead.
The Flax team recommends a more Module-centric approach that uses __call__
define the forward function. In Linen, the corresponding accessor will be
and Module.variable
(go to Handling State
for an explanation on collections). In NNX, the parameters and variables can
be set and accessed as normal using regular Python class semantics.
def forward(x):
counter = hk.get_state('counter', shape=[], dtype=jnp.int32, init=jnp.ones)
multiplier = hk.get_parameter(
'multiplier', shape=[1,], dtype=x.dtype, init=jnp.ones
output = x + multiplier * counter
hk.set_state("counter", counter + 1)
return output
model = hk.transform_with_state(forward)
params, state = model.init(random.key(0), jnp.ones((1, 64)))
class FooModule(nn.Module):
def __call__(self, x):
counter = self.variable('counter', 'count', lambda: jnp.ones((), jnp.int32))
multiplier = self.param(
'multiplier', nn.initializers.ones_init(), [1,], x.dtype
output = x + multiplier * counter.value
if not self.is_initializing(): # otherwise model.init() also increases it
counter.value += 1
return output
model = FooModule()
variables = model.init(random.key(0), jnp.ones((1, 64)))
params, counter = variables['params'], variables['counter']
class Counter(nnx.Variable):
class FooModule(nnx.Module):
def __init__(self, rngs):
self.counter = Counter(jnp.ones((), jnp.int32))
self.multiplier = nnx.Param(
nnx.initializers.ones(rngs.params(), [1,], jnp.float32)
def __call__(self, x):
output = x + self.multiplier * self.counter.value
self.counter.value += 1
return output
model = FooModule(rngs=nnx.Rngs(0))
_, params, counter = nnx.split(model, nnx.Param, Counter)