# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import typing as tp
from functools import partial
import jax.tree_util as jtu
from flax.nnx import (
filterlib,
graph,
)
from flax.nnx import variablelib as variableslib
from flax.nnx.graph import GraphDef
from flax.nnx.object import Object, ObjectMeta
from flax.nnx.graph import GraphState, StateLeaf
from flax.nnx.statelib import State
from flax.typing import Key, Path, PathParts
A = tp.TypeVar('A')
B = tp.TypeVar('B')
M = tp.TypeVar('M', bound='Module')
S = tp.TypeVar('S', bound=tp.Union[GraphState, tuple[GraphState, ...]])
V = tp.TypeVar('V', bound=variableslib.Variable[tp.Any])
F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])
StateMapping = tp.Mapping[Path, tp.Any]
tuple_reduce = lambda xs, x: xs + (x,)
tuple_init = lambda: ()
class ModuleMeta(ObjectMeta):
# we keep a trivial derived class just in case we need to
# add more functionality in the future
pass
[docs]class Module(Object, metaclass=ModuleMeta):
"""Base class for all neural network modules.
Layers and models should subclass this class.
``Module``'s can contain submodules, and in this way can be nested in a tree
structure. Submodules can be assigned as regular attributes inside the
``__init__`` method.
You can define arbitrary "forward pass" methods on your ``Module`` subclass.
While no methods are special-cased, ``__call__`` is a popular choice since
you can call the ``Module`` directly::
>>> from flax import nnx
>>> import jax.numpy as jnp
>>> class Model(nnx.Module):
... def __init__(self, rngs):
... self.linear1 = nnx.Linear(2, 3, rngs=rngs)
... self.linear2 = nnx.Linear(3, 4, rngs=rngs)
... def __call__(self, x):
... x = self.linear1(x)
... x = nnx.relu(x)
... x = self.linear2(x)
... return x
>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> y = model(x)
"""
[docs] def sow(
self,
variable_type: tp.Type[variableslib.Variable[tp.Any]],
name: str,
value: A,
reduce_fn: tp.Callable[[B, A], B] = tuple_reduce,
init_fn: tp.Callable[[], B] = tuple_init, # type: ignore
) -> None:
"""``sow()`` can be used to collect intermediate values without
the overhead of explicitly passing a container through each Module call.
``sow()`` stores a value in a new ``Module`` attribute, denoted by ``name``.
The value will be wrapped by a :class:`Variable` of type ``variable_type``,
which can be useful to filter for in :func:`split`, :func:`state` and
:func:`pop`.
By default the values are stored in a tuple and each stored value
is appended at the end. This way all intermediates can be tracked when
the same module is called multiple times.
Example usage::
>>> from flax import nnx
>>> import jax.numpy as jnp
>>> class Model(nnx.Module):
... def __init__(self, rngs):
... self.linear1 = nnx.Linear(2, 3, rngs=rngs)
... self.linear2 = nnx.Linear(3, 4, rngs=rngs)
... def __call__(self, x, add=0):
... x = self.linear1(x)
... self.sow(nnx.Intermediate, 'i', x+add)
... x = self.linear2(x)
... return x
>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> assert not hasattr(model, 'i')
>>> y = model(x)
>>> assert hasattr(model, 'i')
>>> assert len(model.i.value) == 1 # tuple of length 1
>>> assert model.i.value[0].shape == (1, 3)
>>> y = model(x, add=1)
>>> assert len(model.i.value) == 2 # tuple of length 2
>>> assert (model.i.value[0] + 1 == model.i.value[1]).all()
Alternatively, a custom init/reduce function can be passed::
>>> class Model(nnx.Module):
... def __init__(self, rngs):
... self.linear1 = nnx.Linear(2, 3, rngs=rngs)
... self.linear2 = nnx.Linear(3, 4, rngs=rngs)
... def __call__(self, x):
... x = self.linear1(x)
... self.sow(nnx.Intermediate, 'sum', x,
... init_fn=lambda: 0,
... reduce_fn=lambda prev, curr: prev+curr)
... self.sow(nnx.Intermediate, 'product', x,
... init_fn=lambda: 1,
... reduce_fn=lambda prev, curr: prev*curr)
... x = self.linear2(x)
... return x
>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> y = model(x)
>>> assert (model.sum.value == model.product.value).all()
>>> intermediate = model.sum.value
>>> y = model(x)
>>> assert (model.sum.value == intermediate*2).all()
>>> assert (model.product.value == intermediate**2).all()
Args:
variable_type: The :class:`Variable` type for the stored value.
Typically :class:`Intermediate` is used to indicate an
intermediate value.
name: A string denoting the ``Module`` attribute name, where
the sowed value is stored.
value: The value to be stored.
reduce_fn: The function used to combine the existing value with the new
value. The default is to append the value to a tuple.
init_fn: For the first value stored, ``reduce_fn`` will be passed the result
of ``init_fn`` together with the value to be stored. The default is an
empty tuple.
"""
if hasattr(self, name):
variable = getattr(self, name)
if not isinstance(variable, variableslib.Variable):
raise ValueError(
f"Expected '{name}' to be a Variable, got {type(variable).__name__}"
)
elif type(variable) != variable_type:
raise ValueError(
f"Expected '{name}' to be of type '{variable_type.__name__}', "
f"got '{type(variable).__name__}'"
)
variable.raw_value = reduce_fn(variable.raw_value, value)
else:
reduced_value = reduce_fn(init_fn(), value)
setattr(self, name, variable_type(reduced_value))
[docs] def iter_modules(self) -> tp.Iterator[tuple[PathParts, Module]]:
"""Recursively iterates over all nested :class:`Module`'s of the current Module, including
the current Module.
``iter_modules`` creates a generator that yields the path and the Module instance, where
the path is a tuple of strings or integers representing the path to the Module from the
root Module.
Example::
>>> from flax import nnx
...
>>> class SubModule(nnx.Module):
... def __init__(self, din, dout, rngs):
... self.linear1 = nnx.Linear(din, dout, rngs=rngs)
... self.linear2 = nnx.Linear(din, dout, rngs=rngs)
...
>>> class Block(nnx.Module):
... def __init__(self, din, dout, *, rngs: nnx.Rngs):
... self.linear = nnx.Linear(din, dout, rngs=rngs)
... self.submodule = SubModule(din, dout, rngs=rngs)
... self.dropout = nnx.Dropout(0.5)
... self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
>>> for path, module in model.iter_modules():
... print(path, type(module).__name__)
...
('batch_norm',) BatchNorm
('dropout',) Dropout
('linear',) Linear
('submodule', 'linear1') Linear
('submodule', 'linear2') Linear
('submodule',) SubModule
() Block
"""
for path, value in graph.iter_graph(self):
if isinstance(value, Module):
yield path, value
[docs] def iter_children(self) -> tp.Iterator[tuple[Key, Module]]:
"""Iterates over all children :class:`Module`'s of the current Module. This
method is similar to :func:`iter_modules`, except it only iterates over the
immediate children, and does not recurse further down.
``iter_children`` creates a generator that yields the key and the Module instance,
where the key is a string representing the attribute name of the Module to access
the corresponding child Module.
Example::
>>> from flax import nnx
...
>>> class SubModule(nnx.Module):
... def __init__(self, din, dout, rngs):
... self.linear1 = nnx.Linear(din, dout, rngs=rngs)
... self.linear2 = nnx.Linear(din, dout, rngs=rngs)
...
>>> class Block(nnx.Module):
... def __init__(self, din, dout, *, rngs: nnx.Rngs):
... self.linear = nnx.Linear(din, dout, rngs=rngs)
... self.submodule = SubModule(din, dout, rngs=rngs)
... self.dropout = nnx.Dropout(0.5)
... self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
>>> for path, module in model.iter_children():
... print(path, type(module).__name__)
...
batch_norm BatchNorm
dropout Dropout
linear Linear
submodule SubModule
"""
node_dict = graph.get_node_impl(self).node_dict(self)
for key, value in node_dict.items():
if isinstance(value, Module):
yield key, value
[docs] def set_attributes(
self,
*filters: filterlib.Filter,
raise_if_not_found: bool = True,
**attributes: tp.Any,
) -> None:
"""Sets the attributes of nested Modules including the current Module.
If the attribute is not found in the Module, it is ignored.
Example::
>>> from flax import nnx
...
>>> class Block(nnx.Module):
... def __init__(self, din, dout, *, rngs: nnx.Rngs):
... self.linear = nnx.Linear(din, dout, rngs=rngs)
... self.dropout = nnx.Dropout(0.5, deterministic=False)
... self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.set_attributes(deterministic=True, use_running_average=True)
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)
``Filter``'s can be used to set the attributes of specific Modules::
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.set_attributes(nnx.Dropout, deterministic=True)
>>> # Only the dropout will be modified
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, False)
Args:
*filters: Filters to select the Modules to set the attributes of.
raise_if_not_found: If True (default), raises a ValueError if at least one attribute
instance is not found in one of the selected Modules.
**attributes: The attributes to set.
"""
remaining_attributes = set(attributes.keys())
if not filters:
filters = (True,)
predicates = tuple(map(filterlib.to_predicate, filters))
for path, module in self.iter_modules():
for predicate in predicates:
if predicate(path, module):
for name, value in attributes.items():
if hasattr(module, name):
if name in remaining_attributes:
remaining_attributes.remove(name)
setattr(module, name, value)
break
if remaining_attributes and raise_if_not_found:
raise ValueError(
f'Could not find at least one instance of the following attributes: {remaining_attributes}'
)
[docs] def train(self, **attributes):
"""Sets the Module to training mode.
``train`` uses ``set_attributes`` to recursively set attributes ``deterministic=False``
and ``use_running_average=False`` of all nested Modules that have these attributes.
Its primarily used to control the runtime behavior of the ``Dropout`` and ``BatchNorm``
Modules.
Example::
>>> from flax import nnx
...
>>> class Block(nnx.Module):
... def __init__(self, din, dout, *, rngs: nnx.Rngs):
... self.linear = nnx.Linear(din, dout, rngs=rngs)
... # initialize Dropout and BatchNorm in eval mode
... self.dropout = nnx.Dropout(0.5, deterministic=True)
... self.batch_norm = nnx.BatchNorm(10, use_running_average=True, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)
>>> block.train()
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
Args:
**attributes: additional attributes passed to ``set_attributes``.
"""
return self.set_attributes(
deterministic=False,
use_running_average=False,
**attributes,
raise_if_not_found=False,
)
[docs] def eval(self, **attributes):
"""Sets the Module to evaluation mode.
``eval`` uses ``set_attributes`` to recursively set attributes ``deterministic=True``
and ``use_running_average=True`` of all nested Modules that have these attributes.
Its primarily used to control the runtime behavior of the ``Dropout`` and ``BatchNorm``
Modules.
Example::
>>> from flax import nnx
...
>>> class Block(nnx.Module):
... def __init__(self, din, dout, *, rngs: nnx.Rngs):
... self.linear = nnx.Linear(din, dout, rngs=rngs)
... self.dropout = nnx.Dropout(0.5)
... self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.eval()
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)
Args:
**attributes: additional attributes passed to ``set_attributes``.
"""
return self.set_attributes(
deterministic=True,
use_running_average=True,
**attributes,
raise_if_not_found=False,
)
def __init_subclass__(cls, experimental_pytree: bool = False) -> None:
super().__init_subclass__()
if experimental_pytree:
jtu.register_pytree_with_keys(
cls,
partial(_module_flatten, with_keys=True),
_module_unflatten, # type: ignore[arg-type]
flatten_func=partial(_module_flatten, with_keys=False),
)
def __treescope_repr__(self, path, subtree_renderer):
import treescope # type: ignore[import-not-found,import-untyped]
children = {}
for name, value in vars(self).items():
if name.startswith('_'):
continue
children[name] = value
return treescope.repr_lib.render_object_constructor(
object_type=type(self),
attributes=children,
path=path,
subtree_renderer=subtree_renderer,
color=treescope.formatting_util.color_from_string(
type(self).__qualname__
)
)
# -------------------------
# Pytree Definition
# -------------------------
def _module_flatten(module: Module, *, with_keys: bool):
graphdef, state = graph.split(module)
key_values = sorted(state.raw_mapping.items())
keys = tuple(key for key, _ in key_values)
children: tuple[tp.Any, ...]
if with_keys:
children = tuple((jtu.DictKey(key), value) for key, value in key_values)
else:
children = tuple(value for _, value in key_values)
return children, (keys, graphdef)
def _module_unflatten(
paths_moduledef: tuple[tuple[Path, ...], GraphDef[M]],
variables: tuple[StateLeaf, ...],
) -> M:
paths, graphdef = paths_moduledef
return graph.merge(graphdef, State(zip(paths, variables)))
def first_from(*args: tp.Optional[A], error_msg: str) -> A:
"""Return the first non-None argument.
If all arguments are None, raise a ValueError with the given error message.
Args:
*args: the arguments to check
error_msg: the error message to raise if all arguments are None
Returns:
The first non-None argument.
"""
for arg in args:
if arg is not None:
return arg
raise ValueError(error_msg)