class flax.experimental.nnx.Module(*args, **kwargs)[source]#
classmethod partial_init(state, *states)[source]#

Creates a constuctor that initializes the Module with the given state.

partial_init takes one or more States and returns a constructor that uses jax.jit to initialize the Module and update its state with the given States. Its semantically equivalent to:

module = MyModule(*args, **kwargs)
module.update(state, *states)

However, thanks to dead code elimination the resulting constructor will only initialize the subset of Variable’s that were part of the given state(s).


>>> import jax.numpy as jnp
>>> import jax
>>> from flax.experimental import nnx
>>> bias = jax.random.normal(jax.random.key(0), (4,))
>>> state = nnx.State({'bias': bias}) # in reality load it from a checkpoint
>>> linear = nnx.Linear.partial_init(state)(2, 4, rngs=nnx.Rngs(1))
>>> y = linear(jnp.ones((1, 2)))
>>> assert jnp.allclose(linear.bias, bias)
>>> assert y.shape == (1, 4)
  • state – The State to initialize the Module with.

  • *states – Additional States to initialize the Module with.


A constructor that initializes the Module with the given States.

set_attributes(*filters, raise_if_not_found=True, **attributes)[source]#

Sets the attributes of nested Modules including the current Module. If the attribute is not found in the Module, it is ignored.


>>> from flax.experimental import nnx
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, deterministic=False)
...     self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs)
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.set_attributes(deterministic=True, use_running_average=True)
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)

Filter’s can be used to set the attributes of specific Modules:

>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.set_attributes(nnx.Dropout, deterministic=True, use_running_average=True)
>>> # Only the dropout will be modified
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, False)
  • *filters – Filters to select the Modules to set the attributes of.

  • raise_if_not_found – If True (default), raises a ValueError if at least one attribute instance is not found in one of the selected Modules.

  • **attributes – The attributes to set.

flax.experimental.nnx.merge(graphdef, state, *states)[source]#