NNX#
NNX is a JAX-based neural network library designed for simplicity and power. Its modular approach follows standard Python conventions, making it both intuitive and compatible with the broader JAX ecosystem.
Note
NNX is currently in an experimental state and is subject to change. Linen is still the recommended option for large-scale projects. Feedback and contributions are welcome!
Features#
Modules are standard Python classes, promoting ease of use and a more familiar development experience.
Effortlessly convert between Modules and pytrees using the Functional API for maximum flexibility.
Manage a Module’s state with precision using typed Variable collections, enabling fine-grained control on JAX transformations.
NNX prioritizes simplicity for common use cases, building upon lessons learned from Linen to provide a streamlined experience.
Basic usage#
from flax.experimental import nnx
import optax
class Model(nnx.Module):
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dmid, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.dropout = nnx.Dropout(0.2, rngs=rngs)
self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = nnx.relu(self.dropout(self.bn(self.linear(x))))
return self.linear_out(x)
model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # reference sharing
@nnx.jit # automatic state management
def train_step(model, optimizer, x, y):
def loss_fn(model):
y_pred = model(x) # call methods directly
return ((y_pred - y) ** 2).mean()
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads) # inplace updates
return loss
Installation#
NNX is under active development, we recommend using the latest version from Flax’s GitHub repository:
pip install git+https://github.com/google/flax.git