compat

compat#

NNX Compat API.

The compat module provides wrappers that use the legacy graph-mode implementation by default for many NNX APIs like split, state, and all the transforms such as jit, grad, etc. It does so by changing the default values to graph=True and graph_updates=True.

Example:

from flax import nnx

graphdef, state = nnx.compat.split(model)  # graph=True by default

@nnx.compat.jit  # graph=True, graph_updates=True by default
def train_step(model, optimizer, x, y):
  ...

See [Tree Mode NNX](https://flax.readthedocs.io/en/latest/flip/5310-tree-mode-nnx.html#prefix-filters) for more details.