# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pytype: skip-file
from __future__ import annotations
import dataclasses
import functools
from functools import partial
import typing as tp
from typing import Any
import jax
from flax import errors
from flax.nnx import filterlib, reprlib, tracers
from flax.typing import Missing, PathParts
import jax.tree_util as jtu
A = tp.TypeVar('A')
B = tp.TypeVar('B')
F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])
V = tp.TypeVar('V', bound='Variable[Any]')
GetValueHook = tp.Callable[['Variable[A]', A], A]
SetValueHook = tp.Callable[['Variable[A]', A], A]
CreateValueHook = tp.Callable[['Variable[A]', A], A]
AxisName = str
AxisIndex = int
AddAxisHook = tp.Callable[[V, AxisIndex, AxisName | None], None]
RemoveAxisHook = tp.Callable[[V, AxisIndex, AxisName | None], None]
VariableTypeCache: dict[str, tp.Type[Variable[tp.Any]]] = {}
[docs]class Variable(tp.Generic[A], reprlib.Representable):
"""The base class for all ``Variable`` types. Create custom ``Variable``
types by subclassing this class. Numerous NNX graph functions can filter
for specific ``Variable`` types, for example, :func:`split`, :func:`state`,
:func:`pop`, and :func:`State.filter`.
Example usage::
>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> class CustomVariable(nnx.Variable):
... pass
>>> class Model(nnx.Module):
... def __init__(self, rngs):
... self.linear = nnx.Linear(2, 3, rngs=rngs)
... self.custom_variable = CustomVariable(jnp.ones((1, 3)))
... def __call__(self, x):
... return self.linear(x) + self.custom_variable
>>> model = Model(rngs=nnx.Rngs(0))
>>> linear_variables = nnx.state(model, nnx.Param)
>>> jax.tree.map(jnp.shape, linear_variables)
State({
'linear': {
'bias': VariableState(
type=Param,
value=(3,)
),
'kernel': VariableState(
type=Param,
value=(2, 3)
)
}
})
>>> custom_variable = nnx.state(model, CustomVariable)
>>> jax.tree.map(jnp.shape, custom_variable)
State({
'custom_variable': VariableState(
type=CustomVariable,
value=(1, 3)
)
})
>>> variables = nnx.state(model)
>>> jax.tree.map(jnp.shape, variables)
State({
'custom_variable': VariableState(
type=CustomVariable,
value=(1, 3)
),
'linear': {
'bias': VariableState(
type=Param,
value=(3,)
),
'kernel': VariableState(
type=Param,
value=(2, 3)
)
}
})
"""
raw_value: A
_trace_state: tracers.TraceState
_var_metadata: dict[str, tp.Any]
def __init__(
self,
value: tp.Union[A, VariableMetadata[A]],
**metadata: tp.Any,
):
type_vars = vars(type(self))
vars_self = vars(self)
vars_self['_trace_state'] = tracers.TraceState()
if isinstance(value, VariableMetadata):
metadata.update(value.metadata)
value = tp.cast(A, value.raw_value)
object.__setattr__(self, 'raw_value', value)
if 'on_get_value' in type_vars and 'on_get_value' not in metadata:
metadata['get_value'] = getattr(type(self), 'on_get_value')
if 'on_set_value' in type_vars and 'on_set_value' not in metadata:
metadata['set_value'] = getattr(type(self), 'on_set_value')
if 'on_create_value' in type_vars and 'on_create_value' not in metadata:
metadata['create_value'] = getattr(type(self), 'on_create_value')
if 'on_add_axis' in type_vars and 'on_add_axis' not in metadata:
metadata['add_axis'] = getattr(type(self), 'on_add_axis')
if 'on_remove_axis' in type_vars and 'on_remove_axis' not in metadata:
metadata['remove_axis'] = getattr(type(self), 'on_remove_axis')
vars_self['_var_metadata'] = metadata
# run create_value hooks
vars_self['raw_value'] = self.create_value(self.raw_value)
def __getattr__(self, name: str) -> tp.Any:
if name in vars(self)['_var_metadata']:
return self._var_metadata[name]
return getattr(self.value, name)
def __setattr__(self, name: str, value: tp.Any):
if not self._trace_state.is_valid():
raise errors.TraceContextError(
f'Cannot mutate {type(self).__name__} from a different trace level'
)
if (
name == 'value'
or name == 'raw_value'
or name == '_var_metadata'
or name == '_trace_state'
):
object.__setattr__(self, name, value)
else:
self._var_metadata[name] = value
def __delattr__(self, name: str):
if not self._trace_state.is_valid():
raise errors.TraceContextError(
f'Cannot mutate {type(self).__name__} from a different trace level'
)
if (
name == 'value'
or name == 'raw_value'
or name == '_var_metadata'
or name == '_trace_state'
):
object.__delattr__(self, name)
else:
del self._var_metadata[name]
@classmethod
def state(cls, value: A, **metadata) -> VariableState[A]:
return cls(value, **metadata).to_state()
def get_metadata(self):
return self._var_metadata
def copy_from(self, other: Variable[A]) -> None:
if type(self) is not type(other):
raise ValueError(
f'Cannot copy from incompatible container, '
f'expected {type(self).__name__}, got {type(other).__name__}'
)
if self is other:
return
self.raw_value = other.raw_value
self._var_metadata.clear()
self._var_metadata.update(other.get_metadata())
def update_from_state(self, variable_state: VariableState[A]):
vars_self = vars(self)
vars_self['raw_value'] = variable_state.value
vars_self['_var_metadata'] = variable_state._var_metadata.copy()
@property
def value(self) -> A:
value = self.raw_value
if 'on_get_value' in self._var_metadata:
value = self._var_metadata['on_get_value'](self, value)
return value
@value.setter
def value(self, value: A):
if isinstance(value, Variable):
raise ValueError(
'Cannot set value to a Variable, ' 'use `copy_from` method instead'
)
if 'on_set_value' in self._var_metadata:
value = self._var_metadata['on_set_value'](self, value)
vars(self)['raw_value'] = value
def create_value(self, value: A):
if 'on_create_value' in self._var_metadata:
value = self._var_metadata['on_create_value'](self, value)
return value
def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None):
if 'on_add_axis' in self._var_metadata:
self._var_metadata['on_add_axis'](self, axis_index, axis_name)
def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None):
if 'on_remove_axis' in self._var_metadata:
self._var_metadata['on_remove_axis'](self, axis_index, axis_name)
def __eq__(self, other: object) -> bool:
return type(self) is type(other) and vars(other) == vars(self)
@tp.overload
def replace(self, value: B, **kwargs) -> Variable[B]: ...
@tp.overload
def replace(self, **kwargs) -> Variable[A]: ...
def replace(self, value: tp.Any = Missing, **kwargs) -> Variable[tp.Any]:
if value is not Missing:
kwargs['raw_value'] = value
# rename `value` to `raw_value`
if 'value' in kwargs:
kwargs['raw_value'] = kwargs.pop('value')
# return `value` if it is a Variable
if 'raw_value' in kwargs and isinstance(
value := kwargs['raw_value'], Variable
):
# remove value from kwargs
kwargs.pop('raw_value')
if type(self) is not type(value):
raise ValueError(
'Cannot replace value from incompatible container, '
f'expected {type(self).__name__}, got {type(value).__name__}'
)
# if kwargs aren't empty, recursively call replace
# else return variable value
if kwargs:
return value.replace(**kwargs)
else:
return value
# get and update attributes
# return new instance with updated attributes
obj = object.__new__(type(self))
object.__setattr__(obj, '_trace_state', self._trace_state)
object.__setattr__(obj, 'raw_value', kwargs.pop('raw_value'))
object.__setattr__(obj, '_var_metadata', self.get_metadata())
obj._var_metadata.update(kwargs)
return obj
@classmethod
def from_metadata(cls, value: A, attributes: tp.Mapping[str, tp.Any]):
obj = object.__new__(cls)
object.__setattr__(obj, '_trace_state', tracers.TraceState())
object.__setattr__(obj, 'raw_value', value)
object.__setattr__(obj, '_var_metadata', attributes)
return obj
def copy(self: Variable[A]) -> Variable[A]:
obj = object.__new__(type(self))
object.__setattr__(obj, '_trace_state', self._trace_state)
object.__setattr__(obj, 'raw_value', self.raw_value)
object.__setattr__(obj, '_var_metadata', self.get_metadata().copy())
return obj
def to_state(self: Variable[A]) -> VariableState[A]:
return VariableState(type(self), self.raw_value, **self._var_metadata)
def __nnx_repr__(self):
yield reprlib.Object(type=type(self))
yield reprlib.Attr('value', self.raw_value)
for name, value in self._var_metadata.items():
yield reprlib.Attr(name, repr(value))
def __treescope_repr__(self, path, subtree_renderer):
import treescope # type: ignore[import-not-found,import-untyped]
children = {'value': self.raw_value, **self._var_metadata}
return treescope.repr_lib.render_object_constructor(
object_type=type(self),
attributes=children,
path=path,
subtree_renderer=subtree_renderer,
)
# hooks API
if tp.TYPE_CHECKING:
def on_get_value(self, value: A) -> A: ...
def on_set_value(self, value: A) -> A: ...
def on_create_value(self, value: A) -> A: ...
def on_add_axis(
self: V, axis_index: AxisIndex, axis_name: AxisName | None
) -> V: ...
def on_remove_axis(
self: V, axis_index: AxisIndex, axis_name: AxisName | None
) -> V: ...
def __jax_array__(self):
return self.value
# pickle support
def __getstate__(self):
return vars(self).copy()
def __setstate__(self, state):
vars(self).update(state)
# --------------------------------------------
# proxy methods
# --------------------------------------------
def __getitem__(self, key) -> tp.Any:
return self.value[key] # type: ignore
def __setitem__(self, key, value) -> None:
self.value[key] = value # type: ignore
def __call__(self, *args, **kwargs) -> tp.Any:
return self.value(*args, **kwargs) # type: ignore
def __len__(self) -> int:
return len(self.value) # type: ignore
def __iter__(self) -> tp.Iterator:
return iter(self.value) # type: ignore
def __contains__(self, item) -> bool:
return item in self.value # type: ignore
def __add__(self, other) -> A:
return self.value.__add__(other) # type: ignore
def __sub__(self, other) -> A:
return self.value.__sub__(other) # type: ignore
def __mul__(self, other) -> A:
return self.value.__mul__(other) # type: ignore
def __matmul__(self, other) -> A:
return self.value.__matmul__(other) # type: ignore
def __truediv__(self, other) -> A:
return self.value.__truediv__(other) # type: ignore
def __floordiv__(self, other) -> A:
return self.value.__floordiv__(other) # type: ignore
def __mod__(self, other) -> A:
return self.value.__mod__(other) # type: ignore
def __divmod__(self, other) -> A:
return self.value.__divmod__(other) # type: ignore
def __pow__(self, other) -> A:
return self.value.__pow__(other) # type: ignore
def __lshift__(self, other) -> A:
return self.value.__lshift__(other) # type: ignore
def __rshift__(self, other) -> A:
return self.value.__rshift__(other) # type: ignore
def __and__(self, other) -> A:
return self.value.__and__(other) # type: ignore
def __xor__(self, other) -> A:
return self.value.__xor__(other) # type: ignore
def __or__(self, other) -> A:
return self.value.__or__(other) # type: ignore
def __radd__(self, other) -> A:
return self.value.__radd__(other) # type: ignore
def __rsub__(self, other) -> A:
return self.value.__rsub__(other) # type: ignore
def __rmul__(self, other) -> A:
return self.value.__rmul__(other) # type: ignore
def __rmatmul__(self, other) -> A:
return self.value.__rmatmul__(other) # type: ignore
def __rtruediv__(self, other) -> A:
return self.value.__rtruediv__(other) # type: ignore
def __rfloordiv__(self, other) -> A:
return self.value.__rfloordiv__(other) # type: ignore
def __rmod__(self, other) -> A:
return self.value.__rmod__(other) # type: ignore
def __rdivmod__(self, other) -> A:
return self.value.__rdivmod__(other) # type: ignore
def __rpow__(self, other) -> A:
return self.value.__rpow__(other) # type: ignore
def __rlshift__(self, other) -> A:
return self.value.__rlshift__(other) # type: ignore
def __rrshift__(self, other) -> A:
return self.value.__rrshift__(other) # type: ignore
def __rand__(self, other) -> A:
return self.value.__rand__(other) # type: ignore
def __rxor__(self, other) -> A:
return self.value.__rxor__(other) # type: ignore
def __ror__(self, other) -> A:
return self.value.__ror__(other) # type: ignore
def __iadd__(self: V, other) -> V:
value = self.value
if hasattr(value, '__iadd__'):
value.__iadd__(other)
else:
self.value = value.__add__(other)
return self
def __isub__(self: V, other) -> V:
value = self.value
if hasattr(value, '__isub__'):
value.__isub__(other)
else:
self.value = value.__sub__(other)
return self
def __imul__(self: V, other) -> V:
value = self.value
if hasattr(value, '__imul__'):
value.__imul__(other)
else:
self.value = value.__mul__(other)
return self
def __imatmul__(self: V, other) -> V:
value = self.value
if hasattr(value, '__imatmul__'):
value.__imatmul__(other)
else:
self.value = value.__matmul__(other)
return self
def __itruediv__(self: V, other) -> V:
value = self.value
if hasattr(value, '__itruediv__'):
value.__itruediv__(other)
else:
self.value = value.__truediv__(other)
return self
def __ifloordiv__(self: V, other) -> V:
value = self.value
if hasattr(value, '__ifloordiv__'):
value.__ifloordiv__(other)
else:
self.value = value.__floordiv__(other)
return self
def __imod__(self: V, other) -> V:
value = self.value
if hasattr(value, '__imod__'):
value.__imod__(other)
else:
self.value = value.__mod__(other)
return self
def __ipow__(self: V, other) -> V:
value = self.value
if hasattr(value, '__ipow__'):
value.__ipow__(other)
else:
self.value = value.__pow__(other)
return self
def __ilshift__(self: V, other) -> V:
value = self.value
if hasattr(value, '__ilshift__'):
value.__ilshift__(other)
else:
self.value = value.__lshift__(other)
return self
def __irshift__(self: V, other) -> V:
value = self.value
if hasattr(value, '__irshift__'):
value.__irshift__(other)
else:
self.value = value.__rshift__(other)
return self
def __iand__(self: V, other) -> V:
value = self.value
if hasattr(value, '__iand__'):
value.__iand__(other)
else:
self.value = value.__and__(other)
return self
def __ixor__(self: V, other) -> V:
value = self.value
if hasattr(value, '__ixor__'):
value.__ixor__(other)
else:
self.value = value.__xor__(other)
return self
def __ior__(self: V, other) -> V:
value = self.value
if hasattr(value, '__ior__'):
value.__ior__(other)
else:
self.value = value.__or__(other)
return self
def __neg__(self) -> A:
return self.value.__neg__() # type: ignore
def __pos__(self) -> A:
return self.value.__pos__() # type: ignore
def __abs__(self) -> A:
return self.value.__abs__() # type: ignore
def __invert__(self) -> A:
return self.value.__invert__() # type: ignore
def __complex__(self) -> A:
return self.value.__complex__() # type: ignore
def __int__(self) -> A:
return self.value.__int__() # type: ignore
def __float__(self) -> A:
return self.value.__float__() # type: ignore
def __index__(self) -> A:
return self.value.__index__() # type: ignore
def __round__(self, ndigits: int) -> A:
return self.value.__round__(ndigits) # type: ignore
def __trunc__(self) -> A:
return self.value.__trunc__() # type: ignore
def __floor__(self) -> A:
return self.value.__floor__() # type: ignore
def __ceil__(self) -> A:
return self.value.__ceil__() # type: ignore
[docs]class Param(Variable[A]):
"""The canonical learnable parameter. All learnable parameters
in NNX layer modules will have the ``Param`` :class:`Variable`
type::
>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> layer = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> jax.tree.map(jnp.shape, nnx.state(layer))
State({
'bias': VariableState(
type=Param,
value=(3,)
),
'kernel': VariableState(
type=Param,
value=(2, 3)
)
})
"""
pass
[docs]class BatchStat(Variable[A]):
"""The mean and variance batch statistics stored in
the :class:`BatchNorm` layer. Note, these are not the
learnable scale and bias parameters, but rather the
running average statistics that are typically used
during post-training inference::
>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> layer = nnx.BatchNorm(3, rngs=nnx.Rngs(0))
>>> jax.tree.map(jnp.shape, nnx.state(layer))
State({
'bias': VariableState(
type=Param,
value=(3,)
),
'mean': VariableState(
type=BatchStat,
value=(3,)
),
'scale': VariableState(
type=Param,
value=(3,)
),
'var': VariableState(
type=BatchStat,
value=(3,)
)
})
"""
pass
[docs]class Cache(Variable[A]):
"""Autoregressive cache in :class:`MultiHeadAttention`::
>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> layer = nnx.MultiHeadAttention(
... num_heads=2,
... in_features=3,
... qkv_features=6,
... out_features=6,
... decode=True,
... rngs=nnx.Rngs(0),
... )
>>> layer.init_cache((1, 3))
>>> jax.tree.map(jnp.shape, nnx.state(layer, nnx.Cache))
State({
'cache_index': VariableState(
type=Cache,
value=()
),
'cached_key': VariableState(
type=Cache,
value=(1, 2, 3)
),
'cached_value': VariableState(
type=Cache,
value=(1, 2, 3)
)
})
"""
pass
[docs]class VariableState(tp.Generic[A], reprlib.Representable):
__slots__ = ('type', 'value', '_var_metadata')
type: type[Variable[A]]
value: A
_var_metadata: dict[str, tp.Any]
def __init__(
self,
type: type[Variable[A]], # type: ignore [valid-type]
value: A,
**metadata,
):
object.__setattr__(self, 'type', type)
object.__setattr__(self, 'value', value)
object.__setattr__(self, '_var_metadata', metadata)
def __getattr__(self, name: str) -> None:
var_metadata = object.__getattribute__(self, '_var_metadata')
if name not in var_metadata:
raise AttributeError(f"'VariableState' object has no attribute '{name}'")
return var_metadata[name]
def __setattr__(self, name: str, value: Any) -> None:
if name == 'type' or name == 'value' or name == '_var_metadata':
object.__setattr__(self, name, value)
else:
self._var_metadata[name] = value
def __delattr__(self, name: str) -> None:
if name == 'type' or name == 'value' or name == '_var_metadata':
object.__delattr__(self, name)
else:
del self._var_metadata[name]
def __nnx_repr__(self):
yield reprlib.Object(type=type(self))
yield reprlib.Attr('type', self.type.__name__)
yield reprlib.Attr('value', self.value)
for name, value in self._var_metadata.items():
yield reprlib.Attr(name, repr(value))
def __treescope_repr__(self, path, subtree_renderer):
import treescope # type: ignore[import-not-found,import-untyped]
children = {'type': self.type, 'value': self.value, **self._var_metadata}
return treescope.repr_lib.render_object_constructor(
object_type=type(self),
attributes=children,
path=path,
subtree_renderer=subtree_renderer,
)
def replace(self, value: B) -> VariableState[B]:
return VariableState(self.type, value, **self.get_metadata())
def to_variable(self) -> Variable[A]:
# we use object.__new__ to avoid calling __init__ and bypass the
# __init__ logic which should not be called twice
variable = object.__new__(self.type)
object.__setattr__(variable, '_trace_state', tracers.TraceState())
object.__setattr__(variable, 'raw_value', self.value)
object.__setattr__(variable, '_var_metadata', self.get_metadata().copy())
return variable
def copy(self: VariableState[A]) -> VariableState[A]:
return jax.tree.map(lambda x: x, self)
def get_metadata(self) -> dict[str, tp.Any]:
return self._var_metadata
def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None):
if 'on_add_axis' in self._var_metadata:
self._var_metadata['on_add_axis'](self, axis_index, axis_name)
def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None):
if 'on_remove_axis' in self._var_metadata:
self._var_metadata['on_remove_axis'](self, axis_index, axis_name)
def _variable_state_flatten(x: VariableState[tp.Any], *, with_keys: bool):
metadata = tuple(x.get_metadata().items())
if with_keys:
node = (jtu.GetAttrKey('value'), x.value)
else:
node = x.value
return (node,), (x.type, metadata)
def _variable_state_unflatten(
static: tuple[type[Variable[A]], tuple[tuple[str, tp.Any], ...]],
children: tuple[A],
) -> VariableState[A]:
return VariableState(
type=static[0],
value=children[0],
**dict(static[1]),
)
jtu.register_pytree_with_keys(
VariableState,
partial(_variable_state_flatten, with_keys=True), # type: ignore
_variable_state_unflatten, # type: ignore
flatten_func=partial(_variable_state_flatten, with_keys=False), # type: ignore
)
def split_flat_state(
flat_state: tp.Iterable[tuple[PathParts, Variable | VariableState]],
filters: tuple[filterlib.Filter, ...],
) -> tuple[list[tuple[PathParts, Variable | VariableState]], ...]:
predicates = filterlib.filters_to_predicates(filters)
# we have n + 1 states, where n is the number of predicates
# the last state is for values that don't match any predicate
flat_states: tuple[list[tuple[PathParts, Variable | VariableState]], ...] = (
tuple([] for _ in predicates)
)
for path, value in flat_state:
for i, predicate in enumerate(predicates):
if predicate(path, value):
flat_states[i].append((path, value))
break
else:
raise ValueError(
'Non-exhaustive filters, got a non-empty remainder: '
f'{path} -> {value}.'
'\nUse `...` to match all remaining elements.'
)
return flat_states