Migrating from Haiku to Flax#

This guide demonstrates the differences between Haiku and Flax NNX models, providing side-by-side example code to help you migrate to the Flax NNX API from Haiku.

If you are new to Flax NNX, make sure you become familiarized with Flax NNX basics, which covers the nnx.Module system, Flax transformations, and the Functional API with examples.

Let’s start with some imports.

Basic Module definition#

Both Haiku and Flax use the Module class as the default unit to express a neural network library layer. For example, to create a one-layer network with dropout and a ReLU activation function, you:

  • First, create a Block (by subclassing Module) composed of one linear layer with dropout and a ReLU activation function.

  • Then, use Block as a sub-Module when creating a Model (also by subclassing Module), which is made up of Block and a linear layer.

There are two fundamental differences between Haiku and Flax Module objects:

  • Stateless vs. stateful:

    • A haiku.Module instance is stateless. This means, the variables are returned from a purely functional Module.init() call and managed separately.

    • A flax.nnx.Module, however, owns its variables as attributes of this Python object.

  • Lazy vs. eager:

    • A haiku.Module only allocates space to create variables when they actually see the input when the user calls the model (lazy).

    • A flax.nnx.Module instance creates variables the moment they are instantiated, before seeing a sample input (eager).

import haiku as hk

class Block(hk.Module):
  def __init__(self, features: int, name=None):
    super().__init__(name=name)
    self.features = features

  def __call__(self, x, training: bool):
    x = hk.Linear(self.features)(x)
    x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
    x = jax.nn.relu(x)
    return x

class Model(hk.Module):
  def __init__(self, dmid: int, dout: int, name=None):
    super().__init__(name=name)
    self.dmid = dmid
    self.dout = dout

  def __call__(self, x, training: bool):
    x = Block(self.dmid)(x, training)
    x = hk.Linear(self.dout)(x)
    return x
from flax import nnx

class Block(nnx.Module):
  def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
    self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
    self.dropout = nnx.Dropout(0.5, rngs=rngs)

  def __call__(self, x):
    x = self.linear(x)
    x = self.dropout(x)
    x = jax.nn.relu(x)
    return x

class Model(nnx.Module):
  def __init__(self, din: int, dmid: int, dout: int, rngs: nnx.Rngs):
    self.block = Block(din, dmid, rngs=rngs)
    self.linear = nnx.Linear(dmid, dout, rngs=rngs)


  def __call__(self, x):
    x = self.block(x)
    x = self.linear(x)
    return x

Variable creation#

This section is about instantiating a model and initializing its parameters.

  • To generate model parameters for a Haiku model, you need to put it inside a forward function and use haiku.transform to make it purely functional. This results in a nested dictionary of JAX Arrays (jax.Array data types) to be carried around and maintained separately.

  • In Flax NNX, the model parameters are automatically initialized when you instantiate the model, and the variables (nnx.Variable objects) are stored inside the nnx.Module (or its sub-Module) as attributes. You still need to provide it with a pseudorandom number generator (PRNG) key, but that key will be wrapped inside an nnx.Rngs class and stored inside, generating more PRNG keys when needed.

If you want to access Flax model parameters in the stateless, dictionary-like fashion for checkpoint saving or model surgery, check out the Flax NNX split/merge API (nnx.split / nnx.merge).

def forward(x, training: bool):
  return Model(256, 10)(x, training)

model = hk.transform(forward)
sample_x = jnp.ones((1, 784))
params = model.init(jax.random.key(0), sample_x, training=False)


assert params['model/linear']['b'].shape == (10,)
assert params['model/block/linear']['w'].shape == (784, 256)
...


model = Model(784, 256, 10, rngs=nnx.Rngs(0))


# Parameters were already initialized during model instantiation.

assert model.linear.bias.value.shape == (10,)
assert model.block.linear.kernel.value.shape == (784, 256)

Training step and compilation#

This section covers writing a training step and compiling it using the JAX just-in-time compilation.

When compiling the training step:

  • Haiku uses @jax.jit - a JAX transformation - to compile a purely functional training step.

  • Flax NNX uses @nnx.jit - a Flax NNX transformation (one of several transform APIs that behave similarly to JAX transforms, but also work well with Flax objects). While jax.jit only accepts functions with pure stateless arguments, flax.nnx.jit allows the arguments to be stateful Modules. This greatly reduces the number of lines needed for a train step.

When taking gradients:

  • Similarly, Haiku uses jax.grad (a JAX transformation for automatic differentiation) to return a raw dictionary of gradients.

  • Meanwhile, Flax NNX uses flax.nnx.grad (a Flax NNX transformation) to return the gradients of Flax NNX Modules as flax.nnx.State dictionaries. If you want to use regular jax.grad with Flax NNX, you need to use the split/merge API.

For optimizers:

  • If you are already using Optax optimizers like optax.adamw (instead of the raw jax.tree.map computation shown here) with Haiku, check out the flax.nnx.Optimizer example in the Flax basics guide for a much more concise way of training and updating your model.

Model updates during each training step:

  • The Haiku training step needs to return a JAX pytree of parameters as the input of the next step.

  • The Flax NNX training step does not need to return anything, because the model was already updated in-place within nnx.jit.

  • In addition, nnx.Module objects are stateful, and Module automatically tracks several things within it, such as PRNG keys and flax.nnx.BatchNorm stats. That is why you don’t need to explicitly pass a PRNG key in at every step. Also note that you can use flax.nnx.reseed to reset its underlying PRNG state.

The dropout behavior:

  • In Haiku, you need to explicitly define and pass in the training argument to toggle haiku.dropout and make sure that random dropout only happens if training=True.

  • In Flax NNX, you can call model.train() (flax.nnx.Module.train()) to automatically switch flax.nnx.Dropout to the training mode. Conversely, you can call model.eval() (flax.nnx.Module.eval()) to turn off the training mode. You can learn more about what flax.nnx.Module.train does in its API reference.

...

@jax.jit
def train_step(key, params, inputs, labels):
  def loss_fn(params):
    logits = model.apply(
      params, key,
      inputs, training=True # <== inputs

    )
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = jax.grad(loss_fn)(params)


  params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)

  return params
model.train() # set deterministic=False

@nnx.jit
def train_step(model, inputs, labels):
  def loss_fn(model):
    logits = model(

      inputs, # <== inputs

    )
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = nnx.grad(loss_fn)(model)
  _, params, rest = nnx.split(model, nnx.Param, ...)
  params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
  nnx.update(model, nnx.GraphState.merge(params, rest))

Handling non-parameter states#

Haiku makes a distinction between trainable parameters and all other data (“states”) that the model tracks. For example, the batch stats used in batch norm is considered a state. Models with states needs to be transformed with hk.transform_with_state so that their .init() returns both params and states.

In Flax, there isn’t such a strong distinction - they are all subclasses of nnx.Variable and seen by a module as its attributes. Parameters are instances of a subclass called nnx.Param, and batch stats can be of another subclass called nnx.BatchStat. You can use nnx.split to quickly extract all data of a certain variable type.

Let’s see an example of this by taking the Block definition above but replace dropout with BatchNorm.

class Block(hk.Module):
  def __init__(self, features: int, name=None):
    super().__init__(name=name)
    self.features = features



  def __call__(self, x, training: bool):
    x = hk.Linear(self.features)(x)
    x = hk.BatchNorm(
      create_scale=True, create_offset=True, decay_rate=0.99
    )(x, is_training=training)
    x = jax.nn.relu(x)
    return x

def forward(x, training: bool):
  return Model(256, 10)(x, training)
model = hk.transform_with_state(forward)

sample_x = jnp.ones((1, 784))
params, batch_stats = model.init(jax.random.key(0), sample_x, training=True)
class Block(nnx.Module):
  def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
    self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
    self.batchnorm = nnx.BatchNorm(
      num_features=out_features, momentum=0.99, rngs=rngs
    )

  def __call__(self, x):
    x = self.linear(x)
    x = self.batchnorm(x)


    x = jax.nn.relu(x)
    return x



model = Block(4, 4, rngs=nnx.Rngs(0))

model.linear.kernel   # Param(value=...)
model.batchnorm.mean  # BatchStat(value=...)

Flax takes the difference of trainable params and other data into account. nnx.grad will only take gradients on the nnx.Param variables, thus skipping the batchnorm arrays automatically. Therefore, the training step will look the same for Flax NNX with this model.

Using multiple methods#

In this section you will learn how to use multiple methods in Haiku and Flax. As an example, you will implement an auto-encoder model with three methods: encode, decode, and __call__.

In Haiku, you need to use hk.multi_transform to explicitly define how the model shall be initialized and what methods (encode and decode here) it can call. Note that you still need to define a __call__ that activates both layers for the lazy initialization of all model parameters.

In Flax, it’s simpler as you initialized parameters in __init__ and the nnx.Module methods encode and decode can be used directly.

class AutoEncoder(hk.Module):

  def __init__(self, embed_dim: int, output_dim: int, name=None):
    super().__init__(name=name)
    self.encoder = hk.Linear(embed_dim, name="encoder")
    self.decoder = hk.Linear(output_dim, name="decoder")

  def encode(self, x):
    return self.encoder(x)

  def decode(self, x):
    return self.decoder(x)

  def __call__(self, x):
    x = self.encode(x)
    x = self.decode(x)
    return x

def forward():
  module = AutoEncoder(256, 784)
  init = lambda x: module(x)
  return init, (module.encode, module.decode)

model = hk.multi_transform(forward)
params = model.init(jax.random.key(0), x=jnp.ones((1, 784)))
class AutoEncoder(nnx.Module):

  def __init__(self, in_dim: int, embed_dim: int, output_dim: int, rngs):

    self.encoder = nnx.Linear(in_dim, embed_dim, rngs=rngs)
    self.decoder = nnx.Linear(embed_dim, output_dim, rngs=rngs)

  def encode(self, x):
    return self.encoder(x)

  def decode(self, x):
    return self.decoder(x)











model = AutoEncoder(784, 256, 784, rngs=nnx.Rngs(0))
...

The parameter structure is as follows:

...


{
    'auto_encoder/~/decoder': {
        'b': (784,),
        'w': (256, 784)
    },
    'auto_encoder/~/encoder': {
        'b': (256,),
        'w': (784, 256)
    }
}
_, params, _ = nnx.split(model, nnx.Param, ...)

params
State({
  'decoder': {
    'bias': VariableState(type=Param, value=(784,)),
    'kernel': VariableState(type=Param, value=(256, 784))
  },
  'encoder': {
    'bias': VariableState(type=Param, value=(256,)),
    'kernel': VariableState(type=Param, value=(784, 256))
  }
})

To call those custom methods:

  • In Haiku, you need to decouple the .apply function to extract your method before calling it.

  • In Flax, you can simply call the method directly.

encode, decode = model.apply
z = encode(params, None, x=jnp.ones((1, 784)))
...
z = model.encode(jnp.ones((1, 784)))

Transformations#

Both Haiku and Flax transformations provide their own set of transforms that wrap JAX transforms in a way that they can be used with Module objects.

For more information on Flax transforms, check out the Transforms guide.

Let’s start with an example:

  • First, define an RNNCell Module that will contain the logic for a single step of the RNN.

  • Define a initial_state method that will be used to initialize the state (a.k.a. carry) of the RNN. Like with jax.lax.scan (API doc), the RNNCell.__call__ method will be a function that takes the carry and input, and returns the new carry and output. In this case, the carry and the output are the same.

class RNNCell(hk.Module):
  def __init__(self, hidden_size: int, name=None):
    super().__init__(name=name)
    self.hidden_size = hidden_size

  def __call__(self, carry, x):
    x = jnp.concatenate([carry, x], axis=-1)
    x = hk.Linear(self.hidden_size)(x)
    x = jax.nn.relu(x)
    return x, x

  def initial_state(self, batch_size: int):
    return jnp.zeros((batch_size, self.hidden_size))
class RNNCell(nnx.Module):
  def __init__(self, input_size, hidden_size, rngs):
    self.linear = nnx.Linear(hidden_size + input_size, hidden_size, rngs=rngs)
    self.hidden_size = hidden_size

  def __call__(self, carry, x):
    x = jnp.concatenate([carry, x], axis=-1)
    x = self.linear(x)
    x = jax.nn.relu(x)
    return x, x

  def initial_state(self, batch_size: int):
    return jnp.zeros((batch_size, self.hidden_size))

Next, we will define a RNN Module that will contain the logic for the entire RNN. In both cases, we use the library’s scan call to run the RNNCell over the input sequence.

The only difference is that Flax nnx.scan allows you to specify which axis to repeat over in arguments in_axes and out_axes, which will be forwarded to the underlying `jax.lax.scan<https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html>`__, whereas in Haiku you need to transpose the input and output explicitly.

class RNN(hk.Module):
  def __init__(self, hidden_size: int, name=None):
    super().__init__(name=name)
    self.hidden_size = hidden_size

  def __call__(self, x):
    cell = RNNCell(self.hidden_size)
    carry = cell.initial_state(x.shape[0])
    carry, y = hk.scan(
      cell, carry,
      jnp.swapaxes(x, 1, 0)
    )
    y = jnp.swapaxes(y, 0, 1)
    return y
class RNN(nnx.Module):
  def __init__(self, input_size: int, hidden_size: int, rngs: nnx.Rngs):
    self.hidden_size = hidden_size
    self.cell = RNNCell(input_size, self.hidden_size, rngs=rngs)

  def __call__(self, x):
    scan_fn = lambda carry, cell, x: cell(carry, x)
    carry = self.cell.initial_state(x.shape[0])
    carry, y = nnx.scan(
      scan_fn, in_axes=(nnx.Carry, None, 1), out_axes=(nnx.Carry, 1)
    )(carry, self.cell, x)

    return y

Scan over layers#

Most Haiku transforms should look similar with Flax, since they all wraps their JAX counterparts, but the scan-over-layers use case is an exception.

Scan-over-layers is a technique where you run an input through a sequence of N repeated layers, passing the output of each layer as the input to the next layer. This pattern can significantly reduce compilation time for large models. In the example below, you will repeat the Block Module 5 times in the top-level MLP Module.

In Haiku, we define the Block Module as usual, and then inside MLP we will use hk.experimental.layer_stack over a stack_block function to create a stack of Block Modules. The same code will create 5 layers of parameters in initialization time, and run the input through them in call time.

In Flax, model initialization and calling code are completely decoupled, so we use the nnx.vmap transform to initialize the underlying Block parameters, and the nnx.scan transform to run the model input through them.

class Block(hk.Module):
  def __init__(self, features: int, name=None):
    super().__init__(name=name)
    self.features = features

  def __call__(self, x, training: bool):
    x = hk.Linear(self.features)(x)
    x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
    x = jax.nn.relu(x)
    return x

class MLP(hk.Module):
  def __init__(self, features: int, num_layers: int, name=None):
      super().__init__(name=name)
      self.features = features
      self.num_layers = num_layers





  def __call__(self, x, training: bool):

    @hk.experimental.layer_stack(self.num_layers)
    def stack_block(x):
      return Block(self.features)(x, training)

    stack = hk.experimental.layer_stack(self.num_layers)
    return stack_block(x)

def forward(x, training: bool):
  return MLP(64, num_layers=5)(x, training)
model = hk.transform(forward)

sample_x = jnp.ones((1, 64))
params = model.init(jax.random.key(0), sample_x, training=False)
class Block(nnx.Module):
  def __init__(self, input_dim, features, rngs):
    self.linear = nnx.Linear(input_dim, features, rngs=rngs)
    self.dropout = nnx.Dropout(0.5, rngs=rngs)

  def __call__(self, x: jax.Array):  # No need to require a second input!
    x = self.linear(x)
    x = self.dropout(x)
    x = jax.nn.relu(x)
    return x   # No need to return a second output!

class MLP(nnx.Module):
  def __init__(self, features, num_layers, rngs):
    @nnx.split_rngs(splits=num_layers)
    @nnx.vmap(in_axes=(0,), out_axes=0)
    def create_block(rngs: nnx.Rngs):
      return Block(features, features, rngs=rngs)

    self.blocks = create_block(rngs)
    self.num_layers = num_layers

  def __call__(self, x):
    @nnx.split_rngs(splits=self.num_layers)
    @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
    def forward(x, model):
      x = model(x)
      return x

    return forward(x, self.blocks)



model = MLP(64, num_layers=5, rngs=nnx.Rngs(0))

There are a few other details to explain in the Flax example above:

  • The `@nnx.split_rngs` decorator: Flax transforms, like their JAX counterparts, are completely agnostic of the PRNG state and rely on input for PRNG keys. The nnx.split_rngs decorator allows you to split the nnx.Rngs before passing them to the decorated function and ‘lower’ them afterwards, so they can be used outside.

    • Here, you split the PRNG keys because jax.vmap and jax.lax.scan require a list of PRNG keys if each of its internal operations needs its own key. So for the 5 layers inside the MLP, you split and provide 5 different PRNG keys from its arguments before going down to the JAX transform.

    • Note that actually create_block() knows it needs to create 5 layers precisely because it sees 5 PRNG keys, because in_axes=(0,) indicates that vmap will look into the first argument’s first dimension to know the size it will map over.

    • Same goes for forward(), which looks at the variables inside the first argument (aka. model) to find out how many times it needs to scan. nnx.split_rngs here actually splits the PRNG state inside the model. (If the Block Module doesn’t have dropout, you don’t need the nnx.split_rngs line as it would not consume any PRNG key anyway.)

  • Why the Block Module in Flax doesn’t need to take and return that extra dummy value: jax.lax.scan (API doc requires its function to return two inputs - the carry and the stacked output. In this case, we didn’t use the latter. Flax simplifies this, so that you can now choose to ignore the second output if you set out_axes=nnx.Carry instead of the default (nnx.Carry, 0).

    • This is one of the rare cases where Flax NNX transforms diverge from the JAX transforms APIs.

There are more lines of code in the Flax example above, but they express what happens at each time more precisely. Since Flax transforms become way closer to the JAX transform APIs, it is recommended to have a good understanding of the underlying JAX transforms before using their Flax NNX equivalents

Now inspect the variable pytree on both sides:

...


{
    'mlp/__layer_stack_no_per_layer/block/linear': {
        'b': (5, 64),
        'w': (5, 64, 64)
    }
}



...
_, params, _ = nnx.split(model, nnx.Param, ...)

params
State({
  'blocks': {
    'linear': {
      'bias': VariableState(type=Param, value=(5, 64)),
      'kernel': VariableState(type=Param, value=(5, 64, 64))
    }
  }
})

Top-level Haiku functions vs top-level Flax modules#

In Haiku, it is possible to write the entire model as a single function by using the raw hk.{get,set}_{parameter,state} to define/access model parameters and states. It is very common to write the top-level “Module” as a function instead.

The Flax team recommends a more Module-centric approach that uses __call__ to define the forward function. In Flax modules, the parameters and variables can be set and accessed as normal using regular Python class semantics.

...


def forward(x):


  counter = hk.get_state('counter', shape=[], dtype=jnp.int32, init=jnp.ones)
  multiplier = hk.get_parameter(
    'multiplier', shape=[1,], dtype=x.dtype, init=jnp.ones
  )

  output = x + multiplier * counter

  hk.set_state("counter", counter + 1)
  return output

model = hk.transform_with_state(forward)

params, state = model.init(jax.random.key(0), jnp.ones((1, 64)))
class Counter(nnx.Variable):
  pass

class FooModule(nnx.Module):

  def __init__(self, rngs):
    self.counter = Counter(jnp.ones((), jnp.int32))
    self.multiplier = nnx.Param(
      nnx.initializers.ones(rngs.params(), [1,], jnp.float32)
    )
  def __call__(self, x):
    output = x + self.multiplier * self.counter.value

    self.counter.value += 1
    return output

model = FooModule(rngs=nnx.Rngs(0))

_, params, counter = nnx.split(model, nnx.Param, Counter)