module#
- class flax.experimental.nnx.Module(*args, **kwargs)[source]#
- sow(variable_type, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)[source]#
- iter_modules()[source]#
Iterates over all nested Modules of the current Module, including the current Module.
iter_modules
creates a generator that yields the path and the Module instance, where the path is a tuple of strings or integers representing the path to the Module from the root Module.Example:
>>> from flax.experimental import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5) ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) ... ... >>> model = Block(2, 5, rngs=nnx.Rngs(0)) >>> for path, module in model.iter_modules(): ... print(path, type(module).__name__) ... () Block ('batch_norm',) BatchNorm ('dropout',) Dropout ('linear',) Linear
- eval(**attributes)[source]#
Sets the Module to evaluation mode.
eval
usesset_attributes
to recursively set attributesdeterministic=True
anduse_running_average=True
of all nested Modules that have these attributes. Its primarily used to control the runtime behavior of theDropout
andBatchNorm
Modules.Example:
>>> from flax.experimental import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5) ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) ... >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.dropout.deterministic, block.batch_norm.use_running_average (False, False) >>> block.eval() >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, True)
- Parameters
**attributes – additional attributes passed to
set_attributes
.
- property init#
Calls a method in initialization mode.
When a method is called using
init
, theis_initializing
method will returnTrue
. This is useful to implement Modules that support lazy initialization.Example:
>>> from flax.experimental import nnx >>> import jax >>> import jax.numpy as jnp ... >>> class Linear(nnx.Module): ... def __init__(self, dout, rngs: nnx.Rngs): ... self.dout = dout ... self.rngs = rngs ... ... def __call__(self, x): ... if self.is_initializing(): ... din = x.shape[-1] ... if not hasattr(self, 'w'): ... key = self.rngs.params() ... self.w = nnx.Param(jax.random.uniform(key, (din, self.dout))) ... if not hasattr(self, 'b'): ... self.b = nnx.Param(jnp.zeros((self.dout,))) ... ... return x @ self.w + self.b ... >>> linear = Linear(3, nnx.Rngs(0)) >>> x = jnp.ones((5, 2)) >>> y = linear.init(x) >>> linear.w.value.shape (2, 3) >>> linear.b.value.shape (3,) >>> y.shape (5, 3)
- is_initializing()[source]#
Returns whether the Module is initializing.
is_initializing
returnsTrue
if the Module is currently being run underinit
.
- iter_modules()[source]#
Iterates over all nested Modules of the current Module, including the current Module.
iter_modules
creates a generator that yields the path and the Module instance, where the path is a tuple of strings or integers representing the path to the Module from the root Module.Example:
>>> from flax.experimental import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5) ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) ... ... >>> model = Block(2, 5, rngs=nnx.Rngs(0)) >>> for path, module in model.iter_modules(): ... print(path, type(module).__name__) ... () Block ('batch_norm',) BatchNorm ('dropout',) Dropout ('linear',) Linear
- set_attributes(*filters, raise_if_not_found=True, **attributes)[source]#
Sets the attributes of nested Modules including the current Module. If the attribute is not found in the Module, it is ignored.
Example:
>>> from flax.experimental import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5, deterministic=False) ... self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs) ... >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.dropout.deterministic, block.batch_norm.use_running_average (False, False) >>> block.set_attributes(deterministic=True, use_running_average=True) >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, True)
Filter
’s can be used to set the attributes of specific Modules:>>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.set_attributes(nnx.Dropout, deterministic=True) >>> # Only the dropout will be modified >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, False)
- Parameters
*filters – Filters to select the Modules to set the attributes of.
raise_if_not_found – If True (default), raises a ValueError if at least one attribute instance is not found in one of the selected Modules.
**attributes – The attributes to set.
- train(**attributes)[source]#
Sets the Module to training mode.
train
usesset_attributes
to recursively set attributesdeterministic=False
anduse_running_average=False
of all nested Modules that have these attributes. Its primarily used to control the runtime behavior of theDropout
andBatchNorm
Modules.Example:
>>> from flax.experimental import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... # initialize Dropout and BatchNorm in eval mode ... self.dropout = nnx.Dropout(0.5, deterministic=True) ... self.batch_norm = nnx.BatchNorm(10, use_running_average=True, rngs=rngs) ... >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, True) >>> block.train() >>> block.dropout.deterministic, block.batch_norm.use_running_average (False, False)
- Parameters
**attributes – additional attributes passed to
set_attributes
.