graph#
- flax.nnx.split(node, *filters)[source]#
Split a graph node into a
GraphDef
and one or moreState`s. State is a ``Mapping`
from strings or integers toVariables
, Arrays or nested States. GraphDef contains all the static information needed to reconstruct aModule
graph, it is analogous to JAX’sPyTreeDef
.split()
is used in conjunction withmerge()
to switch seamlessly between stateful and stateless representations of the graph.Example usage:
>>> from flax import nnx >>> import jax, jax.numpy as jnp ... >>> class Foo(nnx.Module): ... def __init__(self, rngs): ... self.batch_norm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... >>> node = Foo(nnx.Rngs(0)) >>> graphdef, params, batch_stats = nnx.split(node, nnx.Param, nnx.BatchStat) ... >>> jax.tree.map(jnp.shape, params) State({ 'batch_norm': { 'bias': VariableState( type=Param, value=(2,) ), 'scale': VariableState( type=Param, value=(2,) ) }, 'linear': { 'bias': VariableState( type=Param, value=(3,) ), 'kernel': VariableState( type=Param, value=(2, 3) ) } }) >>> jax.tree.map(jnp.shape, batch_stats) State({ 'batch_norm': { 'mean': VariableState( type=BatchStat, value=(2,) ), 'var': VariableState( type=BatchStat, value=(2,) ) } })
split()
andmerge()
are primarily used to interact directly with JAX transformations, see Functional API for more information.- Parameters
node – graph node to split.
*filters – some optional filters to group the state into mutually exclusive substates.
- Returns
GraphDef
and one or moreStates
equal to the number of filters passed. If no filters are passed, a singleState
is returned.
- flax.nnx.merge(graphdef, state, /, *states)[source]#
The inverse of
flax.nnx.split()
.nnx.merge
takes aflax.nnx.GraphDef
and one or moreflax.nnx.State
’s and creates a new node with the same structure as the original node.Recall:
flax.nnx.split()
is used to represent aflax.nnx.Module
by: 1) a staticnnx.GraphDef
that captures its Pythonic static information; and 2) one or moreflax.nnx.Variable
nnx.State
’(s) that capture itsjax.Array
’s in the form of JAX pytrees.nnx.merge
is used in conjunction withnnx.split
to switch seamlessly between stateful and stateless representations of the graph.Example usage:
>>> from flax import nnx >>> import jax, jax.numpy as jnp ... >>> class Foo(nnx.Module): ... def __init__(self, rngs): ... self.batch_norm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... >>> node = Foo(nnx.Rngs(0)) >>> graphdef, params, batch_stats = nnx.split(node, nnx.Param, nnx.BatchStat) ... >>> new_node = nnx.merge(graphdef, params, batch_stats) >>> assert isinstance(new_node, Foo) >>> assert isinstance(new_node.batch_norm, nnx.BatchNorm) >>> assert isinstance(new_node.linear, nnx.Linear)
nnx.split
andnnx.merge
are primarily used to interact directly with JAX transformations (refer to Functional API for more information.- Parameters
graphdef – A
flax.nnx.GraphDef
object.state – A
flax.nnx.State
object.*states – Additional
flax.nnx.State
objects.
- Returns
The merged
flax.nnx.Module
.
- flax.nnx.update(node, state, /, *states)[source]#
Update the given graph node with a new state(s) in-place.
Example usage:
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> x = jnp.ones((1, 2)) >>> y = jnp.ones((1, 3)) >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> def loss_fn(model, x, y): ... return jnp.mean((y - model(x))**2) >>> prev_loss = loss_fn(model, x, y) >>> grads = nnx.grad(loss_fn)(model, x, y) >>> new_state = jax.tree.map(lambda p, g: p - 0.1*g, nnx.state(model), grads) >>> nnx.update(model, new_state) >>> assert loss_fn(model, x, y) < prev_loss
- flax.nnx.pop(node, *filters)[source]#
Pop one or more
Variable
types from the graph node.Example usage:
>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... self.sow(nnx.Intermediate, 'i', x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> assert not hasattr(model, 'i') >>> y = model(x) >>> assert hasattr(model, 'i') >>> intermediates = nnx.pop(model, nnx.Intermediate) >>> assert intermediates['i'].value[0].shape == (1, 3) >>> assert not hasattr(model, 'i')
- flax.nnx.state(node, *filters)[source]#
Similar to
split()
but only returns theState
’s indicated by the filters.Example usage:
>>> from flax import nnx >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.batch_norm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... def __call__(self, x): ... return self.linear(self.batch_norm(x)) >>> model = Model(rngs=nnx.Rngs(0)) >>> # get the learnable parameters from the batch norm and linear layer >>> params = nnx.state(model, nnx.Param) >>> # get the batch statistics from the batch norm layer >>> batch_stats = nnx.state(model, nnx.BatchStat) >>> # get them separately >>> params, batch_stats = nnx.state(model, nnx.Param, nnx.BatchStat) >>> # get them together >>> state = nnx.state(model)
- flax.nnx.variables(node, *filters)[source]#
Similar to
state()
but returns the currentVariable
objects instead of newVariableState
instances.Example:
>>> from flax import nnx ... >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> params = nnx.variables(model, nnx.Param) ... >>> assert params['kernel'] is model.kernel >>> assert params['bias'] is model.bias
- flax.nnx.graph()#
- flax.nnx.graphdef(node, /)[source]#
Get the
GraphDef
of the given graph node.Example usage:
>>> from flax import nnx >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> graphdef, _ = nnx.split(model) >>> assert graphdef == nnx.graphdef(model)
- flax.nnx.iter_graph(node, /)[source]#
Iterates over all nested nodes and leaves of the given graph node, including the current node.
iter_graph
creates a generator that yields path and value pairs, where the path is a tuple of strings or integers representing the path to the value from the root. Repeated nodes are visited only once. Leaves include static values.- Example::
>>> from flax import nnx >>> import jax.numpy as jnp ... >>> class Linear(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.din, self.dout = din, dout ... self.w = nnx.Param(jax.random.uniform(rngs.next(), (din, dout))) ... self.b = nnx.Param(jnp.zeros((dout,))) ... >>> module = Linear(3, 4, rngs=nnx.Rngs(0)) >>> graph = [module, module] ... >>> for path, value in nnx.iter_graph(graph): ... print(path, type(value).__name__) ... (0, '_object__state') ObjectState (0, 'b') Param (0, 'din') int (0, 'dout') int (0, 'w') Param (0,) Linear () list
- flax.nnx.clone(node)[source]#
Create a deep copy of the given graph node.
Example usage:
>>> from flax import nnx >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> cloned_model = nnx.clone(model) >>> model.bias.value += 1 >>> assert (model.bias.value != cloned_model.bias.value).all()
- Parameters
node – A graph node object.
- Returns
A deep copy of the
Module
object.
- flax.nnx.call(graphdef_state, /)[source]#
Calls a method underlying graph node defined by a (GraphDef, State) pair.
call
takes a(GraphDef, State)
pair and creates a proxy object that can be used to call methods on the underlying graph node. When a method is called, the output is returned along with a new (GraphDef, State) pair that represents the updated state of the graph node.call
is equivalent tomerge()
>method
>split`()
but is more convenient to use in pure JAX functions.Example:
>>> from flax import nnx >>> import jax >>> import jax.numpy as jnp ... >>> class StatefulLinear(nnx.Module): ... def __init__(self, din, dout, rngs): ... self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout))) ... self.b = nnx.Param(jnp.zeros((dout,))) ... self.count = nnx.Variable(jnp.array(0, dtype=jnp.uint32)) ... ... def increment(self): ... self.count += 1 ... ... def __call__(self, x): ... self.increment() ... return x @ self.w + self.b ... >>> linear = StatefulLinear(3, 2, nnx.Rngs(0)) >>> linear_state = nnx.split(linear) ... >>> @jax.jit ... def forward(x, linear_state): ... y, linear_state = nnx.call(linear_state)(x) ... return y, linear_state ... >>> x = jnp.ones((1, 3)) >>> y, linear_state = forward(x, linear_state) >>> y, linear_state = forward(x, linear_state) ... >>> linear = nnx.merge(*linear_state) >>> linear.count.value Array(2, dtype=uint32)
The proxy object returned by
call
supports indexing and attribute access to access nested methods. In the example below, theincrement
method indexing is used to call theincrement
method of theStatefulLinear
module at theb
key of anodes
dictionary.>>> class StatefulLinear(nnx.Module): ... def __init__(self, din, dout, rngs): ... self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout))) ... self.b = nnx.Param(jnp.zeros((dout,))) ... self.count = nnx.Variable(jnp.array(0, dtype=jnp.uint32)) ... ... def increment(self): ... self.count += 1 ... ... def __call__(self, x): ... self.increment() ... return x @ self.w + self.b ... >>> rngs = nnx.Rngs(0) >>> nodes = dict( ... a=StatefulLinear(3, 2, rngs), ... b=StatefulLinear(2, 1, rngs), ... ) ... >>> node_state = nnx.split(nodes) >>> # use attribute access >>> _, node_state = nnx.call(node_state)['b'].increment() ... >>> nodes = nnx.merge(*node_state) >>> nodes['a'].count.value Array(0, dtype=uint32) >>> nodes['b'].count.value Array(1, dtype=uint32)
- flax.nnx.cached_partial(f, *cached_args)#
Create a partial from a NNX transformed function alog with some cached input arguments and reduces the python overhead by caching the traversal of NNX graph nodes. This is useful for speed up function that are called repeatedly with the same subset of inputs e.g. a
train_step
with amodel
andoptimizer
:>>> from flax import nnx >>> import jax.numpy as jnp >>> import optax ... >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> optimizer = nnx.Optimizer(model, optax.adamw(1e-3)) ... >>> @nnx.jit ... def train_step(model, optimizer, x, y): ... def loss_fn(model): ... return jnp.mean((model(x) - y) ** 2) ... ... loss, grads = nnx.value_and_grad(loss_fn)(model) ... optimizer.update(grads) ... return loss ... >>> cached_train_step = nnx.cached_partial(train_step, model, optimizer) ... >>> for step in range(total_steps:=2): ... x, y = jnp.ones((10, 2)), jnp.ones((10, 3)) ... # loss = train_step(model, optimizer, x, y) ... loss = cached_train_step(x, y) ... print(f'Step {step}: loss={loss:.3f}') Step 0: loss=2.669 Step 1: loss=2.660
Note that
cached_partial
will clone all cached graph nodes to gurantee the validity of the cache, and these clones will contain references to the same Variable objects which guarantees that state is propagated correctly back to the original graph nodes. Because of the previous, the final structure of all graph nodes must be the same after each call to the cached function, otherswise an error will be raised. Temporary mutations are allowed (e.g. the use ofModule.sow
) as long as they are cleaned up before the function returns (e.g. viannx.pop
).- Parameters
f – A function to cache.
*cached_args – A subset of the input arguments containing the graph nodes to cache.
- Returns
A partial function expecting the remaining arguments to the original function.
- class flax.nnx.GraphDef(nodes: 'list[NodeDef[tp.Any] | VariableDef[tp.Any] | NodeRef[tp.Any]]', attributes: 'list[tuple[Key, NodeAttr | ArrayAttr | Static[tp.Any]]]', num_leaves: 'int')[source]#
- class flax.nnx.UpdateContext(tag, outer_ref_outer_index, outer_index_inner_ref, outer_index_outer_ref, inner_ref_outer_index, static_cache)[source]#
A context manager for handling complex state updates.
- flax.nnx.update_context(tag)[source]#
Creates an
UpdateContext
context manager which can be used to handle more complex state updates beyond whatnnx.update
can handle, including updates to static properties and graph structure.UpdateContext exposes a
split
andmerge
API with the same signature asnnx.split
/nnx.merge
but performs some bookkeeping to have the necessary information in order to perfectly update the input objects based on the changes made inside the transform. The UpdateContext must call split and merge a total of 4 times, the first and last calls happen outside the transform and the second and third calls happen inside the transform as shown in the diagram below:idxmap (2) merge ─────────────────────────────► split (3) ▲ │ │ inside │ │. . . . . . . . . . . . . . . . . . │ index_mapping │ outside │ │ ▼ (1) split──────────────────────────────► merge (4) refmap
The first call to split
(1)
creates arefmap
which keeps track of the outer references, and the first call to merge(2)
creates anidxmap
which keeps track of the inner references. The second call to split(3)
combines the refmap and idxmap to produce theindex_mapping
which indicates how the outer references map to the inner references. Finally, the last call to merge(4)
uses the index_mapping and the refmap to reconstruct the output of the transform while reusing/updating the inner references. To avoid memory leaks, the idxmap is cleared after(3)
and the refmap is cleared after(4)
, and both are cleared after the context manager exits.Here is a simple example showing the use of
update_context
:>>> from flax import nnx ... >>> m1 = nnx.Dict({}) >>> with nnx.update_context('example'): ... with nnx.split_context('example') as ctx: ... graphdef, state = ctx.split(m1) ... @jax.jit ... def f(graphdef, state): ... with nnx.merge_context('example', inner=True) as ctx: ... m2 = ctx.merge(graphdef, state) ... m2.a = 1 ... m2.ref = m2 # create a reference cycle ... with nnx.split_context('example') as ctx: ... return ctx.split(m2) ... graphdef_out, state_out = f(graphdef, state) ... with nnx.merge_context('example', inner=False) as ctx: ... m3 = ctx.merge(graphdef_out, state_out) ... >>> assert m1 is m3 >>> assert m1.a == 1 >>> assert m1.ref is m1
Note that
update_context
takes in atag
argument which is used primarily as a safety mechanism reduce the risk of accidentally using the wrong UpdateContext when usingcurrent_update_context()
to access the current active context.update_context
can also be used as a decorator that creates/activates an UpdateContext context for the duration of the function:>>> from flax import nnx ... >>> m1 = nnx.Dict({}) >>> @jax.jit ... def f(graphdef, state): ... with nnx.merge_context('example', inner=True) as ctx: ... m2 = ctx.merge(graphdef, state) ... m2.a = 1 # insert static attribute ... m2.ref = m2 # create a reference cycle ... with nnx.split_context('example') as ctx: ... return ctx.split(m2) ... >>> @nnx.update_context('example') ... def g(m1): ... with nnx.split_context('example') as ctx: ... graphdef, state = ctx.split(m1) ... graphdef_out, state_out = f(graphdef, state) ... with nnx.merge_context('example', inner=False) as ctx: ... return ctx.merge(graphdef_out, state_out) ... >>> m3 = g(m1) >>> assert m1 is m3 >>> assert m1.a == 1 >>> assert m1.ref is m1
The context can be accessed using
current_update_context()
.- Parameters
tag – A string tag to identify the context.
- flax.nnx.current_update_context(tag)[source]#
Returns the current active
UpdateContext
for the given tag.