Use Flax NNX and Linen together#

This guide is for existing Flax users who want to make their codebase a mixture of Flax Linen and Flax NNX Modules, which is made possible thanks to the flax.nnx.bridge API.

This will be helpful if you:

  • Want to migrate your codebase to NNX gradually, one module at a time;

  • Have external dependency that already moved to NNX but you haven’t, or is still in Linen while you’ve moved to NNX.

We hope this allows you to move and try out NNX at your own pace, and leverage the best of both worlds. We will also talk about how to resolve the caveats of interoperating the two APIs, on a few aspects that they are fundamentally different.

Note:

This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the Migrate from Flax Linen to Flax NNX guide.

And all built-in Linen layers should have equivalent NNX versions! Check out the list of Built-in NNX layers.

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

from flax import nnx
from flax import linen as nn
from flax.nnx import bridge
import jax
from jax import numpy as jnp
from jax.experimental import mesh_utils
from typing import *

Submodule is all you need#

A Flax model is always a tree of modules - either old Linen modules (flax.linen.Module, usually written as nn.Module) or NNX modules (nnx.Module).

An nnx.bridge wrapper glues the two types together, in both ways:

  • nnx.bridge.ToNNX: Convert a Linen module to NNX, so that it can be a submodule of another NNX module, or stand alone to be trained in NNX-style training loops.

  • nnx.bridge.ToLinen: Vice versa, convert a NNX module to Linen.

This means you can move in either top-down or bottom-up behavior: convert the whole Linen module to NNX, then gradually move down, or convert all the lower level modules to NNX then move up.

The Basics#

There are two fundamental difference between Linen and NNX modules:

  • Stateless vs. stateful: Linen module instances are stateless: variables are returned from a purely functional .init() call and managed separately. NNX modules, however, owns its variables as instance attributes.

  • Lazy vs. eager: Linen modules only allocate space to create variables when they actually see their input. Whereas NNX module instances create their variables the moment they are instantiated, without seeing a sample input.

With that in mind, let’s look at how the nnx.bridge wrappers tackle the differences.

Linen -> NNX#

Since Linen modules may require an input to create variables, we semi-formally supported lazy initialization in the NNX modules converted from Linen. The Linen variables are created when you give it a sample input.

For you, it’s calling nnx.bridge.lazy_init() where you call module.init() in Linen code.

(Note: you can call nnx.display upon any NNX module to inspect all its variables and state.)

class LinenDot(nn.Module):
  out_dim: int
  w_init: Callable[..., Any] = nn.initializers.lecun_normal()
  @nn.compact
  def __call__(self, x):
    # Linen might need the input shape to create the weight!
    w = self.param('w', self.w_init, (x.shape[-1], self.out_dim))
    return x @ w

x = jax.random.normal(jax.random.key(42), (4, 32))
model = bridge.ToNNX(LinenDot(64), 
                     rngs=nnx.Rngs(0))  # => `model = LinenDot(64)` in Linen
bridge.lazy_init(model, x)              # => `var = model.init(key, x)` in Linen
y = model(x)                            # => `y = model.apply(var, x)` in Linen

nnx.display(model)

# In-place swap your weight array and the model still works!
model.w.value = jax.random.normal(jax.random.key(1), (32, 64))
assert not jnp.allclose(y, model(x))

nnx.bridge.lazy_init also works even if the top-level module is a pure-NNX one, so you can do sub-moduling as you wish:

class NNXOuter(nnx.Module):
  def __init__(self, out_dim: int, rngs: nnx.Rngs):
    self.dot = nnx.bridge.ToNNX(LinenDot(out_dim), rngs=rngs)
    self.b = nnx.Param(jax.random.uniform(rngs.params(), (1, out_dim,)))

  def __call__(self, x):
    return self.dot(x) + self.b

x = jax.random.normal(jax.random.key(42), (4, 32))
model = bridge.lazy_init(NNXOuter(64, rngs=nnx.Rngs(0)), x)  # Can fit into one line
nnx.display(model)

The Linen weight is already converted to a typical NNX variable, which is a thin wrapper of the actual JAX array value within. Here, w is an nnx.Param because it belongs to the params collection of LinenDot module.

We will talk more about different collections and types in the NNX Variable <-> Linen Collections section. Right now, just know that they are converted to NNX variables like native ones.

assert isinstance(model.dot.w, nnx.Param)
assert isinstance(model.dot.w.value, jax.Array)

If you create this model witout using nnx.bridge.lazy_init, the NNX variables defined outside will be initialized as usual, but the Linen part (wrapped inside ToNNX) will not.

partial_model = NNXOuter(64, rngs=nnx.Rngs(0))
nnx.display(partial_model)
full_model = bridge.lazy_init(partial_model, x)
nnx.display(full_model)

NNX -> Linen#

To convert an NNX module to Linen, you should forward your creation arguments to bridge.ToLinen and let it handle the actual creation process.

This is because NNX module instance initializes all its variables eagerly when it is created, which consumes memory and compute. On the other hand, Linen modules are stateless, and the typical init and apply process involves multiple creation of them. So bridge.to_linen will handle the actual module creation and make sure no memory is allocated twice.

class NNXDot(nnx.Module):
  def __init__(self, in_dim: int, out_dim: int, rngs: nnx.Rngs):
    self.w = nnx.Param(nnx.initializers.lecun_normal()(
      rngs.params(), (in_dim, out_dim)))
  def __call__(self, x: jax.Array):
    return x @ self.w

x = jax.random.normal(jax.random.key(42), (4, 32))
# Pass in the arguments, not an actual module
model = bridge.to_linen(NNXDot, 32, out_dim=64)
variables = model.init(jax.random.key(0), x)
y = model.apply(variables, x)

print(list(variables.keys()))
print(variables['params']['w'].shape)  # => (32, 64)
print(y.shape)                         # => (4, 64)
['nnx', 'params']
(32, 64)
(4, 64)

Note that ToLinen modules need to track an extra variable collection - nnx - for the static metadata of the underlying NNX module.

# This new field stores the static data that defines the underlying `NNXDot`
print(type(variables['nnx']['graphdef']))    # => `nnx.graph.NodeDef`
<class 'flax.nnx.graph.NodeDef'>

bridge.to_linen is actually a convenience wrapper around the Linen module bridge.ToLinen. Most likely you won’t need to use ToLinen directly at all, unless you are using one of the built-in arguments of ToLinen. For example, if your NNX module doesn’t want to be initialized with RNG handling:

class NNXAddConstant(nnx.Module):
  def __init__(self):
    self.constant = nnx.Variable(jnp.array(1))
  def __call__(self, x):
    return x + self.constant

# You have to use `skip_rng=True` because this module's `__init__` don't
# take `rng` as argument
model = bridge.ToLinen(NNXAddConstant, skip_rng=True)
y, var = model.init_with_output(jax.random.key(0), x)

Similar to ToNNX, you can use ToLinen to create a submodule of another Linen module.

class LinenOuter(nn.Module):
  out_dim: int
  @nn.compact
  def __call__(self, x):
    dot = bridge.to_linen(NNXDot, x.shape[-1], self.out_dim)
    b = self.param('b', nn.initializers.lecun_normal(), (1, self.out_dim))
    return dot(x) + b

x = jax.random.normal(jax.random.key(42), (4, 32))
model = LinenOuter(out_dim=64)
y, variables = model.init_with_output(jax.random.key(0), x)
w, b = variables['params']['ToLinen_0']['w'], variables['params']['b']
print(w.shape, b.shape, y.shape)
(32, 64) (1, 64) (4, 64)

Handling RNG keys#

All Flax modules, Linen or NNX, automatically handle the RNG keys for variable creation and random layers like dropouts. However, the specific logics of RNG key splitting are different, so you cannot generate the same params between Linen and NNX modules, even if you pass in same keys.

Another difference is that NNX modules are stateful, so they can track and update the RNG keys within themselves.

Linen to NNX#

If you convert a Linen module to NNX, you enjoy the stateful benefit and don’t need to pass in extra RNG keys on every module call. You can use always nnx.reseed to reset the RNG state within.

x = jax.random.normal(jax.random.key(42), (4, 32))
model = bridge.ToNNX(nn.Dropout(rate=0.5, deterministic=False), rngs=nnx.Rngs(dropout=0))
# We don't really need to call lazy_init because no extra params were created here,
# but it's a good practice to always add this line.
bridge.lazy_init(model, x)
y1, y2 = model(x), model(x)
assert not jnp.allclose(y1, y2)  # Two runs yield different outputs!

# Reset the dropout RNG seed, so that next model run will be the same as the first.
nnx.reseed(model, dropout=0)
assert jnp.allclose(y1, model(x))

NNX to Linen#

If you convert an NNX module to Linen, the underlying NNX module’s RNG states will still be part of the top-level variables. On the other hand, Linen apply() call accepts different RNG keys on each call, which resets the internal Linen environment and allow different random data to be generated.

Now, it really depends on whether your underlying NNX module generates new random data from its RNG state, or from the passed-in argument. Fortunately, nnx.Dropout supports both - using passed-in keys if there is any, and use its own RNG state if not.

And this leaves you with two style options of handling the RNG keys:

  • The NNX style (recommended): Let the underlying NNX state manage the RNG keys, no need to pass in extra keys in apply(). This means a few more lines to mutate the variables for every apply call, but things will look easier once your whole model no longer needs ToLinen.

  • The Linen style: Just pass different RNG keys for every apply() call.

x = jax.random.normal(jax.random.key(42), (4, 32))
model = bridge.to_linen(nnx.Dropout, rate=0.5)
variables = model.init({'dropout': jax.random.key(0)}, x)

# The NNX RNG state was stored inside `variables`
print('The RNG key in state:', variables['RngKey']['rngs']['dropout']['key'].value)
print('Number of key splits:', variables['RngCount']['rngs']['dropout']['count'].value)

# NNX style: Must set `RngCount` as mutable and update the variables after every `apply`
y1, updates = model.apply(variables, x, mutable=['RngCount'])
variables |= updates
y2, updates = model.apply(variables, x, mutable=['RngCount'])
variables |= updates
print('Number of key splits after y2:', variables['RngCount']['rngs']['dropout']['count'].value)
assert not jnp.allclose(y1, y2)  # Every call yields different output!

# Linen style: Just pass different RNG keys for every `apply()` call.
y3 = model.apply(variables, x, rngs={'dropout': jax.random.key(1)})
y4 = model.apply(variables, x, rngs={'dropout': jax.random.key(2)})
assert not jnp.allclose(y3, y4)  # Every call yields different output!
y5 = model.apply(variables, x, rngs={'dropout': jax.random.key(1)})
assert jnp.allclose(y3, y5)      # When you use same top-level RNG, outputs are same
The RNG key in state: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
Number of key splits: 0
Number of key splits after y2: 2

NNX variable types vs. Linen collections#

When you want to group some variables as one category, in Linen you use different collections; in NNX, since all variables shall be top-level Python attributes, you use different variable types.

Therefore, when mixing Linen and NNX modules, Flax must know the 1-to-1 mapping between Linen collections and NNX variable types, so that ToNNX and ToLinen can do the conversion automatically.

Flax keeps a registry for this, and it already covers all Flax’s built-in Linen collections. You can register extra mapping of NNX variable type and Linen collection names using nnx.register_variable_name_type_pair.

Linen to NNX#

For any collection of your Linen module, ToNNX will convert all its endpoint arrays (aka. leaves) to a subtype of nnx.Variable, either from registry or automatically created on-the-fly.

(However, we still keep the whole collection as one class attribute, because Linen modules may have duplicated names over different collections.)

class LinenMultiCollections(nn.Module):
  out_dim: int
  def setup(self):
    self.w = self.param('w', nn.initializers.lecun_normal(), (x.shape[-1], self.out_dim))
    self.b = self.param('b', nn.zeros_init(), (self.out_dim,))
    self.count = self.variable('counter', 'count', lambda: jnp.zeros((), jnp.int32))

  def __call__(self, x):
    if not self.is_initializing():
      self.count.value += 1
    y = x @ self.w + self.b
    self.sow('intermediates', 'dot_sum', jnp.sum(y))
    return y

x = jax.random.normal(jax.random.key(42), (2, 4))
model = bridge.lazy_init(bridge.ToNNX(LinenMultiCollections(3), rngs=nnx.Rngs(0)), x)
print(model.w)        # Of type `nnx.Param` - note this is still under attribute `params`
print(model.b)        # Of type `nnx.Param`
print(model.count)    # Of type `counter` - auto-created type from the collection name
print(type(model.count))

y = model(x, mutable=True)  # Linen's `sow()` needs `mutable=True` to trigger
print(model.dot_sum)        # Of type `nnx.Intermediates`
Param(
  value=Array([[ 0.35401407,  0.38010964, -0.20674096],
         [-0.7356256 ,  0.35613298, -0.5099556 ],
         [-0.4783049 ,  0.4310735 ,  0.30137998],
         [-0.6102254 , -0.2668519 , -1.053598  ]], dtype=float32)
)
Param(
  value=Array([0., 0., 0.], dtype=float32)
)
counter(
  value=Array(0, dtype=int32)
)
<class 'flax.nnx.bridge.variables.counter'>
(Intermediate(
  value=Array(6.9329877, dtype=float32)
),)

You can quickly separate different types of NNX variables apart using nnx.split.

This can be handy when you only want to set some variables as trainable.

# Separate variables of different types with nnx.split
CountType = type(model.count)
static, params, counter, the_rest = nnx.split(model, nnx.Param, CountType, ...)
print('All Params:', list(params.keys()))
print('All Counters:', list(counter.keys()))
print('All the rest (intermediates and RNG keys):', list(the_rest.keys()))

model = nnx.merge(static, params, counter, the_rest)  # You can merge them back at any time
y = model(x, mutable=True)  # still works!
All Params: ['b', 'w']
All Counters: ['count']
All the rest (intermediates and RNG keys): ['dot_sum', 'rngs']

NNX to Linen#

If you define custom NNX variable types, you should register their names with nnx.register_variable_name_type_pair so that they go to the desired collections.

class Count(nnx.Variable): pass
nnx.register_variable_name_type_pair('counts', Count, overwrite=True)

class NNXMultiCollections(nnx.Module):
  def __init__(self, din, dout, rngs):
    self.w = nnx.Param(nnx.initializers.lecun_normal()(rngs.params(), (din, dout)))
    self.lora = nnx.LoRA(din, 3, dout, rngs=rngs)
    self.count = Count(jnp.array(0))

  def __call__(self, x):
    self.count += 1
    return (x @ self.w.value) + self.lora(x)

xkey, pkey, dkey = jax.random.split(jax.random.key(0), 3)
x = jax.random.normal(xkey, (2, 4))
model = bridge.to_linen(NNXMultiCollections, 4, 3)
var = model.init({'params': pkey, 'dropout': dkey}, x)
print('All Linen collections:', list(var.keys()))
print(var['params'])
All Linen collections: ['nnx', 'LoRAParam', 'counts', 'params']
{'w': Array([[ 0.2916921 ,  0.22780475,  0.06553137],
       [ 0.17487915, -0.34043145,  0.24764155],
       [ 0.6420431 ,  0.6220095 , -0.44769976],
       [ 0.11161668,  0.83873135, -0.7446058 ]], dtype=float32)}

Partition metadata#

Flax uses a metadata wrapper box over the raw JAX array to annotate how a variable should be sharded.

In Linen, this is an optional feature that triggered by using nn.with_partitioning on initializers (see more on Linen partition metadata guide). In NNX, since all NNX variables are wrapped by nnx.Variable class anyway, that class will hold the sharding annotations too.

The bridge.ToNNX and bridge.ToLinen API will automatically convert the sharding annotations, if you use the built-in annotation methods (aka. nn.with_partitioning for Linen and nnx.with_partitioning for NNX).

Linen to NNX#

Even if you are not using any partition metadata in your Linen module, the variable JAX arrays will be converted to nnx.Variables that wraps the true JAX array within.

If you use nn.with_partitioning to annotate your Linen module’s variables, the annotation will be converted to a .sharding field in the corresponding nnx.Variable.

You can then use nnx.with_sharding_constraint to explicitly put the arrays into the annotated partitions within a jax.jit-compiled function, to initialize the whole model with every array at the right sharding.

class LinenDotWithPartitioning(nn.Module):
  out_dim: int
  @nn.compact
  def __call__(self, x):
    w = self.param('w', nn.with_partitioning(nn.initializers.lecun_normal(), 
                                             ('in', 'out')), 
                   (x.shape[-1], self.out_dim))
    return x @ w

@nnx.jit
def create_sharded_nnx_module(x):
  model = bridge.lazy_init(
    bridge.ToNNX(LinenDotWithPartitioning(64), rngs=nnx.Rngs(0)), x)
  state = nnx.state(model)
  sharded_state = nnx.with_sharding_constraint(state, nnx.get_partition_spec(state))
  nnx.update(model, sharded_state)
  return model


print(f'We have {len(jax.devices())} fake JAX devices now to partition this model...')
mesh = jax.sharding.Mesh(devices=mesh_utils.create_device_mesh((2, 4)), 
                         axis_names=('in', 'out'))
x = jax.random.normal(jax.random.key(42), (4, 32))
with mesh:
  model = create_sharded_nnx_module(x)

print(type(model.w))           # `nnx.Param`
print(model.w.sharding)        # The partition annotation attached with `w`
print(model.w.value.sharding)  # The underlying JAX array is sharded across the 2x4 mesh
We have 8 fake JAX devices now to partition this model...
<class 'flax.nnx.variablelib.Param'>
('in', 'out')
NamedSharding(mesh=Mesh('in': 2, 'out': 4), spec=PartitionSpec('in', 'out'), memory_kind=unpinned_host)

NNX to Linen#

If you are not using any metadata feature of the nnx.Variable (i.e., no sharding annotation, no registered hooks), the converted Linen module will not add a metadata wrapper to your NNX variable, and you don’t need to worry about it.

But if you did add sharding annotations to your NNX variables, ToLinen will convert them to a default Linen partition metadata class called bridge.NNXMeta, retaining all the metadata you put into the NNX variable.

Like with any Linen metadata wrappers, you can use linen.unbox() to get the raw JAX array tree.

class NNXDotWithParititioning(nnx.Module):
  def __init__(self, in_dim: int, out_dim: int, rngs: nnx.Rngs):
    init_fn = nnx.with_partitioning(nnx.initializers.lecun_normal(), ('in', 'out'))
    self.w = nnx.Param(init_fn(rngs.params(), (in_dim, out_dim)))
  def __call__(self, x: jax.Array):
    return x @ self.w

x = jax.random.normal(jax.random.key(42), (4, 32))

@jax.jit
def create_sharded_variables(key, x):
  model = bridge.to_linen(NNXDotWithParititioning, 32, 64)
  variables = model.init(key, x)
  # A `NNXMeta` wrapper of the underlying `nnx.Param`
  assert type(variables['params']['w']) == bridge.NNXMeta
  # The annotation coming from the `nnx.Param` => (in, out)
  assert variables['params']['w'].metadata['sharding'] == ('in', 'out')
  
  unboxed_variables = nn.unbox(variables)
  variable_pspecs = nn.get_partition_spec(variables)
  assert isinstance(unboxed_variables['params']['w'], jax.Array)
  assert variable_pspecs['params']['w'] == jax.sharding.PartitionSpec('in', 'out')
  
  sharded_vars = jax.tree.map(jax.lax.with_sharding_constraint, 
                              nn.unbox(variables),
                              nn.get_partition_spec(variables))
  return sharded_vars

with mesh:
  variables = create_sharded_variables(jax.random.key(0), x)

# The underlying JAX array is sharded across the 2x4 mesh
print(variables['params']['w'].sharding)
NamedSharding(mesh=Mesh('in': 2, 'out': 4), spec=PartitionSpec('in', 'out'), memory_kind=unpinned_host)

Lifted transforms#

In general, if you want to apply Linen/NNX-style lifted transforms upon an nnx.bridge-converted module, just go ahead and do it in the usual Linen/NNX syntax.

For Linen-style transforms, note that bridge.ToLinen is the top level module class, so you may want to just use it as the first argument of your transforms (which needs to be a linen.Module class in most cases)

Linen to NNX#

NNX style lifted transforms are similar to JAX transforms, and they work on functions.

class NNXVmapped(nnx.Module):
  def __init__(self, out_dim: int, vmap_axis_size: int, rngs: nnx.Rngs):
    self.linen_dot = nnx.bridge.ToNNX(nn.Dense(out_dim, use_bias=False), rngs=rngs)
    self.vmap_axis_size = vmap_axis_size

  def __call__(self, x):

    @nnx.split_rngs(splits=self.vmap_axis_size)
    @nnx.vmap(in_axes=(0, 0), axis_size=self.vmap_axis_size)
    def vmap_fn(submodule, x):
      return submodule(x)

    return vmap_fn(self.linen_dot, x)

x = jax.random.normal(jax.random.key(0), (4, 32))
model = bridge.lazy_init(NNXVmapped(64, 4, rngs=nnx.Rngs(0)), x)

print(model.linen_dot.kernel.shape) # (4, 32, 64) - first axis with dim 4 got vmapped
y = model(x)
print(y.shape)
(4, 32, 64)
(4, 64)

NNX to Linen#

Note that bridge.ToLinen is the top level module class, so you may want to just use it as the first argument of your transforms (which needs to be a linen.Module class in most cases).

Also, since bridge.ToLinen introduced this extra nnx collection, you need to mark it when using the axis-changing transforms (linen.vmap, linen.scan, etc) to make sure they are passed inside.

class LinenVmapped(nn.Module):
  dout: int
  @nn.compact
  def __call__(self, x):
    inner = nn.vmap(bridge.ToLinen, variable_axes={'params': 0, 'nnx': None}, split_rngs={'params': True}
                    )(nnx.Linear, args=(x.shape[-1], self.dout))
    return inner(x)

x = jax.random.normal(jax.random.key(42), (4, 32))
model = LinenVmapped(64)
var = model.init(jax.random.key(0), x)
print(var['params']['VmapToLinen_0']['kernel'].shape)  # (4, 32, 64) - leading dim 4 got vmapped
y = model.apply(var, x)
print(y.shape)
(4, 32, 64)
(4, 64)