NNX#

NNX is a Neural Network library for JAX that focuses on providing the best development experience, so building and experimenting with neural networks is easy and intuitive. It achieves this by embracing Python’s object-oriented model and making it compatible with JAX transforms, resulting in code that is easy to inspect, debug, and analyze.

Features#

Pythonic

NNX supports the use of regular Python objects, providing an intuitive and predictable development experience.

Simple

NNX relies on Python’s object model, which results in simplicity for the user and increases development speed.

Streamlined

NNX integrates user feedback and hands-on experience with Linen into a new simplified API.

Compatible

NNX makes it very easy to integrate objects with regular JAX code via the Functional API.

Basic usage#

from flax import nnx
import optax


class Model(nnx.Module):
  def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
    self.linear = nnx.Linear(din, dmid, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.dropout = nnx.Dropout(0.2, rngs=rngs)
    self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x):
    x = nnx.relu(self.dropout(self.bn(self.linear(x))))
    return self.linear_out(x)

model = Model(2, 64, 3, rngs=nnx.Rngs(0))  # eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3))  # reference sharing

@nnx.jit # automatic state management
def train_step(model, optimizer, x, y):
  def loss_fn(model):
    y_pred = model(x)  # call methods directly
    return ((y_pred - y) ** 2).mean()

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(grads)  # inplace updates

  return loss

Installation#

Install NNX via pip:

pip install flax

Or install the latest version from the repository:

pip install git+https://github.com/google/flax.git

Learn more#

NNX Basics
nnx_basics.html
MNIST Tutorial
mnist_tutorial.html
NNX vs JAX Transformations
transforms.html
Haiku and Linen vs NNX
haiku_linen_vs_nnx.html