NNX is a Neural Network library for JAX that focuses on providing the best development experience, so building and experimenting with neural networks is easy and intuitive. It achieves this by embracing Python’s object-oriented model and making it compatible with JAX transforms, resulting in code that is easy to inspect, debug, and analyze.



NNX supports the use of regular Python objects, providing an intuitive and predictable development experience.


NNX relies on Python’s object model, which results in simplicity for the user and increases development speed.


NNX integrates user feedback and hands-on experience with Linen into a new simplified API.


NNX makes it very easy to integrate objects with regular JAX code via the Functional API.

Basic usage#

from flax import nnx
import optax

class Model(nnx.Module):
  def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
    self.linear = nnx.Linear(din, dmid, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.dropout = nnx.Dropout(0.2, rngs=rngs)
    self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x):
    x = nnx.relu(self.dropout(self.bn(self.linear(x))))
    return self.linear_out(x)

model = Model(2, 64, 3, rngs=nnx.Rngs(0))  # eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3))  # reference sharing

@nnx.jit # automatic state management
def train_step(model, optimizer, x, y):
  def loss_fn(model):
    y_pred = model(x)  # call methods directly
    return ((y_pred - y) ** 2).mean()

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(grads)  # inplace updates

  return loss


Install NNX via pip:

pip install flax

Or install the latest version from the repository:

pip install git+https://github.com/google/flax.git

Learn more#

NNX Basics
MNIST Tutorial
NNX vs JAX Transformations
Haiku and Linen vs NNX