compat#
NNX Compat API.
The compat module provides wrappers that use the legacy graph-mode
implementation by default for many NNX APIs like split, state, and all
the transforms such as jit, grad, etc. It does so by changing the
default values to graph=True and graph_updates=True.
Example:
from flax import nnx
graphdef, state = nnx.compat.split(model) # graph=True by default
@nnx.compat.jit # graph=True, graph_updates=True by default
def train_step(model, optimizer, x, y):
...
See [Tree Mode NNX](https://flax.readthedocs.io/en/latest/flip/5310-tree-mode-nnx.html#prefix-filters) for more details.