module#
- class flax.nnx.Module(self, /, *args, **kwargs)[source]#
Base class for all neural network modules.
Layers and models should subclass this class.
Module
’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the__init__
method.You can define arbitrary “forward pass” methods on your
Module
subclass. While no methods are special-cased,__call__
is a popular choice since you can call theModule
directly:>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... x = nnx.relu(x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> y = model(x)
- eval(**attributes)[source]#
Sets the Module to evaluation mode.
eval
usesset_attributes
to recursively set attributesdeterministic=True
anduse_running_average=True
of all nested Modules that have these attributes. Its primarily used to control the runtime behavior of theDropout
andBatchNorm
Modules.Example:
>>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5) ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) ... >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.dropout.deterministic, block.batch_norm.use_running_average (False, False) >>> block.eval() >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, True)
- Parameters
**attributes – additional attributes passed to
set_attributes
.
- iter_children()[source]#
Iterates over all children
Module
’s of the current Module. This method is similar toiter_modules()
, except it only iterates over the immediate children, and does not recurse further down.iter_children
creates a generator that yields the key and the Module instance, where the key is a string representing the attribute name of the Module to access the corresponding child Module.Example:
>>> from flax import nnx ... >>> class SubModule(nnx.Module): ... def __init__(self, din, dout, rngs): ... self.linear1 = nnx.Linear(din, dout, rngs=rngs) ... self.linear2 = nnx.Linear(din, dout, rngs=rngs) ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.submodule = SubModule(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5) ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) ... >>> model = Block(2, 5, rngs=nnx.Rngs(0)) >>> for path, module in model.iter_children(): ... print(path, type(module).__name__) ... batch_norm BatchNorm dropout Dropout linear Linear submodule SubModule
- iter_modules()[source]#
Recursively iterates over all nested
Module
’s of the current Module, including the current Module.iter_modules
creates a generator that yields the path and the Module instance, where the path is a tuple of strings or integers representing the path to the Module from the root Module.Example:
>>> from flax import nnx ... >>> class SubModule(nnx.Module): ... def __init__(self, din, dout, rngs): ... self.linear1 = nnx.Linear(din, dout, rngs=rngs) ... self.linear2 = nnx.Linear(din, dout, rngs=rngs) ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.submodule = SubModule(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5) ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) ... >>> model = Block(2, 5, rngs=nnx.Rngs(0)) >>> for path, module in model.iter_modules(): ... print(path, type(module).__name__) ... ('batch_norm',) BatchNorm ('dropout',) Dropout ('linear',) Linear ('submodule', 'linear1') Linear ('submodule', 'linear2') Linear ('submodule',) SubModule () Block
- perturb(name, value, variable_type=<class 'flax.nnx.variablelib.Perturbation'>)[source]#
Add an zero-value variable (“perturbation”) to the intermediate value.
The gradient of
value
would be the same as the gradient of this perturbation variable. Therefore, if you define your loss function with both params and perturbations as standalone arguments, you can get the intermediate gradients ofvalue
by runningjax.grad
on the perturbation variable.Since the shape of the perturbation value depends on the shape of the input, a perturbation variable is only created after you run a sample input through the model once.
Note
This creates extra dummy variables of the same size as
value
, thus occupies more memory. Use it only to debug gradients in training.Example usage:
>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... x = self.perturb('xgrad', x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> y = jnp.ones((1, 4)) >>> model = Model(rngs=nnx.Rngs(0)) >>> assert not hasattr(model, 'xgrad') # perturbation requires a sample input run >>> _ = model(x) >>> assert model.xgrad.value.shape == (1, 3) # same as the intermediate value >>> # Take gradients on the Param and Perturbation variables >>> @nnx.grad(argnums=nnx.DiffState(argnum=0, filter=nnx.Any(nnx.Param, nnx.Perturbation))) ... def grad_loss(model, inputs, targets): ... preds = model(inputs) ... return jnp.square(preds - targets).mean() >>> intm_grads = grad_loss(model, x, y) >>> # `intm_grads.xgrad.value` is the intermediate gradient >>> assert not jnp.array_equal(intm_grads.xgrad.value, jnp.zeros((1, 3)))
- Parameters
name – A string denoting the
Module
attribute name for the perturbation value.value – The value to take intermediate gradient.
variable_type – The
Variable
type for the stored perturbation. Defaulted atnnx.Perturbation
.
- set_attributes(*filters, raise_if_not_found=True, **attributes)[source]#
Sets the attributes of nested Modules including the current Module. If the attribute is not found in the Module, it is ignored.
Example:
>>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5, deterministic=False) ... self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs) ... >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.dropout.deterministic, block.batch_norm.use_running_average (False, False) >>> block.set_attributes(deterministic=True, use_running_average=True) >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, True)
Filter
’s can be used to set the attributes of specific Modules:>>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.set_attributes(nnx.Dropout, deterministic=True) >>> # Only the dropout will be modified >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, False)
- Parameters
*filters – Filters to select the Modules to set the attributes of.
raise_if_not_found – If True (default), raises a ValueError if at least one attribute instance is not found in one of the selected Modules.
**attributes – The attributes to set.
- sow(variable_type, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)[source]#
sow()
can be used to collect intermediate values without the overhead of explicitly passing a container through each Module call.sow()
stores a value in a newModule
attribute, denoted byname
. The value will be wrapped by aVariable
of typevariable_type
, which can be useful to filter for insplit()
,state()
andpop()
.By default the values are stored in a tuple and each stored value is appended at the end. This way all intermediates can be tracked when the same module is called multiple times.
Example usage:
>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x, add=0): ... x = self.linear1(x) ... self.sow(nnx.Intermediate, 'i', x+add) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> assert not hasattr(model, 'i') >>> y = model(x) >>> assert hasattr(model, 'i') >>> assert len(model.i.value) == 1 # tuple of length 1 >>> assert model.i.value[0].shape == (1, 3) >>> y = model(x, add=1) >>> assert len(model.i.value) == 2 # tuple of length 2 >>> assert (model.i.value[0] + 1 == model.i.value[1]).all()
Alternatively, a custom init/reduce function can be passed:
>>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... self.sow(nnx.Intermediate, 'sum', x, ... init_fn=lambda: 0, ... reduce_fn=lambda prev, curr: prev+curr) ... self.sow(nnx.Intermediate, 'product', x, ... init_fn=lambda: 1, ... reduce_fn=lambda prev, curr: prev*curr) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> y = model(x) >>> assert (model.sum.value == model.product.value).all() >>> intermediate = model.sum.value >>> y = model(x) >>> assert (model.sum.value == intermediate*2).all() >>> assert (model.product.value == intermediate**2).all()
- Parameters
variable_type – The
Variable
type for the stored value. TypicallyIntermediate
is used to indicate an intermediate value.name – A string denoting the
Module
attribute name, where the sowed value is stored.value – The value to be stored.
reduce_fn – The function used to combine the existing value with the new value. The default is to append the value to a tuple.
init_fn – For the first value stored,
reduce_fn
will be passed the result ofinit_fn
together with the value to be stored. The default is an empty tuple.
- train(**attributes)[source]#
Sets the Module to training mode.
train
usesset_attributes
to recursively set attributesdeterministic=False
anduse_running_average=False
of all nested Modules that have these attributes. Its primarily used to control the runtime behavior of theDropout
andBatchNorm
Modules.Example:
>>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... # initialize Dropout and BatchNorm in eval mode ... self.dropout = nnx.Dropout(0.5, deterministic=True) ... self.batch_norm = nnx.BatchNorm(10, use_running_average=True, rngs=rngs) ... >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, True) >>> block.train() >>> block.dropout.deterministic, block.batch_norm.use_running_average (False, False)
- Parameters
**attributes – additional attributes passed to
set_attributes
.