# Flax NNX vs JAX Transformations#

Attention

This page relates to the new Flax NNX API.

In this guide, you will learn the differences using Flax NNX and JAX transformations, and how to
seamlessly switch between them or use them together. We will be focusing on the `jit`

and
`grad`

function transformations in this guide.

First, let’s set up imports and generate some dummy data:

```
from flax import nnx
import jax
x = jax.random.normal(jax.random.key(0), (1, 2))
y = jax.random.normal(jax.random.key(1), (1, 3))
```

## Differences between NNX and JAX transformations#

The primary difference between Flax NNX and JAX transformations is that Flax NNX transformations allow you to transform functions that take in Flax NNX graph objects as arguments (Module, Rngs, Optimizer, etc), even those whose state will be mutated, whereas they aren’t recognized in JAX transformations. Therefore Flax NNX transformations can transform functions that are not pure and make mutations and side-effects.

Flax NNX’s Functional API
provides a way to convert graph structures to pytrees and back. By doing this at every function
boundary you can effectively use graph structures with any JAX transform and propagate state updates
in a way consistent with functional purity. Flax NNX custom transforms such as `nnx.jit`

and `nnx.grad`

simply remove the boilerplate, as a result the code looks stateful.

Below is an example of using the `nnx.jit`

and `nnx.grad`

transformations compared to using the
`jax.jit`

and `jax.grad`

transformations. Notice the function signature of Flax NNX-transformed
functions can accept the `nnx.Linear`

module directly and can make stateful updates to the module,
whereas the function signature of JAX-transformed functions can only accept the pytree-registered
`State`

and `GraphDef`

objects and must return an updated copy of them to maintain the purity of
the transformed function.

```
@nnx.jit
def train_step(model, x, y):
def loss_fn(model):
return ((model(x) - y) ** 2).mean()
grads = nnx.grad(loss_fn)(model)
params = nnx.state(model, nnx.Param)
params = jax.tree_util.tree_map(
lambda p, g: p - 0.1 * g, params, grads
)
nnx.update(model, params)
model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
train_step(model, x, y)
```

```
@jax.jit
def train_step(graphdef, state, x, y):
def loss_fn(graphdef, state):
model = nnx.merge(graphdef, state)
return ((model(x) - y) ** 2).mean()
grads = jax.grad(loss_fn, argnums=1)(graphdef, state)
model = nnx.merge(graphdef, state)
params = nnx.state(model, nnx.Param)
params = jax.tree_util.tree_map(
lambda p, g: p - 0.1 * g, params, grads
)
nnx.update(model, params)
return nnx.split(model)
graphdef, state = nnx.split(nnx.Linear(2, 3, rngs=nnx.Rngs(0)))
graphdef, state = train_step(graphdef, state, x, y)
```

## Mixing Flax NNX and JAX transformations#

Flax NNX and JAX transformations can be mixed together, so long as the JAX-transformed function is pure and has valid argument types that are recognized by JAX.

```
@nnx.jit
def train_step(model, x, y):
def loss_fn(graphdef, state):
model = nnx.merge(graphdef, state)
return ((model(x) - y) ** 2).mean()
grads = jax.grad(loss_fn, 1)(*nnx.split(model))
params = nnx.state(model, nnx.Param)
params = jax.tree_util.tree_map(
lambda p, g: p - 0.1 * g, params, grads
)
nnx.update(model, params)
model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
train_step(model, x, y)
```

```
@jax.jit
def train_step(graphdef, state, x, y):
model = nnx.merge(graphdef, state)
def loss_fn(model):
return ((model(x) - y) ** 2).mean()
grads = nnx.grad(loss_fn)(model)
params = nnx.state(model, nnx.Param)
params = jax.tree_util.tree_map(
lambda p, g: p - 0.1 * g, params, grads
)
nnx.update(model, params)
return nnx.split(model)
graphdef, state = nnx.split(nnx.Linear(2, 3, rngs=nnx.Rngs(0)))
graphdef, state = train_step(graphdef, state, x, y)
```