# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import dataclasses
import typing as tp
from functools import partial
import jax.tree_util as jtu
from flax.experimental.nnx.nnx import (
filterlib,
graph,
)
from flax.experimental.nnx.nnx import variables as variableslib
from flax.experimental.nnx.nnx.graph import GraphDef, GraphNode, GraphNodeMeta
from flax.experimental.nnx.nnx.proxy_caller import (
CallableProxy,
DelayedAccessor,
)
from flax.experimental.nnx.nnx.state import State
from flax.experimental.nnx.nnx.variables import Variable
from flax.typing import Path, PathParts
A = tp.TypeVar('A')
B = tp.TypeVar('B')
M = tp.TypeVar('M', bound='Module')
S = tp.TypeVar('S', bound=tp.Union[State, tuple[State, ...]])
V = tp.TypeVar('V', bound=variableslib.Variable[tp.Any])
StateMapping = tp.Mapping[Path, tp.Any]
tuple_reduce = lambda xs, x: xs + (x,)
tuple_init = lambda: ()
@tp.runtime_checkable
class _HasSetup(tp.Protocol):
def setup(self) -> None:
...
class ModuleMeta(GraphNodeMeta):
if not tp.TYPE_CHECKING:
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
return _module_meta_call(cls, *args, **kwargs)
def _module_meta_call(cls: tp.Type[M], *args, **kwargs) -> M:
module: M = GraphNodeMeta.__call__(cls, *args, **kwargs)
if dataclasses.is_dataclass(module):
if isinstance(module, _HasSetup):
module.setup()
return module
[docs]class Module(graph.GraphNode, metaclass=ModuleMeta):
""""""
[docs] def sow(
self,
variable_type: tp.Type[variableslib.Variable[tp.Any]],
name: str,
value: A,
reduce_fn: tp.Callable[[B, A], B] = tuple_reduce,
init_fn: tp.Callable[[], B] = tuple_init, # type: ignore
) -> None:
if hasattr(self, name):
variable = getattr(self, name)
if not isinstance(variable, variableslib.Variable):
raise ValueError(
f"Expected '{name}' to be a Variable, got {type(variable).__name__}"
)
elif type(variable) != variable_type:
raise ValueError(
f"Expected '{name}' to be of type '{variable_type.__name__}', "
f"got '{type(variable).__name__}'"
)
variable.raw_value = reduce_fn(variable.raw_value, value)
else:
reduced_value = reduce_fn(init_fn(), value)
setattr(self, name, variable_type(reduced_value))
@property
def init(self: M) -> M:
"""Calls a method in initialization mode.
When a method is called using ``init``, the ``is_initializing`` method
will return ``True``. This is useful to implement Modules that support
lazy initialization.
Example::
>>> from flax.experimental import nnx
>>> import jax
>>> import jax.numpy as jnp
...
>>> class Linear(nnx.Module):
... def __init__(self, dout, rngs: nnx.Rngs):
... self.dout = dout
... self.rngs = rngs
...
... def __call__(self, x):
... if self.is_initializing():
... din = x.shape[-1]
... if not hasattr(self, 'w'):
... key = self.rngs.params()
... self.w = nnx.Param(jax.random.uniform(key, (din, self.dout)))
... if not hasattr(self, 'b'):
... self.b = nnx.Param(jnp.zeros((self.dout,)))
...
... return x @ self.w + self.b
...
>>> linear = Linear(3, nnx.Rngs(0))
>>> x = jnp.ones((5, 2))
>>> y = linear.init(x)
>>> linear.w.value.shape
(2, 3)
>>> linear.b.value.shape
(3,)
>>> y.shape
(5, 3)
"""
def _init_context(accessor: DelayedAccessor, *args, **kwargs):
for _, value in graph.iter_nodes(self):
if isinstance(value, GraphNode):
value._graph_node__state._initializing = True
method = accessor(self)
try:
out = method(*args, **kwargs)
finally:
for _, value in graph.iter_nodes(self):
if isinstance(value, GraphNode):
value._graph_node__state._initializing = False
return out
return CallableProxy(_init_context) # type: ignore
[docs] def is_initializing(self) -> bool:
"""Returns whether the Module is initializing.
``is_initializing`` returns ``True`` if the Module is currently being run
under ``init``.
"""
return self._graph_node__state._initializing
[docs] def iter_modules(self) -> tp.Iterator[tuple[PathParts, Module]]:
"""Iterates over all nested Modules of the current Module, including the current Module.
``iter_modules`` creates a generator that yields the path and the Module instance, where
the path is a tuple of strings or integers representing the path to the Module from the
root Module.
Example::
>>> from flax.experimental import nnx
...
>>> class Block(nnx.Module):
... def __init__(self, din, dout, *, rngs: nnx.Rngs):
... self.linear = nnx.Linear(din, dout, rngs=rngs)
... self.dropout = nnx.Dropout(0.5)
... self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
...
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
>>> for path, module in model.iter_modules():
... print(path, type(module).__name__)
...
('batch_norm',) BatchNorm
('dropout',) Dropout
('linear',) Linear
() Block
"""
for path, value in graph.iter_nodes(self):
if isinstance(value, Module):
yield path, value
[docs] def set_attributes(
self,
*filters: filterlib.Filter,
raise_if_not_found: bool = True,
**attributes: tp.Any,
) -> None:
"""Sets the attributes of nested Modules including the current Module.
If the attribute is not found in the Module, it is ignored.
Example::
>>> from flax.experimental import nnx
...
>>> class Block(nnx.Module):
... def __init__(self, din, dout, *, rngs: nnx.Rngs):
... self.linear = nnx.Linear(din, dout, rngs=rngs)
... self.dropout = nnx.Dropout(0.5, deterministic=False)
... self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.set_attributes(deterministic=True, use_running_average=True)
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)
``Filter``'s can be used to set the attributes of specific Modules::
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.set_attributes(nnx.Dropout, deterministic=True)
>>> # Only the dropout will be modified
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, False)
Args:
*filters: Filters to select the Modules to set the attributes of.
raise_if_not_found: If True (default), raises a ValueError if at least one attribute
instance is not found in one of the selected Modules.
**attributes: The attributes to set.
"""
remaining_attributes = set(attributes.keys())
if not filters:
filters = (True,)
predicates = tuple(map(filterlib.to_predicate, filters))
for path, module in self.iter_modules():
for predicate in predicates:
if predicate(path, module):
for name, value in attributes.items():
if hasattr(module, name):
if name in remaining_attributes:
remaining_attributes.remove(name)
setattr(module, name, value)
break
if remaining_attributes and raise_if_not_found:
raise ValueError(
f'Could not find at least one instance of the following attributes: {remaining_attributes}'
)
[docs] def train(self, **attributes):
"""Sets the Module to training mode.
``train`` uses ``set_attributes`` to recursively set attributes ``deterministic=False``
and ``use_running_average=False`` of all nested Modules that have these attributes.
Its primarily used to control the runtime behavior of the ``Dropout`` and ``BatchNorm``
Modules.
Example::
>>> from flax.experimental import nnx
...
>>> class Block(nnx.Module):
... def __init__(self, din, dout, *, rngs: nnx.Rngs):
... self.linear = nnx.Linear(din, dout, rngs=rngs)
... # initialize Dropout and BatchNorm in eval mode
... self.dropout = nnx.Dropout(0.5, deterministic=True)
... self.batch_norm = nnx.BatchNorm(10, use_running_average=True, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)
>>> block.train()
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
Args:
**attributes: additional attributes passed to ``set_attributes``.
"""
return self.set_attributes(
deterministic=False,
use_running_average=False,
**attributes,
raise_if_not_found=False,
)
[docs] def eval(self, **attributes):
"""Sets the Module to evaluation mode.
``eval`` uses ``set_attributes`` to recursively set attributes ``deterministic=True``
and ``use_running_average=True`` of all nested Modules that have these attributes.
Its primarily used to control the runtime behavior of the ``Dropout`` and ``BatchNorm``
Modules.
Example::
>>> from flax.experimental import nnx
...
>>> class Block(nnx.Module):
... def __init__(self, din, dout, *, rngs: nnx.Rngs):
... self.linear = nnx.Linear(din, dout, rngs=rngs)
... self.dropout = nnx.Dropout(0.5)
... self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.eval()
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)
Args:
**attributes: additional attributes passed to ``set_attributes``.
"""
return self.set_attributes(
deterministic=True,
use_running_average=True,
**attributes,
raise_if_not_found=False,
)
def __init_subclass__(cls, experimental_pytree: bool = False) -> None:
super().__init_subclass__()
if experimental_pytree:
jtu.register_pytree_with_keys(
cls,
partial(_module_flatten, with_keys=True),
_module_unflatten,
flatten_func=partial(_module_flatten, with_keys=False),
)
# -------------------------
# Pytree Definition
# -------------------------
def _module_flatten(module: Module, *, with_keys: bool):
graphdef, state = graph.split(module)
key_values = sorted(state.raw_mapping.items())
keys = tuple(key for key, _ in key_values)
if with_keys:
children = tuple((jtu.DictKey(key), value) for key, value in key_values)
else:
children = tuple(value for _, value in key_values)
return children, (keys, graphdef)
def _module_unflatten(
paths_moduledef: tuple[tuple[Path, ...], GraphDef[M]],
variables: tuple[Variable[tp.Any], ...],
) -> M:
paths, graphdef = paths_moduledef
return graph.merge(graphdef, State(zip(paths, variables)))
def first_from(*args: tp.Optional[A], error_msg: str) -> A:
"""Return the first non-None argument.
If all arguments are None, raise a ValueError with the given error message.
Args:
*args: the arguments to check
error_msg: the error message to raise if all arguments are None
Returns:
The first non-None argument.
"""
for arg in args:
if arg is not None:
return arg
raise ValueError(error_msg)