Flax NNX vs JAX Transformations#

Attention

This page relates to the new Flax NNX API.

In this guide, you will learn the differences using Flax NNX and JAX transformations, and how to seamlessly switch between them or use them together. We will be focusing on the jit and grad function transformations in this guide.

First, let’s set up imports and generate some dummy data:

from flax import nnx
import jax

x = jax.random.normal(jax.random.key(0), (1, 2))
y = jax.random.normal(jax.random.key(1), (1, 3))

Differences between NNX and JAX transformations#

The primary difference between Flax NNX and JAX transformations is that Flax NNX transformations allow you to transform functions that take in Flax NNX graph objects as arguments (Module, Rngs, Optimizer, etc), even those whose state will be mutated, whereas they aren’t recognized in JAX transformations. Therefore Flax NNX transformations can transform functions that are not pure and make mutations and side-effects.

Flax NNX’s Functional API provides a way to convert graph structures to pytrees and back. By doing this at every function boundary you can effectively use graph structures with any JAX transform and propagate state updates in a way consistent with functional purity. Flax NNX custom transforms such as nnx.jit and nnx.grad simply remove the boilerplate, as a result the code looks stateful.

Below is an example of using the nnx.jit and nnx.grad transformations compared to using the jax.jit and jax.grad transformations. Notice the function signature of Flax NNX-transformed functions can accept the nnx.Linear module directly and can make stateful updates to the module, whereas the function signature of JAX-transformed functions can only accept the pytree-registered State and GraphDef objects and must return an updated copy of them to maintain the purity of the transformed function.

@nnx.jit
def train_step(model, x, y):
  def loss_fn(model):
    return ((model(x) - y) ** 2).mean()
  grads = nnx.grad(loss_fn)(model)
  params = nnx.state(model, nnx.Param)
  params = jax.tree_util.tree_map(
    lambda p, g: p - 0.1 * g, params, grads
  )
  nnx.update(model, params)

model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
train_step(model, x, y)
@jax.jit
def train_step(graphdef, state, x, y):
  def loss_fn(graphdef, state):
    model = nnx.merge(graphdef, state)
    return ((model(x) - y) ** 2).mean()
  grads = jax.grad(loss_fn, argnums=1)(graphdef, state)

  model = nnx.merge(graphdef, state)
  params = nnx.state(model, nnx.Param)
  params = jax.tree_util.tree_map(
    lambda p, g: p - 0.1 * g, params, grads
  )
  nnx.update(model, params)
  return nnx.split(model)

graphdef, state = nnx.split(nnx.Linear(2, 3, rngs=nnx.Rngs(0)))
graphdef, state = train_step(graphdef, state, x, y)

Mixing Flax NNX and JAX transformations#

Flax NNX and JAX transformations can be mixed together, so long as the JAX-transformed function is pure and has valid argument types that are recognized by JAX.

@nnx.jit
def train_step(model, x, y):
  def loss_fn(graphdef, state):
    model = nnx.merge(graphdef, state)
    return ((model(x) - y) ** 2).mean()
  grads = jax.grad(loss_fn, 1)(*nnx.split(model))
  params = nnx.state(model, nnx.Param)
  params = jax.tree_util.tree_map(
    lambda p, g: p - 0.1 * g, params, grads
  )
  nnx.update(model, params)

model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
train_step(model, x, y)
@jax.jit
def train_step(graphdef, state, x, y):
  model = nnx.merge(graphdef, state)
  def loss_fn(model):
    return ((model(x) - y) ** 2).mean()
  grads = nnx.grad(loss_fn)(model)
  params = nnx.state(model, nnx.Param)
  params = jax.tree_util.tree_map(
    lambda p, g: p - 0.1 * g, params, grads
  )
  nnx.update(model, params)
  return nnx.split(model)

graphdef, state = nnx.split(nnx.Linear(2, 3, rngs=nnx.Rngs(0)))
graphdef, state = train_step(graphdef, state, x, y)