module#
- class flax.experimental.nnx.Module(*args, **kwargs)[source]#
- classmethod partial_init(state, *states)[source]#
Creates a constuctor that initializes the Module with the given state.
partial_init
takes one or more States and returns a constructor that usesjax.jit
to initialize the Module and update its state with the given States. Its semantically equivalent to:module = MyModule(*args, **kwargs) module.update(state, *states)
However, thanks to dead code elimination the resulting constructor will only initialize the subset of
Variable
’s that were part of the given state(s).Example:
>>> import jax.numpy as jnp >>> import jax >>> from flax.experimental import nnx ... >>> bias = jax.random.normal(jax.random.key(0), (4,)) >>> state = nnx.State({'bias': bias}) # in reality load it from a checkpoint >>> linear = nnx.Linear.partial_init(state)(2, 4, rngs=nnx.Rngs(1)) >>> y = linear(jnp.ones((1, 2))) ... >>> assert jnp.allclose(linear.bias, bias) >>> assert y.shape == (1, 4)
- Parameters
state – The State to initialize the Module with.
*states – Additional States to initialize the Module with.
- Returns
A constructor that initializes the Module with the given States.
- set_attributes(*filters, raise_if_not_found=True, **attributes)[source]#
Sets the attributes of nested Modules including the current Module. If the attribute is not found in the Module, it is ignored.
Example:
>>> from flax.experimental import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5, deterministic=False) ... self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs) ... >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.dropout.deterministic, block.batch_norm.use_running_average (False, False) >>> block.set_attributes(deterministic=True, use_running_average=True) >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, True)
Filter
’s can be used to set the attributes of specific Modules:>>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.set_attributes(nnx.Dropout, deterministic=True, use_running_average=True) >>> # Only the dropout will be modified >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, False)
- Parameters
*filters – Filters to select the Modules to set the attributes of.
raise_if_not_found – If True (default), raises a ValueError if at least one attribute instance is not found in one of the selected Modules.
**attributes – The attributes to set.