NNX#

NNX is a JAX-based neural network library designed for simplicity and power. Its modular approach follows standard Python conventions, making it both intuitive and compatible with the broader JAX ecosystem.

Note

NNX is currently in an experimental state and is subject to change. Linen is still the recommended option for large-scale projects. Feedback and contributions are welcome!

Features#

Pythonic

Modules are standard Python classes, promoting ease of use and a more familiar development experience.

Compatible

Effortlessly convert between Modules and pytrees using the Functional API for maximum flexibility.

Control

Manage a Module’s state with precision using typed Variable collections, enabling fine-grained control on JAX transformations.

User-friendly

NNX prioritizes simplicity for common use cases, building upon lessons learned from Linen to provide a streamlined experience.

Installation#

NNX is under active development, we recommend using the latest version from Flax’s GitHub repository:

pip install git+https://github.com/google/flax.git

Basic usage#

from flax.experimental import nnx

class Linear(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    key = rngs() # get a unique random key
    self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,))) # initialize parameters
    self.din, self.dout = din, dout

  def __call__(self, x: jax.Array):
    return x @ self.w.value + self.b.value

rngs = nnx.Rngs(0) # explicit RNG handling
model = Linear(din=2, dout=3, rngs=rngs) # initialize the model

x = jnp.empty((1, 2)) # generate random data
y = model(x) # forward pass

Learn more#

NNX Basics
MNIST Tutorial
NNX vs JAX Transformations