NNX#
NNX is a JAX-based neural network library designed for simplicity and power. Its modular approach follows standard Python conventions, making it both intuitive and compatible with the broader JAX ecosystem.
Note
NNX is currently in an experimental state and is subject to change. Linen is still the recommended option for large-scale projects. Feedback and contributions are welcome!
Features#
Modules are standard Python classes, promoting ease of use and a more familiar development experience.
Effortlessly convert between Modules and pytrees using the Functional API for maximum flexibility.
Manage a Module’s state with precision using typed Variable collections, enabling fine-grained control on JAX transformations.
NNX prioritizes simplicity for common use cases, building upon lessons learned from Linen to provide a streamlined experience.
Installation#
NNX is under active development, we recommend using the latest version from Flax’s GitHub repository:
pip install git+https://github.com/google/flax.git
Basic usage#
from flax.experimental import nnx
class Linear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
key = rngs() # get a unique random key
self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,))) # initialize parameters
self.din, self.dout = din, dout
def __call__(self, x: jax.Array):
return x @ self.w.value + self.b.value
rngs = nnx.Rngs(0) # explicit RNG handling
model = Linear(din=2, dout=3, rngs=rngs) # initialize the model
x = jnp.empty((1, 2)) # generate random data
y = model(x) # forward pass