Source code for flax.nnx.statelib

# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pytype: skip-file
from __future__ import annotations

from collections.abc import MutableMapping
import typing as tp

import jax
import jax.tree_util as jtu

from flax.nnx import traversals
from flax.nnx import filterlib, reprlib
from flax.nnx import variablelib
from flax.typing import PathParts

A = tp.TypeVar('A')
K = tp.TypeVar('K', bound=tp.Hashable)
V = tp.TypeVar('V')

ExtractValueFn = tp.Callable[[tp.Any], tp.Any]
SetValueFn = tp.Callable[[V, tp.Any], V]


class NestedStateRepr(reprlib.Representable):
  def __init__(self, state: State):
    self.state = state

  def __nnx_repr__(self):
    yield reprlib.Object('', value_sep=': ', start='{', end='}')

    for r in self.state.__nnx_repr__():
      if isinstance(r, reprlib.Object):
        continue
      yield r

  def __treescope_repr__(self, path, subtree_renderer):
    children = {}
    for k, v in self.state.items():
      if isinstance(v, State):
        v = NestedStateRepr(v)
      children[k] = v
    # Render as the dictionary itself at the same path.
    return subtree_renderer(children, path=path)

class FlatState(tp.Sequence[tuple[PathParts, V]], reprlib.PrettySequence):
  _keys: tuple[PathParts, ...]
  _values: list[V]

  def __init__(self, items: tp.Iterable[tuple[PathParts, V]]):
    keys, values = [], []
    for key, value in items:
      keys.append(key)
      values.append(value)
    self._keys = tuple(keys)
    self._values = values

  @tp.overload
  def __getitem__(self, index: int) -> tuple[PathParts, V]: ...
  @tp.overload
  def __getitem__(self, index: slice) -> FlatState[V]: ...
  def __getitem__(
    self, index: int | slice
  ) -> tuple[PathParts, V] | FlatState[V]:
    if isinstance(index, int):
      return self._keys[index], self._values[index]
    return FlatState(zip(self._keys[index], self._values[index]))

  def __len__(self) -> int:
    return len(self._keys)

  def __iter__(self) -> tp.Iterator[tuple[PathParts, V]]:
    return iter(zip(self._keys, self._values))


def _flat_state_pytree_flatten(x: FlatState[V]):
  return x._values, x._keys


def _flat_state_pytree_unflatten(
  keys: tuple[PathParts, ...], values: list[V]
) -> FlatState[V]:
  flat_state = object.__new__(FlatState)
  flat_state._keys = keys
  flat_state._values = values
  return flat_state


jax.tree_util.register_pytree_node(
  FlatState,
  _flat_state_pytree_flatten,
  _flat_state_pytree_unflatten,
)


[docs]class State(MutableMapping[K, V], reprlib.Representable): """A pytree-like structure that contains a ``Mapping`` from hashable and comparable keys to leaves. Leaves can be of any type but :class:`VariableState` and :class:`Variable` are the most common. """ def __init__( self, mapping: tp.Union[ tp.Mapping[K, tp.Mapping | V], tp.Iterator[tuple[K, tp.Mapping | V]], ], /, *, _copy: bool = True, ): if _copy: _mapping = dict(mapping) else: if not isinstance(mapping, dict): raise ValueError( 'Expected a dictionary when `_copy=False`, ' f'got {type(mapping)} instead.' ) _mapping = mapping if tp.TYPE_CHECKING: self._mapping = _mapping else: super().__setattr__('_mapping', _mapping) @property def raw_mapping(self) -> tp.Mapping[K, tp.Mapping[K, tp.Any] | V]: return self._mapping # type: ignore def __contains__(self, key) -> bool: return key in self._mapping def __getitem__(self, key: K) -> State | V: # type: ignore value = self._mapping[key] if isinstance(value, tp.Mapping): return State(value, _copy=False) return value def __getattr__(self, key: K) -> State | V: # type: ignore[misc] if '_mapping' not in vars(self) or key not in self._mapping: raise AttributeError(f"No attribute '{key}' in State") return self[key] def __setitem__(self, key: K, value: State | V) -> None: if key == '__orig_class__': object.__setattr__(self, key, value) # type: ignore elif isinstance(value, State): self._mapping[key] = value._mapping else: self._mapping[key] = value __setattr__ = __setitem__ # type: ignore def __delitem__(self, key: K) -> None: del self._mapping[key] def __iter__(self) -> tp.Iterator[K]: return iter(self._mapping) def __len__(self) -> int: return len(self._mapping) def __nnx_repr__(self): yield reprlib.Object(type(self), value_sep=': ', start='({', end='})') for k, v in self.items(): if isinstance(v, State): v = NestedStateRepr(v) yield reprlib.Attr(repr(k), v) def __treescope_repr__(self, path, subtree_renderer): import treescope # type: ignore[import-not-found,import-untyped] children = {} for k, v in self.items(): if isinstance(v, State): v = NestedStateRepr(v) children[k] = v return treescope.repr_lib.render_dictionary_wrapper( object_type=type(self), wrapped_dict=children, path=path, subtree_renderer=subtree_renderer, ) def map(self, f: tp.Callable[[tuple, V], V]) -> State[K, V]: flat_state = self.flat_state() result = [ (path, f(path, variable_state)) for path, variable_state in flat_state ] return State.from_flat_path(result) def flat_state(self) -> FlatState[V]: return FlatState(traversals.flatten_to_sequence(self._mapping)) @classmethod def from_flat_path( cls, flat_state: tp.Mapping[PathParts, V] | tp.Iterable[tuple[PathParts, V]], /, ) -> State: if not isinstance(flat_state, tp.Mapping): flat_state = dict(flat_state) nested_state = traversals.unflatten_mapping(flat_state) return cls(nested_state) def to_pure_dict(self, extract_fn: ExtractValueFn | None = None ) -> dict[str, tp.Any]: # Works for nnx.Variable and nnx.VariableState if extract_fn is None: extract_fn = lambda x: x.value if hasattr(x, 'value') else x flat_values = {k: extract_fn(x) for k, x in self.flat_state()} return traversals.unflatten_mapping(flat_values) def replace_by_pure_dict(self, pure_dict: dict[str, tp.Any], replace_fn: SetValueFn | None = None): def try_convert_int(x): try: return int(x) except ValueError: return x # Works for nnx.Variable and nnx.VariableState if replace_fn is None: replace_fn = lambda x, v: x.replace(v) if hasattr(x, 'replace') else v current_flat = dict(self.flat_state()) for kp, v in traversals.flatten_mapping(pure_dict).items(): kp = tuple(map(try_convert_int, kp)) if kp not in current_flat: raise ValueError(f'key in pure_dict not available in state: {kp}') current_flat[kp] = replace_fn(current_flat[kp], v) self.update(traversals.unflatten_mapping(current_flat)) @tp.overload def split(self, first: filterlib.Filter, /) -> State[K, V]: ... @tp.overload def split( self, first: filterlib.Filter, second: filterlib.Filter, /, *filters: filterlib.Filter, ) -> tuple[State[K, V], ...]: ... @tp.overload def split( self, /, *filters: filterlib.Filter ) -> tp.Union[State[K, V], tuple[State[K, V], ...]]: ...
[docs] def split( # type: ignore[misc] self, first: filterlib.Filter, /, *filters: filterlib.Filter ) -> tp.Union[State[K, V], tuple[State[K, V], ...]]: """Split a ``State`` into one or more ``State``'s. The user must pass at least one ``Filter`` (i.e. :class:`Variable`), and the filters must be exhaustive (i.e. they must cover all :class:`Variable` types in the ``State``). Example usage:: >>> from flax import nnx >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.batchnorm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... def __call__(self, x): ... return self.linear(self.batchnorm(x)) >>> model = Model(rngs=nnx.Rngs(0)) >>> state = nnx.state(model) >>> param, batch_stats = state.split(nnx.Param, nnx.BatchStat) Arguments: first: The first filter *filters: The optional, additional filters to group the state into mutually exclusive substates. Returns: One or more ``States`` equal to the number of filters passed. """ filters = (first, *filters) *states_, rest = _split_state(self.flat_state(), *filters) if rest: raise ValueError( 'Non-exhaustive filters, got a non-empty remainder: ' f'{rest}.\nUse `...` to match all remaining elements.' ) states: State | tuple[State, ...] if len(states_) == 1: states = states_[0] else: states = tuple(states_) return states # type: ignore
@tp.overload def filter( self, first: filterlib.Filter, /, ) -> State[K, V]: ... @tp.overload def filter( self, first: filterlib.Filter, second: filterlib.Filter, /, *filters: filterlib.Filter, ) -> tuple[State[K, V], ...]: ...
[docs] def filter( self, first: filterlib.Filter, /, *filters: filterlib.Filter, ) -> tp.Union[State[K, V], tuple[State[K, V], ...]]: """Filter a ``State`` into one or more ``State``'s. The user must pass at least one ``Filter`` (i.e. :class:`Variable`). This method is similar to :meth:`split() <flax.nnx.State.state.split>`, except the filters can be non-exhaustive. Example usage:: >>> from flax import nnx >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.batchnorm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... def __call__(self, x): ... return self.linear(self.batchnorm(x)) >>> model = Model(rngs=nnx.Rngs(0)) >>> state = nnx.state(model) >>> param = state.filter(nnx.Param) >>> batch_stats = state.filter(nnx.BatchStat) >>> param, batch_stats = state.filter(nnx.Param, nnx.BatchStat) Arguments: first: The first filter *filters: The optional, additional filters to group the state into mutually exclusive substates. Returns: One or more ``States`` equal to the number of filters passed. """ *states_, _rest = _split_state(self.flat_state(), first, *filters) assert len(states_) == len(filters) + 1 states: State | tuple[State, ...] if len(states_) == 1: states = states_[0] else: states = tuple(states_) return states # type: ignore
[docs] @staticmethod def merge( state: tp.Mapping[K, V], /, *states: tp.Mapping[K, V] ) -> State[K, V]: """The inverse of :meth:`split() <flax.nnx.State.state.split>`. ``merge`` takes one or more ``State``'s and creates a new ``State``. Example usage:: >>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.batchnorm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... def __call__(self, x): ... return self.linear(self.batchnorm(x)) >>> model = Model(rngs=nnx.Rngs(0)) >>> params, batch_stats = nnx.state(model, nnx.Param, nnx.BatchStat) >>> params.linear.bias.value += 1 >>> state = nnx.State.merge(params, batch_stats) >>> nnx.update(model, state) >>> assert (model.linear.bias.value == jnp.array([1, 1, 1])).all() Args: state: A ``State`` object. *states: Additional ``State`` objects. Returns: The merged ``State``. """ if not states: if isinstance(state, State): return state return State(state) states = (state, *states) new_state: dict[PathParts, V] = {} for state in states: new_state.update(traversals.flatten_mapping(state)) # type: ignore[attribute-error] # pytype is wrong here return State.from_flat_path(new_state)
def __or__(self, other: State[K, V]) -> State[K, V]: if not other: return self return State.merge(self, other) def __sub__(self, other: State[K, V]) -> State[K, V]: if not other: return self self_flat = dict(self.flat_state()) other_flat = dict(other.flat_state()) diff = {k: v for k, v in self_flat.items() if k not in other_flat} return State.from_flat_path(diff)
def _state_flatten_with_keys(x: State): items = sorted(x._mapping.items()) children = tuple((jtu.DictKey(key), value) for key, value in items) return children, tuple(key for key, _ in items) def _state_unflatten( static: tuple[K, ...], leaves: tuple[V, ...] | tuple[dict[K, V]], ): return State(zip(static, leaves)) jax.tree_util.register_pytree_with_keys( State, _state_flatten_with_keys, _state_unflatten, # type: ignore[arg-type] ) def _split_state( flat_state: FlatState[V], *filters: filterlib.Filter, ) -> tuple[State[PathParts, V], ...]: for i, filter_ in enumerate(filters): if filter_ in (..., True) and i != len(filters) - 1: remaining_filters = filters[i + 1 :] if not all(f in (..., True) for f in remaining_filters): raise ValueError( '`...` or `True` can only be used as the last filters, ' f'got {filter_} it at index {i}.' ) predicates = tuple(map(filterlib.to_predicate, filters)) # we have n + 1 states, where n is the number of predicates # the last state is for values that don't match any predicate flat_states: tuple[list[tuple[PathParts, V]], ...] = tuple( [] for _ in range(len(predicates) + 1) ) for path, value in flat_state: for i, predicate in enumerate(predicates): if predicate(path, value): flat_states[i].append((path, value)) # type: ignore[index] # mypy is wrong here? break else: # if we didn't break, set leaf to last state flat_states[-1].append((path, value)) # type: ignore[index] # mypy is wrong here? return tuple(State.from_flat_path(flat_state) for flat_state in flat_states) def create_path_filters(state: State): flat_state = state.flat_state() value_paths: dict[tp.Any, set[PathParts]] = {} for path, value in flat_state: if isinstance(value, (variablelib.Variable, variablelib.VariableState)): value = value.value value_paths.setdefault(value, set()).add(path) return {filterlib.PathIn(*value_paths[value]): value for value in value_paths}