Source code for flax.experimental.nnx.nnx.visualization

# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
import importlib.util
import typing as tp

import jax

from flax.experimental import nnx

penzai_installed = importlib.util.find_spec('penzai') is not None
try:
  from IPython import get_ipython

  in_ipython = get_ipython() is not None
except ImportError:
  in_ipython = False


[docs]def display(*args): """Display the given objects using a Penzai visualizer. If Penzai is not installed or the code is not running in IPython, ``display`` will print the objects instead. """ if not penzai_installed or not in_ipython: for x in args: print(x) return from penzai import pz # type: ignore[import-not-found,import-untyped] with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()): for x in args: value = to_dataclass(x) pz.ts.display(value, ignore_exceptions=True)
def to_dataclass(node): seen_nodes = set() return _treemap_to_dataclass(node, seen_nodes) def _to_dataclass(x, seen_nodes: set[int]): if nnx.graph.is_graph_node(x): if id(x) in seen_nodes: dc_type = _make_dataclass_obj( type(x), {'repeated': True}, ) return dc_type seen_nodes.add(id(x)) node_impl = nnx.graph.get_node_impl(x) node_dict = node_impl.node_dict(x) node_dict = { str(key): _treemap_to_dataclass(value, seen_nodes) for key, value in node_dict.items() } dc_type = _make_dataclass_obj( type(x), {str(key): value for key, value in node_dict.items()}, ) return dc_type elif isinstance(x, (nnx.Variable, nnx.VariableState)): obj_vars = vars(x).copy() if 'raw_value' in obj_vars: obj_vars['value'] = obj_vars.pop('raw_value') if '_trace_state' in obj_vars: del obj_vars['_trace_state'] for name in list(obj_vars): if name.endswith('_hooks'): del obj_vars[name] obj_vars = { key: _treemap_to_dataclass(value, seen_nodes) for key, value in obj_vars.items() } dc_type = _make_dataclass_obj( type(x), obj_vars, penzai_dataclass=not isinstance(x, nnx.VariableState), ) return dc_type elif isinstance(x, nnx.State): return _treemap_to_dataclass(x._mapping, seen_nodes) return x def _treemap_to_dataclass(node, seen_nodes: set[int]): def _to_dataclass_fn(x): return _to_dataclass(x, seen_nodes) return jax.tree.map( _to_dataclass_fn, node, is_leaf=lambda x: isinstance(x, (nnx.VariableState, nnx.State)), ) def _make_dataclass_obj( cls, fields: tp.Mapping[str, tp.Any], penzai_dataclass: bool = True ) -> tp.Type: from penzai import pz # type: ignore[import-error] dataclass = pz.pytree_dataclass if penzai_dataclass else dataclasses.dataclass base = pz.Layer if penzai_dataclass else object attributes = { '__annotations__': {key: type(value) for key, value in fields.items()}, } if hasattr(cls, '__call__'): attributes['__call__'] = cls.__call__ dc_type = type(cls.__name__, (base,), attributes) dc_type = dataclass(dc_type) return dc_type(**fields)