Source code for flax.nnx.graph

# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import contextlib
import dataclasses
import functools
import threading
import typing as tp

import jax
import numpy as np
import typing_extensions as tpe

from flax.nnx import filterlib, reprlib
from flax.nnx.proxy_caller import (
  ApplyCaller,
  CallableProxy,
  DelayedAccessor,
)
from flax.nnx.statelib import State
from flax.nnx import variablelib
from flax.nnx.variablelib import Variable, VariableState
from flax.typing import Key, PathParts, is_key_like

A = tp.TypeVar('A')
B = tp.TypeVar('B')
C = tp.TypeVar('C')
F = tp.TypeVar('F', bound=tp.Callable)

HA = tp.TypeVar('HA', bound=tp.Hashable)
HB = tp.TypeVar('HB', bound=tp.Hashable)
KeyT = tp.TypeVar('KeyT', bound=Key)

Index = int
Names = tp.Sequence[int]
Node = tp.TypeVar('Node')
Leaf = tp.TypeVar('Leaf')
AuxData = tp.TypeVar('AuxData')

StateLeaf = VariableState[tp.Any]
NodeLeaf = Variable[tp.Any]
GraphState = State[Key, StateLeaf]


def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
  return isinstance(x, VariableState)


def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[NodeLeaf]:
  return isinstance(x, Variable)


class RefMap(tp.MutableMapping[A, B], reprlib.MappingReprMixin[A, B]):
  """A mapping that uses object id as the hash for the keys."""

  def __init__(
    self, mapping: tp.Mapping[A, B] | tp.Iterable[tuple[A, B]] = (), /
  ):
    self._mapping: dict[int, tuple[A, B]] = {}
    self.update(mapping)

  def __getitem__(self, key: A) -> B:
    return self._mapping[id(key)][1]

  def __contains__(self, key: object) -> bool:
    return id(key) in self._mapping

  def __setitem__(self, key: A, value: B):
    self._mapping[id(key)] = (key, value)

  def __delitem__(self, key: A):
    del self._mapping[id(key)]

  def __iter__(self) -> tp.Iterator[A]:
    return (key for key, _ in self._mapping.values())

  def __len__(self) -> int:
    return len(self._mapping)

  def __str__(self) -> str:
    return repr(self)


@dataclasses.dataclass(frozen=True, slots=True)
class NodeImplBase(tp.Generic[Node, Leaf, AuxData]):
  type: type[Node]
  flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]]

  def node_dict(self, node: Node) -> dict[Key, Leaf]:
    nodes, _ = self.flatten(node)
    return dict(nodes)


@dataclasses.dataclass(frozen=True, slots=True)
class GraphNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
  set_key: tp.Callable[[Node, Key, Leaf], None]
  pop_key: tp.Callable[[Node, Key], Leaf]
  create_empty: tp.Callable[[AuxData], Node]
  clear: tp.Callable[[Node], None]
  init: tp.Callable[[Node, tp.Iterable[tuple[Key, Leaf]]], None]


@dataclasses.dataclass(frozen=True, slots=True)
class PytreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
  unflatten: tp.Callable[[tp.Sequence[tuple[Key, Leaf]], AuxData], Node]


NodeImpl = tp.Union[
  GraphNodeImpl[Node, Leaf, AuxData], PytreeNodeImpl[Node, Leaf, AuxData]
]


GRAPH_REGISTRY: dict[type, NodeImpl[tp.Any, tp.Any, tp.Any]] = {}
PYTREE_REGISTRY: dict[type, PytreeNodeImpl[tp.Any, tp.Any, tp.Any]] = {}


def register_graph_node_type(
  type: type,
  flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]],
  set_key: tp.Callable[[Node, Key, Leaf], None],
  pop_key: tp.Callable[[Node, Key], Leaf],
  create_empty: tp.Callable[[AuxData], Node],
  clear: tp.Callable[[Node], None],
  init: tp.Callable[[Node, tp.Iterable[tuple[Key, Leaf]]], None],
):
  if type in GRAPH_REGISTRY:
    raise ValueError(f'Node type {type} is already registered.')

  GRAPH_REGISTRY[type] = GraphNodeImpl(
    type=type,
    flatten=flatten,
    set_key=set_key,
    pop_key=pop_key,
    create_empty=create_empty,
    clear=clear,
    init=init,
  )

def register_pytree_node_type(
  type: type,
  flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]],
  unflatten: tp.Callable[[tp.Sequence[tuple[Key, Leaf]], AuxData], Node],
):
  if type in PYTREE_REGISTRY:
    raise ValueError(f'Node type {type} is already registered.')

  PYTREE_REGISTRY[type] = PytreeNodeImpl(
    type=type, flatten=flatten, unflatten=unflatten
  )

def is_node(x: tp.Any) -> bool:
  if type(x) in GRAPH_REGISTRY:
    return True
  return is_pytree_node(x)


def is_graph_node(x: tp.Any) -> bool:
  return type(x) in GRAPH_REGISTRY


def is_node_type(x: type[tp.Any]) -> bool:
  return x in GRAPH_REGISTRY or x in PYTREE_REGISTRY or x is GenericPytree


def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any]:
  if isinstance(x, Variable):
    raise ValueError(f'Variable is not a node: {x}')

  node_type = type(x)

  if node_type in GRAPH_REGISTRY:
    return GRAPH_REGISTRY[node_type]
  elif node_type in PYTREE_REGISTRY:
    return PYTREE_REGISTRY[node_type]
  elif is_pytree_node(x):
    return PYTREE_NODE_IMPL  # type: ignore
  else:
    raise ValueError(f'Unknown node type: {x}')


def get_node_impl_for_type(x: type[Node]) -> NodeImpl[Node, tp.Any, tp.Any]:
  if x is GenericPytree:
    return PYTREE_NODE_IMPL  # type: ignore
  elif x in PYTREE_REGISTRY:
    return PYTREE_REGISTRY[x]
  else:
    return GRAPH_REGISTRY[x]


class HashableMapping(tp.Mapping[HA, HB], tp.Hashable):
  def __init__(self, mapping: tp.Mapping[HA, HB], copy: bool = True):
    self._mapping = dict(mapping) if copy else mapping

  def __contains__(self, key: object) -> bool:
    return key in self._mapping

  def __getitem__(self, key: HA) -> HB:
    return self._mapping[key]

  def __iter__(self) -> tp.Iterator[HA]:
    return iter(self._mapping)

  def __len__(self) -> int:
    return len(self._mapping)

  def __hash__(self) -> int:
    return hash(tuple(sorted(self._mapping.items())))

  def __eq__(self, other: tp.Any) -> bool:
    return (
      isinstance(other, HashableMapping) and self._mapping == other._mapping
    )

  def __repr__(self) -> str:
    return repr(self._mapping)


[docs]class GraphDef(tp.Generic[Node]): """A class that represents all the static, stateless, and Pythonic parts of a Flax :class:`Module`. A ``GraphDef`` can be generated by either calling :func:`split` or :func:`graphdef` on the :class:`Module`.""" type: type[Node] index: int
@dataclasses.dataclass(frozen=True, repr=False) class NodeRef(GraphDef[Node], reprlib.Representable): type: type[Node] index: int def __nnx_repr__(self): yield reprlib.Object(type=type(self)) yield reprlib.Attr('type', self.type.__name__) yield reprlib.Attr('index', self.index) def __treescope_repr__(self, path, subtree_renderer): import treescope # type: ignore[import-not-found,import-untyped] return treescope.repr_lib.render_object_constructor( object_type=type(self), attributes={'type': self.type, 'index': self.index}, path=path, subtree_renderer=subtree_renderer, ) jax.tree_util.register_static(NodeRef) @dataclasses.dataclass(frozen=True, repr=False) class VariableDef(reprlib.Representable): type: type[Variable] index: int metadata: HashableMapping[str, tp.Any] def __nnx_repr__(self): yield reprlib.Object(type=type(self)) yield reprlib.Attr('type', self.type.__name__) yield reprlib.Attr('index', self.index) yield reprlib.Attr('metadata', reprlib.PrettyMapping(self.metadata)) def __treescope_repr__(self, path, subtree_renderer): import treescope # type: ignore[import-not-found,import-untyped] return treescope.repr_lib.render_object_constructor( object_type=type(self), attributes={ 'type': self.type, 'index': self.index, 'metadata': self.metadata, }, path=path, subtree_renderer=subtree_renderer, ) jax.tree_util.register_static(VariableDef) @dataclasses.dataclass(frozen=True, slots=True) class SubGraphAttribute: key: Key value: NodeDef[tp.Any] | NodeRef[tp.Any] @dataclasses.dataclass(frozen=True, slots=True) class StaticAttribute: key: Key value: tp.Any @dataclasses.dataclass(frozen=True, slots=True) class LeafAttribute: key: Key value: VariableDef | NodeRef[tp.Any] @dataclasses.dataclass(frozen=True, repr=False, slots=True) class NodeDef(GraphDef[Node], reprlib.Representable): """A dataclass that denotes the tree structure of a :class:`Module`. A ``GraphDef`` can be generated by either calling :func:`split` or :func:`graphdef` on the :class:`Module`.""" type: tp.Type[Node] index: int attributes: tuple[SubGraphAttribute | StaticAttribute | LeafAttribute, ...] metadata: tp.Any index_mapping: HashableMapping[Index, Index] | None @classmethod def create( cls, type: tp.Type[Node], index: int, attributes: tuple[SubGraphAttribute | StaticAttribute | LeafAttribute, ...], metadata: tp.Any, index_mapping: tp.Mapping[Index, Index] | None, ): return cls( type=type, index=index, attributes=attributes, metadata=metadata, index_mapping=HashableMapping(index_mapping) if index_mapping is not None else None, ) def __nnx_repr__(self): yield reprlib.Object(type=type(self)) yield reprlib.Attr('type', self.type.__name__) yield reprlib.Attr('index', self.index) yield reprlib.Attr('attributes', reprlib.PrettySequence(self.attributes)) yield reprlib.Attr('metadata', self.metadata) yield reprlib.Attr( 'index_mapping', reprlib.PrettyMapping(self.index_mapping) if self.index_mapping is not None else None, ) def __treescope_repr__(self, path, subtree_renderer): import treescope # type: ignore[import-not-found,import-untyped] return treescope.repr_lib.render_object_constructor( object_type=type(self), attributes={ 'type': self.type, 'index': self.index, 'attributes': self.attributes, 'metadata': self.metadata, }, path=path, subtree_renderer=subtree_renderer, ) def apply( self, state: GraphState, *states: GraphState ) -> ApplyCaller[tuple[GraphDef[Node], GraphState]]: accessor = DelayedAccessor() def _apply( accessor: DelayedAccessor, *args, **kwargs ) -> tuple[tp.Any, tuple[GraphDef[Node], GraphState]]: module = merge(self, state, *states) fn = accessor(module) out = fn(*args, **kwargs) return out, flatten(module) return CallableProxy(_apply, accessor) # type: ignore jax.tree_util.register_static(NodeDef) PureState = tuple[GraphDef[A], GraphState] def flatten( node: Node, /, ref_index: RefMap[tp.Any, Index] | None = None ) -> tuple[GraphDef[Node], GraphState]: """Flattens a graph node into a (graphdef, state) pair. Args: x: A graph node. ref_index: A mapping from nodes to indexes, defaults to None. If not provided, a new empty dictionary is created. This argument can be used to flatten a sequence of graph nodes that share references. """ if ref_index is None: ref_index = RefMap() flat_state: list[tuple[PathParts, StateLeaf]] = [] graphdef = _graph_flatten((), ref_index, flat_state, node) return graphdef, GraphState.from_flat_path(flat_state) def _graph_flatten( path: PathParts, ref_index: RefMap[tp.Any, Index], flat_state: list[tuple[PathParts, StateLeaf]], node: Node, ) -> NodeDef[Node] | NodeRef: if not is_node(node): raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.') if node in ref_index: return NodeRef(type(node), ref_index[node]) node_impl = get_node_impl(node) # only cache graph nodes if isinstance(node_impl, GraphNodeImpl): index = len(ref_index) ref_index[node] = index else: index = -1 attributes: list[SubGraphAttribute | StaticAttribute | LeafAttribute] = [] values, metadata = node_impl.flatten(node) for key, value in values: if is_node(value): nodedef = _graph_flatten((*path, key), ref_index, flat_state, value) # subgraphs.append((key, nodedef)) attributes.append(SubGraphAttribute(key, nodedef)) elif isinstance(value, Variable): if value in ref_index: attributes.append( LeafAttribute(key, NodeRef(type(value), ref_index[value])) ) else: flat_state.append(((*path, key), value.to_state())) variable_index = ref_index[value] = len(ref_index) variabledef = VariableDef( type(value), variable_index, HashableMapping(value._var_metadata) ) attributes.append(LeafAttribute(key, variabledef)) else: if isinstance(value, (jax.Array, np.ndarray)): path_str = '/'.join(map(str, (*path, key))) raise ValueError( f'Arrays leaves are not supported, at {path_str!r}: {value}' ) # static_fields.append((key, value)) attributes.append(StaticAttribute(key, value)) nodedef = NodeDef.create( type=node_impl.type, index=index, attributes=tuple(attributes), metadata=metadata, index_mapping=None, ) return nodedef def unflatten( graphdef: GraphDef[Node], state: tp.Mapping[KeyT, StateLeaf | tp.Mapping[Key, tp.Any]], /, *, index_ref: dict[Index, tp.Any] | None = None, index_ref_cache: dict[Index, tp.Any] | None = None, ) -> Node: """Unflattens a graphdef into a node with the given state. Args: graphdef: A GraphDef instance. state: A State instance. index_ref: A mapping from indexes to nodes references found during the graph traversal, defaults to None. If not provided, a new empty dictionary is created. This argument can be used to unflatten a sequence of (graphdef, state) pairs that share the same index space. index_ref_cache: A mapping from indexes to existing nodes that can be reused. When an reference is reused, ``GraphNodeImpl.clear`` is called to leave the object in an empty state and then filled by the unflatten process, as a result existing graph nodes are mutated to have the new content/topology specified by the graphdef. """ if isinstance(state, State): state = state.raw_mapping # type: ignore if index_ref is None: index_ref = {} assert isinstance(graphdef, (NodeDef, NodeRef)) node = _graph_unflatten(graphdef, state, index_ref, index_ref_cache) return node def _graph_unflatten( nodedef: NodeDef[Node] | NodeRef[Node], state: tp.Mapping[KeyT, StateLeaf | tp.Mapping[Key, tp.Any]], index_ref: dict[Index, tp.Any], index_ref_cache: dict[Index, tp.Any] | None, ) -> Node: """Recursive helper for graph_unflatten. Args: nodedef: A GraphDef instance or an index to a node in the cache. state: A mapping from attribute names to variables or subgraphs. index_to_ref: A mapping from indexes to nodes that have been traversed. If a node is already in the cache, it won't be traversed again. index_ref_cache: A mapping from indexes to existing nodes that can be reused. When an reference is reused, ``GraphNodeImpl.clear`` is called to leave the object in an empty state and then filled by the unflatten process, as a result existing graph nodes are mutated to have the new content/topology specified by the nodedef. """ if isinstance(nodedef, NodeRef): return index_ref[nodedef.index] if not is_node_type(nodedef.type): raise RuntimeError(f'Unsupported type: {nodedef.type}, this is a bug.') if nodedef.index in index_ref: raise RuntimeError(f'GraphDef index {nodedef.index} already used.') node_impl = get_node_impl_for_type(nodedef.type) def _get_children(): children: list[tuple[Key, NodeLeaf | Node]] = [] state_keys: set = set(state.keys()) # for every key in attributes there are 6 possible cases: # - (2) the key can either be present in the state or not # - (3) the key can be a subgraph, a leaf, or a static attribute for attribute in nodedef.attributes: key = attribute.key if key not in state: # if key is not present create an empty types if type(attribute) is StaticAttribute: children.append((key, attribute.value)) elif type(attribute) is SubGraphAttribute: # if the key is a subgraph we create an empty node subgraphdef = attribute.value assert not isinstance(subgraphdef, VariableDef) if isinstance(subgraphdef, NodeRef): # subgraph exists, take it from the cache children.append((key, index_ref[subgraphdef.index])) else: # create a node from an empty state, reasoning: # * its a node with no state # * its a node with state but only through references of already # created nodes substate = {} subnode = _graph_unflatten( subgraphdef, substate, index_ref, index_ref_cache ) children.append((key, subnode)) elif type(attribute) is LeafAttribute: variabledef = attribute.value if variabledef.index in index_ref: # variable exists, take it from the cache children.append((key, index_ref[variabledef.index])) else: # key for a variable is missing, raise an error raise ValueError( f'Expected key {key!r} in state while building node of type ' f'{nodedef.type.__name__}.' ) else: raise RuntimeError(f'Unknown static field: {key!r}') else: state_keys.remove(key) value = state[key] # if key in nodedef.static_fields: if type(attribute) is StaticAttribute: raise ValueError( f'Got state for static field {key!r}, this is not supported.' ) elif type(attribute) is SubGraphAttribute: if is_state_leaf(value): raise ValueError( f'Expected value of type {attribute.value} for ' f'{key!r}, but got {value!r}' ) assert isinstance(value, dict) subgraphdef = attribute.value if isinstance(subgraphdef, NodeRef): children.append((key, index_ref[subgraphdef.index])) else: subnode = _graph_unflatten( subgraphdef, value, index_ref, index_ref_cache ) children.append((key, subnode)) elif type(attribute) is LeafAttribute: variabledef = attribute.value if variabledef.index in index_ref: # add an existing variable assert isinstance(variabledef, NodeRef) children.append((key, index_ref[variabledef.index])) else: # its a unseen variable, create a new one assert isinstance(variabledef, VariableDef) # when idxmap is present, check if the Varable exists there # and update existing variables if it does if ( index_ref_cache is not None and variabledef.index in index_ref_cache ): # if variable exists, update it variable = index_ref_cache[variabledef.index] if not isinstance(variable, Variable): raise ValueError( f'Expected a Variable type for {key!r}, but got {type(variable)}.' ) if isinstance(value, VariableState): variable.update_from_state(value) else: variable.raw_value = value else: # if it doesn't, create a new variable if isinstance(value, VariableState): variable = value.to_variable() else: variable = variabledef.type.from_metadata( value, variabledef.metadata ) children.append((key, variable)) index_ref[variabledef.index] = variable else: raise RuntimeError(f'Unknown key: {key!r}, this is a bug.') # NOTE: we could allw adding new StateLeafs here if state_keys: raise ValueError(f'Unknown keys: {state_keys}') return children if isinstance(node_impl, GraphNodeImpl): # we create an empty node first and add it to the index # this avoids infinite recursion when there is a reference cycle if index_ref_cache is not None and nodedef.index in index_ref_cache: node = index_ref_cache[nodedef.index] if type(node) != nodedef.type: raise ValueError( f'Expected a node of type {nodedef.type} for index ' f'{nodedef.index}, but got a node of type {type(node)}.' ) node_impl.clear(node) else: node = node_impl.create_empty(nodedef.metadata) index_ref[nodedef.index] = node node_impl.init(node, _get_children()) else: # if the node type does not support the creation of an empty object it means # that it cannot reference itself, so we can create its children first node = node_impl.unflatten(_get_children(), nodedef.metadata) return node def graph_pop( node: tp.Any, filters: tuple[filterlib.Filter, ...], ) -> tuple[GraphState, ...]: id_to_index: dict[int, Index] = {} path_parts: PathParts = () predicates = tuple(filterlib.to_predicate(filter) for filter in filters) flat_states: tuple[dict[PathParts, StateLeaf], ...] = tuple( {} for _ in predicates ) _graph_pop(node, id_to_index, path_parts, flat_states, predicates) return tuple( GraphState.from_flat_path(flat_state) for flat_state in flat_states ) def _graph_pop( node: tp.Any, id_to_index: dict[int, Index], path_parts: PathParts, flat_states: tuple[dict[PathParts, StateLeaf], ...], predicates: tuple[filterlib.Predicate, ...], ) -> None: if not is_node(node): raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.') if id(node) in id_to_index: return id_to_index[id(node)] = len(id_to_index) node_impl = get_node_impl(node) node_dict = node_impl.node_dict(node) for name, value in node_dict.items(): if is_node(value): _graph_pop( node=value, id_to_index=id_to_index, path_parts=(*path_parts, name), flat_states=flat_states, predicates=predicates, ) continue elif not is_node_leaf(value): continue elif id(value) in id_to_index: continue node_path = (*path_parts, name) node_impl = get_node_impl(node) for state, predicate in zip(flat_states, predicates): if predicate(node_path, value): if isinstance(node_impl, PytreeNodeImpl): raise ValueError( f'Cannot pop key {name!r} from node of type {type(node).__name__}' ) id_to_index[id(value)] = len(id_to_index) node_impl.pop_key(node, name) if isinstance(value, Variable): value = value.to_state() state[node_path] = value # type: ignore[index] # mypy is wrong here? break else: # NOTE: should we raise an error here? pass def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[KeyT, tp.Any]): if not is_node(node): raise RuntimeError(f'Unsupported type: {type(node)}') node_impl = get_node_impl(node) node_dict = node_impl.node_dict(node) for key, value in state.items(): # case 1: new state is being added if key not in node_dict: if isinstance(node_impl, PytreeNodeImpl): raise ValueError( f'Cannot set key {key!r} on immutable node of ' f'type {type(node).__name__}' ) if isinstance(value, Variable): value = value.copy() node_impl.set_key(node, key, value) continue # check values are of the same type current_value = node_dict[key] # case 2: subgraph is being updated if is_node(current_value): if is_state_leaf(value): raise ValueError(f'Expected a subgraph for {key!r}, but got: {value!r}') _graph_update_dynamic(current_value, value) else: # case 3: state leaf is being updated if not isinstance(current_value, Variable): raise ValueError( f'Trying to update a non-Variable attribute {key!r} with a Variable: ' f'{value!r}' ) if isinstance(value, VariableState): # updated from VariableState current_value.update_from_state(value) else: # updated from raw value current_value.raw_value = value # -------------------------------------------------------- # UpdateContext # -------------------------------------------------------- @dataclasses.dataclass class GraphContext(threading.local): update_context_stacks: dict[str, list[UpdateContext]] = dataclasses.field( default_factory=dict ) ref_index_stack: list[SplitContext] = dataclasses.field(default_factory=list) index_ref_stack: list[MergeContext] = dataclasses.field(default_factory=list) GRAPH_CONTEXT = GraphContext() @dataclasses.dataclass class SplitContext: ctxtag: str | None ref_index: RefMap[tp.Any, Index] @tp.overload def split(self, graph_node: A, /) -> tuple[GraphDef[A], GraphState]: ... @tp.overload def split( self, graph_node: A, first: filterlib.Filter, / ) -> tuple[GraphDef[A], GraphState]: ... @tp.overload def split( self, graph_node: A, first: filterlib.Filter, second: filterlib.Filter, /, *filters: filterlib.Filter, ) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: ... def split( self, node: A, *filters: filterlib.Filter ) -> tuple[GraphDef[A], tpe.Unpack[tuple[GraphState, ...]]]: ctx = ( current_update_context(self.ctxtag) if self.ctxtag is not None else None ) graphdef, state = flatten(node, self.ref_index) states = _split_state(state, filters) if ctx is not None: if ctx.index_ref is not None and isinstance(graphdef, NodeDef): index_to_index = compose_mapping(ctx.index_ref, self.ref_index) graphdef = dataclasses.replace( graphdef, index_mapping=HashableMapping(index_to_index, copy=False) ) return graphdef, *states @contextlib.contextmanager def split_context(ctxtag: str | None = None): index_ref: RefMap[tp.Any, Index] = RefMap() flatten_ctx = SplitContext(ctxtag, index_ref) GRAPH_CONTEXT.ref_index_stack.append(flatten_ctx) try: yield flatten_ctx finally: GRAPH_CONTEXT.ref_index_stack.pop() if ctxtag is not None: ctx = current_update_context(ctxtag) ctx.flatten_end(index_ref) del flatten_ctx.ref_index del flatten_ctx.ctxtag @dataclasses.dataclass class MergeContext: ctxtag: str | None index_ref: dict[Index, tp.Any] def merge( self, graphdef: GraphDef[A], state: GraphState, /, *states: GraphState ) -> A: ctx = ( current_update_context(self.ctxtag) if self.ctxtag is not None else None ) if ( ctx is not None and isinstance(graphdef, NodeDef) and graphdef.index_mapping is not None ): # outer merge (4), create index_ref_cache assert ctx.ref_index is not None index_ref_cache = compose_mapping_reversed( ctx.ref_index, graphdef.index_mapping ) else: # inner merge (2) index_ref_cache = None state = State.merge(state, *states) node = unflatten( graphdef, state, index_ref=self.index_ref, index_ref_cache=index_ref_cache, ) return node @contextlib.contextmanager def merge_context(ctxtag: str | None = None): index_ref: dict[Index, tp.Any] = {} unflatten_ctx = MergeContext(ctxtag, index_ref) GRAPH_CONTEXT.index_ref_stack.append(unflatten_ctx) try: yield unflatten_ctx finally: GRAPH_CONTEXT.index_ref_stack.pop() if ctxtag is not None: ctx = current_update_context(ctxtag) ctx.unflatten_end(index_ref) del unflatten_ctx.index_ref del unflatten_ctx.ctxtag
[docs]@dataclasses.dataclass class UpdateContext: """A context manager for handling complex state updates.""" tag: str ref_index: RefMap[tp.Any, Index] | None index_ref: dict[Index, tp.Any] | None # define hash and eq to make this an opaque object def __hash__(self): return 0 def __eq__(self, other): return isinstance(other, UpdateContext) def flatten_end(self, ref_index: RefMap[tp.Any, Index]): if self.ref_index is None: # outer split (1), store the references self.ref_index = ref_index else: # inner split (3), clear index_ref self.index_ref = None def unflatten_end(self, index_ref: dict[Index, tp.Any]): self.index_ref = index_ref @tp.overload def split(self, graph_node: A, /) -> tuple[GraphDef[A], GraphState]: ... @tp.overload def split( self, graph_node: A, first: filterlib.Filter, / ) -> tuple[GraphDef[A], GraphState]: ... @tp.overload def split( self, graph_node: A, first: filterlib.Filter, second: filterlib.Filter, /, *filters: filterlib.Filter, ) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: ...
[docs] def split( self, node: A, *filters: filterlib.Filter ) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: """Split a graph node into a :class:`GraphDef` and one or more :class:`State`s. State is a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States. GraphDef contains all the static information needed to reconstruct a ``Module`` graph, it is analogous to JAX’s ``PyTreeDef``. :func:`split` is used in conjunction with :func:`merge` to switch seamlessly between stateful and stateless representations of the graph. Example usage:: >>> from flax import nnx >>> import jax, jax.numpy as jnp ... >>> class Foo(nnx.Module): ... def __init__(self, rngs): ... self.batch_norm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... >>> node = Foo(nnx.Rngs(0)) >>> graphdef, params, batch_stats = nnx.split(node, nnx.Param, nnx.BatchStat) ... >>> jax.tree.map(jnp.shape, params) State({ 'batch_norm': { 'bias': VariableState( type=Param, value=(2,) ), 'scale': VariableState( type=Param, value=(2,) ) }, 'linear': { 'bias': VariableState( type=Param, value=(3,) ), 'kernel': VariableState( type=Param, value=(2, 3) ) } }) >>> jax.tree.map(jnp.shape, batch_stats) State({ 'batch_norm': { 'mean': VariableState( type=BatchStat, value=(2,) ), 'var': VariableState( type=BatchStat, value=(2,) ) } }) Arguments: node: graph node to split. *filters: some optional filters to group the state into mutually exclusive substates. Returns: :class:`GraphDef` and one or more :class:`State`'s equal to the number of filters passed. If no filters are passed, a single :class:`State` is returned. """ ref_index: RefMap[tp.Any, Index] = RefMap() graphdef, state = flatten(node, ref_index) states = _split_state(state, filters) if self.index_ref is not None and isinstance(graphdef, NodeDef): index_to_index = compose_mapping(self.index_ref, ref_index) graphdef = dataclasses.replace( graphdef, index_mapping=HashableMapping(index_to_index, copy=False) ) self.flatten_end(ref_index) return graphdef, *states
[docs] def merge( self, graphdef: GraphDef[A], state: GraphState, *states: GraphState, ) -> A: """merge""" if not isinstance(graphdef, NodeDef): raise ValueError( f'Expected a NodeDef instance, but got {type(graphdef)}.' ) if self.ref_index is None: raise ValueError('Cannot merge without ref_index.') if graphdef.index_mapping is not None: # outer merge (4), create index_ref_cache assert self.ref_index is not None index_ref_cache = compose_mapping_reversed( self.ref_index, graphdef.index_mapping ) else: # inner merge (2) index_ref_cache = None state = State.merge(state, *states) index_ref: dict[Index, tp.Any] = {} node = unflatten( graphdef, state, index_ref=index_ref, index_ref_cache=index_ref_cache ) self.unflatten_end(index_ref) return node
jax.tree_util.register_static(UpdateContext) @dataclasses.dataclass class UpdateContextManager: tag: str def __enter__(self): ctx = UpdateContext(self.tag, None, None) if self.tag not in GRAPH_CONTEXT.update_context_stacks: GRAPH_CONTEXT.update_context_stacks[self.tag] = [ctx] else: GRAPH_CONTEXT.update_context_stacks[self.tag].append(ctx) return ctx def __exit__(self, *args): if self.tag not in GRAPH_CONTEXT.update_context_stacks: raise RuntimeError( f'No update context found for tag {self.tag!r}, this is a bug.' ) stack = GRAPH_CONTEXT.update_context_stacks[self.tag] ctx = stack.pop() # clear references del ctx.ref_index del ctx.index_ref if not stack: del GRAPH_CONTEXT.update_context_stacks[self.tag] def __call__(self, f: F) -> F: @functools.wraps(f) def update_context_manager_wrapper(*args, **kwargs): with self: return f(*args, **kwargs) return update_context_manager_wrapper # type: ignore
[docs]def update_context(tag: str): """Creates an :class:`UpdateContext` context manager which can be used to handle more complex state updates beyond what ``nnx.update`` can handle, including updates to static properties and graph structure. UpdateContext exposes a ``split`` and ``merge`` API with the same signature as ``nnx.split`` / ``nnx.merge`` but performs some bookkeeping to have the necessary information in order to perfectly update the input objects based on the changes made inside the transform. The UpdateContext must call split and merge a total of 4 times, the first and last calls happen outside the transform and the second and third calls happen inside the transform as shown in the diagram below:: idxmap (2) merge ─────────────────────────────► split (3) ▲ │ │ inside │ │. . . . . . . . . . . . . . . . . . │ index_mapping │ outside │ │ ▼ (1) split──────────────────────────────► merge (4) refmap The first call to split ``(1)`` creates a ``refmap`` which keeps track of the outer references, and the first call to merge ``(2)`` creates an ``idxmap`` which keeps track of the inner references. The second call to split ``(3)`` combines the refmap and idxmap to produce the ``index_mapping`` which indicates how the outer references map to the inner references. Finally, the last call to merge ``(4)`` uses the index_mapping and the refmap to reconstruct the output of the transform while reusing/updating the inner references. To avoid memory leaks, the idxmap is cleared after ``(3)`` and the refmap is cleared after ``(4)``, and both are cleared after the context manager exits. Here is a simple example showing the use of ``update_context``:: >>> from flax import nnx ... >>> m1 = nnx.Dict({}) >>> with nnx.update_context('example') as ctx: ... graphdef, state = ctx.split(m1) ... @jax.jit ... def f(graphdef, state): ... m2 = ctx.merge(graphdef, state) ... m2.a = 1 ... m2.ref = m2 # create a reference cycle ... return ctx.split(m2) ... graphdef_out, state_out = f(graphdef, state) ... m3 = ctx.merge(graphdef_out, state_out) ... >>> assert m1 is m3 >>> assert m1.a == 1 >>> assert m1.ref is m1 Note that ``update_context`` takes in a ``tag`` argument which is used primarily as a safety mechanism reduce the risk of accidentally using the wrong UpdateContext when using :func:`current_update_context` to access the current active context. current_update_context can be used as a way of accessing the current active context without having to pass it as a capture:: >>> from flax import nnx ... >>> m1 = nnx.Dict({}) >>> @jax.jit ... def f(graphdef, state): ... ctx = nnx.current_update_context('example') ... m2 = ctx.merge(graphdef, state) ... m2.a = 1 # insert static attribute ... m2.ref = m2 # create a reference cycle ... return ctx.split(m2) ... >>> @nnx.update_context('example') ... def g(m1): ... ctx = nnx.current_update_context('example') ... graphdef, state = ctx.split(m1) ... graphdef_out, state_out = f(graphdef, state) ... return ctx.merge(graphdef_out, state_out) ... >>> m3 = g(m1) >>> assert m1 is m3 >>> assert m1.a == 1 >>> assert m1.ref is m1 As shown in the code above, ``update_context`` can also be used as a decorator that creates/activates an UpdateContext context for the duration of the function. The context can be accessed using :func:`current_update_context`. Args: tag: A string tag to identify the context. """ return UpdateContextManager(tag)
[docs]def current_update_context(tag: str) -> UpdateContext: """Returns the current active :class:`UpdateContext` for the given tag.""" if tag not in GRAPH_CONTEXT.update_context_stacks: raise ValueError(f'No update context found for tag {tag!r}.') return GRAPH_CONTEXT.update_context_stacks[tag][-1]
# -------------------------------------------------------- # Functional API # -------------------------------------------------------- def _split_state( state: GraphState, filters: tuple[filterlib.Filter, ...], ) -> tuple[GraphState, tpe.Unpack[tuple[GraphState, ...]]]: if not filters: return (state,) states = state.split(*filters) if isinstance(states, State): return (states,) assert len(states) > 0 return states # type: ignore[return-value] @tp.overload def split(graph_node: A, /) -> tuple[GraphDef[A], GraphState]: ... @tp.overload def split( graph_node: A, first: filterlib.Filter, / ) -> tuple[GraphDef[A], GraphState]: ... @tp.overload def split( graph_node: A, first: filterlib.Filter, second: filterlib.Filter, /, *filters: filterlib.Filter, ) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: ...
[docs]def split( node: A, *filters: filterlib.Filter ) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: """Split a graph node into a :class:`GraphDef` and one or more :class:`State`s. State is a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States. GraphDef contains all the static information needed to reconstruct a ``Module`` graph, it is analogous to JAX’s ``PyTreeDef``. :func:`split` is used in conjunction with :func:`merge` to switch seamlessly between stateful and stateless representations of the graph. Example usage:: >>> from flax import nnx >>> import jax, jax.numpy as jnp ... >>> class Foo(nnx.Module): ... def __init__(self, rngs): ... self.batch_norm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... >>> node = Foo(nnx.Rngs(0)) >>> graphdef, params, batch_stats = nnx.split(node, nnx.Param, nnx.BatchStat) ... >>> jax.tree.map(jnp.shape, params) State({ 'batch_norm': { 'bias': VariableState( type=Param, value=(2,) ), 'scale': VariableState( type=Param, value=(2,) ) }, 'linear': { 'bias': VariableState( type=Param, value=(3,) ), 'kernel': VariableState( type=Param, value=(2, 3) ) } }) >>> jax.tree.map(jnp.shape, batch_stats) State({ 'batch_norm': { 'mean': VariableState( type=BatchStat, value=(2,) ), 'var': VariableState( type=BatchStat, value=(2,) ) } }) :func:`split` and :func:`merge` are primarily used to interact directly with JAX transformations, see `Functional API <https://flax.readthedocs.io/en/latest/nnx/nnx_basics.html#the-functional-api>`__ for more information. Arguments: node: graph node to split. *filters: some optional filters to group the state into mutually exclusive substates. Returns: ``GraphDef`` and one or more ``States`` equal to the number of filters passed. If no filters are passed, a single ``State`` is returned. """ graphdef, state = flatten(node) states = _split_state(state, filters) return graphdef, *states
[docs]def merge( graphdef: GraphDef[A], state: tp.Mapping[KeyT, tp.Any], /, *states: tp.Mapping[KeyT, tp.Any], ) -> A: """The inverse of :func:`split`. ``merge`` takes a :class:`GraphDef` and one or more :class:`State`'s and creates a new node with the same structure as the original node. Example usage:: >>> from flax import nnx >>> import jax, jax.numpy as jnp ... >>> class Foo(nnx.Module): ... def __init__(self, rngs): ... self.batch_norm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... >>> node = Foo(nnx.Rngs(0)) >>> graphdef, params, batch_stats = nnx.split(node, nnx.Param, nnx.BatchStat) ... >>> new_node = nnx.merge(graphdef, params, batch_stats) >>> assert isinstance(new_node, Foo) >>> assert isinstance(new_node.batch_norm, nnx.BatchNorm) >>> assert isinstance(new_node.linear, nnx.Linear) :func:`split` and :func:`merge` are primarily used to interact directly with JAX transformations, see `Functional API <https://flax.readthedocs.io/en/latest/nnx/nnx_basics.html#the-functional-api>`__ for more information. Args: graphdef: A :class:`GraphDef` object. state: A :class:`State` object. *states: Additional :class:`State` objects. Returns: The merged :class:`Module`. """ state = State.merge(state, *states) node = unflatten(graphdef, state) return node
[docs]def update( node, state: tp.Mapping[KeyT, tp.Any], /, *states: tp.Mapping[KeyT, tp.Any] ) -> None: """Update the given graph node with a new state(s) in-place. Example usage:: >>> from flax import nnx >>> import jax, jax.numpy as jnp >>> x = jnp.ones((1, 2)) >>> y = jnp.ones((1, 3)) >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> def loss_fn(model, x, y): ... return jnp.mean((y - model(x))**2) >>> prev_loss = loss_fn(model, x, y) >>> grads = nnx.grad(loss_fn)(model, x, y) >>> new_state = jax.tree.map(lambda p, g: p - 0.1*g, nnx.state(model), grads) >>> nnx.update(model, new_state) >>> assert loss_fn(model, x, y) < prev_loss Args: node: A graph node to update. state: A :class:`State` object. *states: Additional :class:`State` objects. """ if states: state = State.merge(state, *states) if isinstance(state, State): state = state.raw_mapping _graph_update_dynamic(node, state)
def _variables_generator(node) -> tp.Iterable[tuple[PathParts, Variable]]: for path, value in iter_graph(node): if isinstance(value, Variable): yield path, value @tp.overload def variables(node, /) -> State[Key, Variable]: ... @tp.overload def variables(node, first: filterlib.Filter, /) -> State[Key, Variable]: ... @tp.overload def variables( node, first: filterlib.Filter, second: filterlib.Filter, /, *filters: filterlib.Filter, ) -> tuple[State[Key, Variable], ...]: ...
[docs]def variables( node, *filters: filterlib.Filter, ) -> tp.Union[State[Key, Variable], tuple[State[Key, Variable], ...]]: """Similar to :func:`state` but returns the current :class:`Variable` objects instead of new :class:`VariableState` instances. Example:: >>> from flax import nnx ... >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> params = nnx.variables(model, nnx.Param) ... >>> assert params['kernel'] is model.kernel >>> assert params['bias'] is model.bias Args: node: A graph node object. *filters: One or more :class:`Variable` objects to filter by. Returns: One or more :class:`State` mappings containing the :class:`Variable` objects. """ num_filters = len(filters) if num_filters == 0: filters = (..., ...) else: filters = (*filters, ...) variables_iterable = _variables_generator(node) flat_states = variablelib.split_flat_state( variables_iterable, (*filters, ...) ) states = tuple(State.from_flat_path(flat_state) for flat_state in flat_states) if num_filters < 2: return states[0] return states
@tp.overload def state(node, /) -> GraphState: ... @tp.overload def state(node, first: filterlib.Filter, /) -> GraphState: ... @tp.overload def state( node, first: filterlib.Filter, second: filterlib.Filter, /, *filters: filterlib.Filter, ) -> tuple[GraphState, ...]: ...
[docs]def state( node, *filters: filterlib.Filter, ) -> tp.Union[GraphState, tuple[GraphState, ...]]: """Similar to :func:`split` but only returns the :class:`State`'s indicated by the filters. Example usage:: >>> from flax import nnx >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.batch_norm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... def __call__(self, x): ... return self.linear(self.batch_norm(x)) >>> model = Model(rngs=nnx.Rngs(0)) >>> # get the learnable parameters from the batch norm and linear layer >>> params = nnx.state(model, nnx.Param) >>> # get the batch statistics from the batch norm layer >>> batch_stats = nnx.state(model, nnx.BatchStat) >>> # get them separately >>> params, batch_stats = nnx.state(model, nnx.Param, nnx.BatchStat) >>> # get them together >>> state = nnx.state(model) Args: node: A graph node object. *filters: One or more :class:`Variable` objects to filter by. Returns: One or more :class:`State` mappings. """ _, state = flatten(node) states: GraphState | tuple[GraphState, ...] if len(filters) == 0: states = state elif len(filters) == 1: states = state.filter(filters[0]) else: states = state.filter(filters[0], filters[1], *filters[2:]) return states
[docs]def graphdef(node: tp.Any, /) -> GraphDef[tp.Any]: """Get the :class:`GraphDef` of the given graph node. Example usage:: >>> from flax import nnx >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> graphdef, _ = nnx.split(model) >>> assert graphdef == nnx.graphdef(model) Args: node: A graph node object. Returns: The :class:`GraphDef` of the :class:`Module` object. """ graphdef, _ = flatten(node) return graphdef
@tp.overload def pop( node, filter: filterlib.Filter, /, ) -> GraphState: ... @tp.overload def pop( node, filter: filterlib.Filter, filter2: filterlib.Filter, /, *filters: filterlib.Filter, ) -> tuple[GraphState, ...]: ...
[docs]def pop( node, *filters: filterlib.Filter ) -> tp.Union[GraphState, tuple[GraphState, ...]]: """Pop one or more :class:`Variable` types from the graph node. Example usage:: >>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... self.sow(nnx.Intermediate, 'i', x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> assert not hasattr(model, 'i') >>> y = model(x) >>> assert hasattr(model, 'i') >>> intermediates = nnx.pop(model, nnx.Intermediate) >>> assert intermediates['i'].value[0].shape == (1, 3) >>> assert not hasattr(model, 'i') Args: node: A graph node object. *filters: One or more :class:`Variable` objects to filter by. Returns: The popped :class:`State` containing the :class:`Variable` objects that were filtered for. """ if len(filters) == 0: raise ValueError('Expected at least one filter') id_to_index: dict[int, Index] = {} path_parts: PathParts = () predicates = tuple(filterlib.to_predicate(filter) for filter in filters) flat_states: tuple[dict[PathParts, StateLeaf], ...] = tuple( {} for _ in predicates ) _graph_pop( node=node, id_to_index=id_to_index, path_parts=path_parts, flat_states=flat_states, predicates=predicates, ) states = tuple( GraphState.from_flat_path(flat_state) for flat_state in flat_states ) if len(states) == 1: return states[0] else: return states
[docs]def clone(node: Node) -> Node: """Create a deep copy of the given graph node. Example usage:: >>> from flax import nnx >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> cloned_model = nnx.clone(model) >>> model.bias.value += 1 >>> assert (model.bias.value != cloned_model.bias.value).all() Args: node: A graph node object. Returns: A deep copy of the :class:`Module` object. """ graphdef, state = split(node) return merge(graphdef, state)
[docs]def call( graphdef_state: tuple[GraphDef[A], GraphState], / ) -> ApplyCaller[tuple[GraphDef[A], GraphState]]: """Calls a method underlying graph node defined by a (GraphDef, State) pair. ``call`` takes a ``(GraphDef, State)`` pair and creates a proxy object that can be used to call methods on the underlying graph node. When a method is called, the output is returned along with a new (GraphDef, State) pair that represents the updated state of the graph node. ``call`` is equivalent to :func:`merge` > ``method`` > :func:`split`` but is more convenient to use in pure JAX functions. Example:: >>> from flax import nnx >>> import jax >>> import jax.numpy as jnp ... >>> class StatefulLinear(nnx.Module): ... def __init__(self, din, dout, rngs): ... self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout))) ... self.b = nnx.Param(jnp.zeros((dout,))) ... self.count = nnx.Variable(jnp.array(0, dtype=jnp.uint32)) ... ... def increment(self): ... self.count += 1 ... ... def __call__(self, x): ... self.increment() ... return x @ self.w + self.b ... >>> linear = StatefulLinear(3, 2, nnx.Rngs(0)) >>> linear_state = nnx.split(linear) ... >>> @jax.jit ... def forward(x, linear_state): ... y, linear_state = nnx.call(linear_state)(x) ... return y, linear_state ... >>> x = jnp.ones((1, 3)) >>> y, linear_state = forward(x, linear_state) >>> y, linear_state = forward(x, linear_state) ... >>> linear = nnx.merge(*linear_state) >>> linear.count.value Array(2, dtype=uint32) The proxy object returned by ``call`` supports indexing and attribute access to access nested methods. In the example below, the ``increment`` method indexing is used to call the ``increment`` method of the ``StatefulLinear`` module at the ``b`` key of a ``nodes`` dictionary. >>> class StatefulLinear(nnx.Module): ... def __init__(self, din, dout, rngs): ... self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout))) ... self.b = nnx.Param(jnp.zeros((dout,))) ... self.count = nnx.Variable(jnp.array(0, dtype=jnp.uint32)) ... ... def increment(self): ... self.count += 1 ... ... def __call__(self, x): ... self.increment() ... return x @ self.w + self.b ... >>> rngs = nnx.Rngs(0) >>> nodes = dict( ... a=StatefulLinear(3, 2, rngs), ... b=StatefulLinear(2, 1, rngs), ... ) ... >>> node_state = nnx.split(nodes) >>> # use attribute access >>> _, node_state = nnx.call(node_state)['b'].increment() ... >>> nodes = nnx.merge(*node_state) >>> nodes['a'].count.value Array(0, dtype=uint32) >>> nodes['b'].count.value Array(1, dtype=uint32) """ def pure_caller(accessor: DelayedAccessor, *args, **kwargs): node = merge(*graphdef_state) method = accessor(node) out = method(*args, **kwargs) return out, split(node) return CallableProxy(pure_caller) # type: ignore
[docs]def iter_graph(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]: """Iterates over all nested nodes and leaves of the given graph node, including the current node. ``iter_graph`` creates a generator that yields path and value pairs, where the path is a tuple of strings or integers representing the path to the value from the root. Repeated nodes are visited only once. Leaves include static values. Example:: >>> from flax import nnx >>> import jax.numpy as jnp ... >>> class Linear(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.din, self.dout = din, dout ... self.w = nnx.Param(jax.random.uniform(rngs.next(), (din, dout))) ... self.b = nnx.Param(jnp.zeros((dout,))) ... >>> module = Linear(3, 4, rngs=nnx.Rngs(0)) >>> graph = [module, module] ... >>> for path, value in nnx.iter_graph(graph): ... print(path, type(value).__name__) ... (0, 'b') Param (0, 'din') int (0, 'dout') int (0, 'w') Param (0,) Linear () list """ visited: set[int] = set() path_parts: PathParts = () yield from _iter_graph(node, visited, path_parts)
def _iter_graph( node: tp.Any, visited: set[int], path_parts: PathParts ) -> tp.Iterator[tuple[PathParts, tp.Any]]: if is_node(node): if id(node) in visited: return visited.add(id(node)) node_dict = get_node_impl(node).node_dict(node) for key, value in node_dict.items(): yield from _iter_graph(value, visited, (*path_parts, key)) yield path_parts, node def compose_mapping( map_ab: tp.Mapping[A, B], map_bc: tp.Mapping[B, C], / ) -> dict[A, C]: return {a: map_bc[b] for a, b in map_ab.items() if b in map_bc} def compose_mapping_reversed( map_ab: tp.Mapping[A, B], map_bc: tp.Mapping[B, C], / ) -> dict[C, A]: return {map_bc[b]: a for a, b in map_ab.items() if b in map_bc} @dataclasses.dataclass(frozen=True) class Static(tp.Generic[A]): """An empty pytree node that treats its inner value as static. ``value`` must define ``__eq__`` and ``__hash__``. """ value: A jax.tree_util.register_static(Static) # --------------------------------------------------------- # Pytree # --------------------------------------------------------- class GenericPytree: ... def is_pytree_node(x: tp.Any) -> bool: t = type(x) if t in PYTREE_REGISTRY: return True elif t in GRAPH_REGISTRY: return False # known non-pytree types elif isinstance(x, Variable): return False # known pytree types elif type(x) is VariableState or type(x) is State: return True else: return not jax.tree_util.all_leaves((x,)) def _key_path_to_key(key: tp.Any) -> Key: if isinstance(key, jax.tree_util.SequenceKey): return key.idx elif isinstance( key, (jax.tree_util.DictKey, jax.tree_util.FlattenedIndexKey) ): if not is_key_like(key.key): raise ValueError( f'Invalid key: {key.key}. May be due to its type not being hashable or comparable.' ) return key.key elif isinstance(key, jax.tree_util.GetAttrKey): return key.name else: return str(key) def _flatten_pytree(pytree: tp.Any): leaves, treedef = jax.tree_util.tree_flatten_with_path( pytree, is_leaf=lambda x: x is not pytree ) nodes = tuple((_key_path_to_key(path[0]), value) for path, value in leaves) return nodes, treedef def _unflatten_pytree( nodes: tuple[tuple[Key, tp.Any], ...], treedef: jax.tree_util.PyTreeDef ): pytree = treedef.unflatten(value for _, value in nodes) return pytree PYTREE_NODE_IMPL = PytreeNodeImpl( type=GenericPytree, flatten=_flatten_pytree, unflatten=_unflatten_pytree, # type: ignore ) # common pytrees # list register_pytree_node_type( list, flatten=lambda x: (list(enumerate(x)), None), unflatten=lambda nodes, _: [value for _, value in nodes], # type: ignore ) # tuple register_pytree_node_type( tuple, flatten=lambda x: (list(enumerate(x)), None), unflatten=lambda nodes, _: tuple(value for _, value in nodes), # type: ignore ) # dict register_pytree_node_type( dict, flatten=lambda x: (sorted(x.items()), None), unflatten=lambda nodes, _: {key: value for key, value in nodes}, # type: ignore ) # None register_pytree_node_type( type(None), flatten=lambda x: ([], None), unflatten=lambda _, __: None, # type: ignore )