Source code for flax.nnx.transforms.iteration

# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pytype: skip-file

from collections import deque
import dataclasses
import functools
import typing as tp

from flax import struct
from flax.core.frozen_dict import FrozenDict
from flax.nnx import extract, filterlib, graph, spmd, variablelib
from flax.nnx import statelib
from flax.nnx.module import Module
from flax.nnx.statelib import State
from flax.nnx.transforms.transforms import resolve_kwargs
from flax.typing import Leaf, Missing, PytreeDeque
import jax
import jax.core
import jax.numpy as jnp
import jax.stages
import numpy as np

A = tp.TypeVar('A')
C = tp.TypeVar('C')
B = tp.TypeVar('B')
F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])
G = tp.TypeVar('G', bound=tp.Callable[..., tp.Any])
M = tp.TypeVar('M', bound=Module)
MA = tp.TypeVar('MA', bound=Module)
N = tp.TypeVar('N', bound=Module)
T = tp.TypeVar('T')
StrInt = tp.TypeVar('StrInt', str, int)
AxisName = tp.Hashable
Leaves = tp.List[Leaf]
Index = int


class Carry:
  pass


# -------------------------------
# vmap
# -------------------------------


class StateAxes(extract.PrefixMapping):

  def __init__(
    self,
    filter_axes: (
      statelib.State
      | tp.Mapping[filterlib.Filter, Index | type[Carry] | None]
      | tp.Iterable[tuple[filterlib.Filter, Index | type[Carry] | None]]
    ),
    /,
  ):
    if isinstance(filter_axes, statelib.State):
      filter_axes = statelib.create_path_filters(filter_axes)  # type: ignore

    iterable = tuple(
        filter_axes.items()
        if isinstance(filter_axes, tp.Mapping)
        else filter_axes
    )
    self._filters = tuple(filter for filter, _ in iterable)
    self._axes = tuple(axis for _, axis in iterable)

  @property
  def filters(self) -> tuple[filterlib.Filter, ...]:
    return self._filters

  @property
  def axes(self) -> tuple[Index | type[Carry] | None, ...]:
    return self._axes

  def map_prefix(
    self, path: variablelib.PathParts, variable: variablelib.Variable
  ) -> tp.Any:
    for filter, axis in zip(self.filters, self.axes):
      predicate = filterlib.to_predicate(filter)
      if predicate(path, variable):
        return axis
    raise ValueError(f'No axis found for {path=}, {variable=}')

  def __repr__(self):
    return f'StateAxes({dict(self.items())})'

  def items(self):
    return zip(self.filters, self.axes)

  def __eq__(self, other):
    return (
        isinstance(other, StateAxes)
        and self.filters == other.filters
        and self.axes == other.axes
    )

  def __hash__(self):
    return hash((self.filters, self.axes))


AxisFn = tp.Callable[[graph.GraphState, int, tp.Mapping], graph.GraphState]


def _update_variable_sharding_metadata(
    tree, transform_metadata, axis_fn: AxisFn
):
  def _update_axes_fn(node_states):
    if isinstance(node_states, extract.NodeStates) and isinstance(
      node_states.metadata, (StateAxes, int)
    ):
      if isinstance(node_states.metadata, int):
        state = node_states.state
        assert isinstance(state, State)
        state = axis_fn(state, node_states.metadata, transform_metadata)
        return node_states.replace(states=(state,))
      else:
        states_out: list[graph.GraphState] = []
        for state, axis in zip(node_states.states, node_states.metadata.axes):
          assert isinstance(state, graph.State)
          if isinstance(axis, int):
            state = axis_fn(state, axis, transform_metadata)
          states_out.append(state)
        return node_states.replace(states=tuple(states_out))
    return node_states

  return jax.tree.map(
    _update_axes_fn, tree, is_leaf=lambda x: isinstance(x, extract.NodeStates)
  )


def _vmap_split_fn(ctx: graph.SplitContext, path, prefix, x):
  if isinstance(prefix, StateAxes):
    return extract.NodeStates.from_split(
      *ctx.split(x, *prefix.filters), metadata=prefix
    )
  return extract.NodeStates.from_split(*ctx.split(x), metadata=prefix)


@dataclasses.dataclass(eq=False)
class VmapFn:
  f: tp.Callable[..., tp.Any]
  transform_metadata: tp.Mapping[str, tp.Any]
  in_axes: tp.Any
  out_axes: tp.Any

  def __post_init__(self):
    functools.update_wrapper(self, self.f)

  def __call__(self, *pure_args: tuple[tp.Any, ...]):
    if spmd.PARTITION_NAME in self.transform_metadata:
      pure_args = _update_variable_sharding_metadata(
          pure_args, self.transform_metadata, spmd.remove_axis
      )
    args = extract.from_tree(pure_args, ctxtag='vmap')

    out = self.f(*args)

    args_out = extract.clear_non_graph_nodes(args)
    pure_args_out, pure_out = extract.to_tree(
        (args_out, out),
        prefix=(self.in_axes, self.out_axes),
        split_fn=_vmap_split_fn,
        ctxtag='vmap',
    )
    if spmd.PARTITION_NAME in self.transform_metadata:
      pure_args_out, pure_out = _update_variable_sharding_metadata(
          (pure_args_out, pure_out), self.transform_metadata, spmd.add_axis
      )
    return pure_args_out, pure_out


@tp.overload
def vmap(
    *,
    in_axes: int | None | tp.Sequence[tp.Any] = 0,
    out_axes: tp.Any = 0,
    axis_name: AxisName | None = None,
    axis_size: int | None = None,
    spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
    # nnx specific
    transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}),
) -> tp.Callable[[F], F]:
  ...


@tp.overload
def vmap(
    f: F,
    *,
    in_axes: int | None | tp.Sequence[tp.Any] = 0,
    out_axes: tp.Any = 0,
    axis_name: AxisName | None = None,
    axis_size: int | None = None,
    spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
    # nnx specific
    transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}),
) -> F:
  ...


[docs]def vmap( f: F | type[Missing] = Missing, *, in_axes: int | None | tp.Sequence[tp.Any] = 0, out_axes: tp.Any = 0, axis_name: AxisName | None = None, axis_size: int | None = None, spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), ) -> F | tp.Callable[[F], F]: """Reference-aware version of `jax.vmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`__. Args: f: Function to be mapped over additional axes. in_axes: An integer, None, or sequence of values specifying which input array axes to map over (see `jax.vmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`__). In addition to integers and None, :class:`StateAxes` can be used to control how graph nodes like Modules are vectorized by specifying the axes to be applied to substates of the graph node given a `Filter <https://flax.readthedocs.io/en/latest/nnx/filters_guide.html>`__. out_axes: An integer, None, or pytree indicating where the mapped axis should appear in the output (see `jax.vmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`__). axis_name: Optional, a hashable Python object used to identify the mapped axis so that parallel collectives can be applied. axis_size: Optional, an integer indicating the size of the axis to be mapped. If not provided, the mapped axis size is inferred from arguments. Returns: Batched/vectorized version of ``f`` with arguments that correspond to those of ``f``, but with extra array axes at positions indicated by ``in_axes``, and a return value that corresponds to that of ``f``, but with extra array axes at positions indicated by ``out_axes``. Example:: >>> from flax import nnx >>> from jax import random, numpy as jnp ... >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> x = jnp.ones((5, 2)) ... >>> @nnx.vmap(in_axes=(None, 0), out_axes=0) ... def forward(model, x): ... return model(x) ... >>> y = forward(model, x) >>> y.shape (5, 3) >>> class LinearEnsemble(nnx.Module): ... def __init__(self, num, rngs): ... self.w = nnx.Param(jax.random.uniform(rngs(), (num, 2, 3))) ... >>> model = LinearEnsemble(5, rngs=nnx.Rngs(0)) >>> x = jnp.ones((2,)) ... >>> @nnx.vmap(in_axes=(0, None), out_axes=0) ... def forward(model, x): ... return jnp.dot(x, model.w.value) ... >>> y = forward(model, x) >>> y.shape (5, 3) To control control how graph node substates are vectorized, ``StateAxes`` can be passed to ``in_axes`` and ``out_axes`` specifying the axes to be applied to each substate given a filter. The following example shows how to share the parameters between the ensemble members which keeping different batch statistics and dropout random state:: >>> class Foo(nnx.Module): ... def __init__(self): ... self.a = nnx.Param(jnp.arange(4)) ... self.b = nnx.BatchStat(jnp.arange(4)) ... >>> state_axes = nnx.StateAxes({nnx.Param: 0, nnx.BatchStat: None}) >>> @nnx.vmap(in_axes=(state_axes,), out_axes=0) ... def mul(foo): ... return foo.a * foo.b ... >>> foo = Foo() >>> y = mul(foo) >>> y Array([[0, 0, 0, 0], [0, 1, 2, 3], [0, 2, 4, 6], [0, 3, 6, 9]], dtype=int32) """ if f is Missing: return functools.partial( vmap, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name, axis_size=axis_size, spmd_axis_name=spmd_axis_name, transform_metadata=transform_metadata, ) # type: ignore[return-value] jax_in_axes = jax.tree.map( lambda x: extract.NodeStates.from_prefixes(x.axes, metadata=x) if isinstance(x, StateAxes) else x, in_axes, ) jax_out_axes = jax.tree.map( lambda x: extract.NodeStates.from_prefixes(x.axes, metadata=x) if isinstance(x, StateAxes) else x, out_axes, ) vmapped_fn = jax.vmap( VmapFn(f, transform_metadata, in_axes, out_axes), in_axes=jax_in_axes, out_axes=(jax_in_axes, jax_out_axes), axis_name=axis_name, axis_size=axis_size, spmd_axis_name=spmd_axis_name, ) @functools.wraps(f) @graph.update_context('vmap') def vmap_wrapper(*args, **kwargs): args = resolve_kwargs(f, args, kwargs) pure_args = extract.to_tree( args, prefix=in_axes, split_fn=_vmap_split_fn, ctxtag='vmap' ) pure_args_out, pure_out = vmapped_fn(*pure_args) _args_out, out = extract.from_tree((pure_args_out, pure_out), ctxtag='vmap') return out return vmap_wrapper # type: ignore
# ------------------------------- # pmap # ------------------------------- @dataclasses.dataclass(eq=False) class PmapFn: f: tp.Callable[..., tp.Any] transform_metadata: tp.Mapping[str, tp.Any] in_axes: tp.Any out_axes: tp.Any def __post_init__(self): functools.update_wrapper(self, self.f) def __call__(self, *pure_args: tuple[tp.Any, ...]): if spmd.PARTITION_NAME in self.transform_metadata: pure_args = _update_variable_sharding_metadata( pure_args, self.transform_metadata, spmd.remove_axis ) args = extract.from_tree(pure_args, ctxtag='pmap') out = self.f(*args) args_out = extract.clear_non_graph_nodes(args) pure_args_out, pure_out = extract.to_tree( (args_out, out), prefix=(self.in_axes, self.out_axes), split_fn=_vmap_split_fn, ctxtag='pmap', ) if spmd.PARTITION_NAME in self.transform_metadata: pure_args_out, pure_out = _update_variable_sharding_metadata( (pure_args_out, pure_out), self.transform_metadata, spmd.add_axis ) return pure_args_out, pure_out @tp.overload def pmap( *, axis_name: AxisName | None = None, in_axes: tp.Any = 0, out_axes: tp.Any = 0, static_broadcasted_argnums: int | tp.Iterable[int] = (), devices: tp.Sequence[jax.Device] | None = None, # noqa: F811 backend: str | None = None, axis_size: int | None = None, donate_argnums: int | tp.Iterable[int] = (), global_arg_shapes: tuple[tuple[int, ...], ...] | None = None, # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), ) -> tp.Callable[[F], F]: ... @tp.overload def pmap( f: F, *, axis_name: AxisName | None = None, in_axes: tp.Any = 0, out_axes: tp.Any = 0, static_broadcasted_argnums: int | tp.Iterable[int] = (), devices: tp.Sequence[jax.Device] | None = None, # noqa: F811 backend: str | None = None, axis_size: int | None = None, donate_argnums: int | tp.Iterable[int] = (), global_arg_shapes: tuple[tuple[int, ...], ...] | None = None, # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), ) -> F: ... def pmap( f: F | type[Missing] = Missing, *, axis_name: AxisName | None = None, in_axes: tp.Any = 0, out_axes: tp.Any = 0, static_broadcasted_argnums: int | tp.Iterable[int] = (), devices: tp.Sequence[jax.Device] | None = None, # noqa: F811 backend: str | None = None, axis_size: int | None = None, donate_argnums: int | tp.Iterable[int] = (), global_arg_shapes: tuple[tuple[int, ...], ...] | None = None, # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), ) -> F | tp.Callable[[F], F]: """Reference-aware version of `jax.vmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`__. Args: f: Function to be mapped over additional axes. in_axes: An integer, None, or sequence of values specifying which input array axes to map over (see `jax.vmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`__). In addition to integers and None, :class:`StateAxes` can be used to control how graph nodes like Modules are vectorized by specifying the axes to be applied to substates of the graph node given a `Filter <https://flax.readthedocs.io/en/latest/nnx/filters_guide.html>`__. out_axes: An integer, None, or pytree indicating where the mapped axis should appear in the output (see `jax.vmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`__). axis_name: Optional, a hashable Python object used to identify the mapped axis so that parallel collectives can be applied. axis_size: Optional, an integer indicating the size of the axis to be mapped. If not provided, the mapped axis size is inferred from arguments. Returns: Batched/vectorized version of ``f`` with arguments that correspond to those of ``f``, but with extra array axes at positions indicated by ``in_axes``, and a return value that corresponds to that of ``f``, but with extra array axes at positions indicated by ``out_axes``. Example:: >>> from flax import nnx >>> from jax import random, numpy as jnp ... >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> x = jnp.ones((5, 2)) ... >>> @nnx.vmap(in_axes=(None, 0), out_axes=0) ... def forward(model, x): ... return model(x) ... >>> y = forward(model, x) >>> y.shape (5, 3) >>> class LinearEnsemble(nnx.Module): ... def __init__(self, num, rngs): ... self.w = nnx.Param(jax.random.uniform(rngs(), (num, 2, 3))) ... >>> model = LinearEnsemble(5, rngs=nnx.Rngs(0)) >>> x = jnp.ones((2,)) ... >>> @nnx.vmap(in_axes=(0, None), out_axes=0) ... def forward(model, x): ... return jnp.dot(x, model.w.value) ... >>> y = forward(model, x) >>> y.shape (5, 3) To control control how graph node substates are vectorized, ``StateAxes`` can be passed to ``in_axes`` and ``out_axes`` specifying the axes to be applied to each substate given a filter. The following example shows how to share the parameters between the ensemble members which keeping different batch statistics and dropout random state:: >>> class Foo(nnx.Module): ... def __init__(self): ... self.a = nnx.Param(jnp.arange(4)) ... self.b = nnx.BatchStat(jnp.arange(4)) ... >>> state_axes = nnx.StateAxes({nnx.Param: 0, nnx.BatchStat: None}) >>> @nnx.vmap(in_axes=(state_axes,), out_axes=0) ... def mul(foo): ... return foo.a * foo.b ... >>> foo = Foo() >>> y = mul(foo) >>> y Array([[0, 0, 0, 0], [0, 1, 2, 3], [0, 2, 4, 6], [0, 3, 6, 9]], dtype=int32) """ if f is Missing: return functools.partial( pmap, axis_name=axis_name, in_axes=in_axes, out_axes=out_axes, static_broadcasted_argnums=static_broadcasted_argnums, devices=devices, backend=backend, axis_size=axis_size, donate_argnums=donate_argnums, global_arg_shapes=global_arg_shapes, transform_metadata=transform_metadata, ) # type: ignore[return-value] jax_in_axes = jax.tree.map( lambda x: extract.NodeStates.from_prefixes(x.axes, metadata=x) if isinstance(x, StateAxes) else x, in_axes, ) jax_out_axes = jax.tree.map( lambda x: extract.NodeStates.from_prefixes(x.axes, metadata=x) if isinstance(x, StateAxes) else x, out_axes, ) pmapped_fn = jax.pmap( PmapFn(f, transform_metadata, in_axes, out_axes), axis_name=axis_name, in_axes=jax_in_axes, out_axes=(jax_in_axes, jax_out_axes), static_broadcasted_argnums=static_broadcasted_argnums, devices=devices, backend=backend, axis_size=axis_size, donate_argnums=donate_argnums, global_arg_shapes=global_arg_shapes, ) @functools.wraps(f) @graph.update_context('pmap') def vmap_wrapper(*args): pure_args = extract.to_tree( args, prefix=in_axes, split_fn=_vmap_split_fn, ctxtag='pmap' ) pure_args_out, pure_out = pmapped_fn(*pure_args) _args_out, out = extract.from_tree((pure_args_out, pure_out), ctxtag='pmap') return out return vmap_wrapper # type: ignore # ------------------------------- # scan # ------------------------------- class Broadcasted(struct.PyTreeNode): data: tp.Any def _get_carry_argnum(axes, is_in_axes: bool): if axes is Carry: return 'all' elif isinstance(axes, int) or axes is None: return None obj_repr = 'in_axes' if is_in_axes else 'out_axes' carry_argnum: int | None = None prev_key: tp.Any = None for key, x in jax.tree_util.tree_leaves_with_path(axes): if x is not Carry: continue assert isinstance(key[0], jax.tree_util.SequenceKey) i = key[0].idx if len(key) >= 2: raise ValueError( f'Carry must at the top-level, it cannot be nested. Found {axes=}' ) if carry_argnum is not None: raise ValueError( f'Found multiple Carry definitions at ' f'{obj_repr}{jax.tree_util.keystr(prev_key)} and ' f'{obj_repr}{jax.tree_util.keystr(key)}' ) carry_argnum = i prev_key = key return carry_argnum def _check_out_axes(out_axes): for key, x in jax.tree_util.tree_leaves_with_path( out_axes, is_leaf=lambda x: x is None ): if x is None: raise ValueError( f'Cannot broadcast output state. ' f'Got out_axes=None at: out_axes{jax.tree_util.keystr(key)}' ) elif isinstance(x, StateAxes): for filter, value in x.items(): if value is None: raise ValueError( f'Cannot broadcast output state. ' f'Got StateAxes({{{filter}: None}}) at: out_axes' f'{jax.tree_util.keystr(key)}' ) elif value is Carry: raise ValueError( f'Cannot carry output state. ' f'Got StateAxes({{{filter}: Carry}}) at: out_axes' f'{jax.tree_util.keystr(key)}' ) def _check_carry_same_references(carry_arg, carry_arg_out): def check_carry_same_references(key_path, arg, out): if ( not isinstance(arg, jax.Array) or not isinstance(out, jax.Array) ) and arg is not out: raise ValueError( 'Carry references must be the same between iterations. ' f'Got {arg=} with id={id(arg)} and {out=} with id={id(out)} ' f'at carry{jax.tree_util.keystr(key_path)}' ) jax.tree_util.tree_map_with_path( check_carry_same_references, carry_arg, carry_arg_out ) def _extract_index_mappings( pure_carry_arg_out, carry_index_mappings: list[graph.HashableMapping[int, int]], /, ): def extract_index_mappings(x): if isinstance(x, extract.NodeStates) and isinstance( x._graphdef, graph.NodeDef ): index_mapping = x._graphdef.index_mapping assert index_mapping is not None carry_index_mappings.append(index_mapping) x = x.replace( _graphdef=dataclasses.replace(x._graphdef, index_mapping=None) ) return x pure_carry_arg_out = jax.tree.map( extract_index_mappings, pure_carry_arg_out, is_leaf=lambda x: isinstance(x, extract.NodeStates), ) return pure_carry_arg_out def _insert_index_mappings( pure_carry_arg_out, carry_index_mappings: deque[graph.HashableMapping[int, int]], /, ): def insert_index_mappings(x): if isinstance(x, extract.NodeStates) and isinstance( x._graphdef, graph.NodeDef ): index_mapping = carry_index_mappings.popleft() x = x.replace( _graphdef=dataclasses.replace(x._graphdef, index_mapping=index_mapping) ) return x pure_carry_arg_out = jax.tree.map( insert_index_mappings, pure_carry_arg_out, is_leaf=lambda x: isinstance(x, extract.NodeStates), ) return pure_carry_arg_out def _scan_split_in( carry_deque: PytreeDeque[list[State]], broadcast_deque: PytreeDeque[list[State]], broadcast_arrays: PytreeDeque[Broadcasted], /, ctx: graph.SplitContext, path, prefix, x, ): if graph.is_graph_node(x): vectorized_states: list[State] = [] carry_states: list[State] = [] broadcast_states: list[State] = [] if isinstance(prefix, StateAxes): graphdef, *states = ctx.split(x, *prefix.filters) for state, axis in zip(states, prefix.axes): if axis is None: broadcast_states.append(state) elif isinstance(axis, int): state = jax.tree.map(lambda x: jnp.moveaxis(x, axis, 0), state) vectorized_states.append(state) else: # axis is Carry carry_states.append(state) if not vectorized_states: vectorized_states.append(State({})) carry_deque.append(carry_states) broadcast_deque.append(broadcast_states) return extract.NodeStates.from_split( graphdef, *vectorized_states, metadata=prefix ) elif isinstance(prefix, int): graphdef, state = ctx.split(x) state = jax.tree.map(lambda x: jnp.moveaxis(x, prefix, 0), state) vectorized_states.append(state) elif prefix is None: graphdef, state = ctx.split(x) broadcast_states.append(state) vectorized_states.append(State({})) elif prefix is Carry: graphdef, state = ctx.split(x) carry_states.append(state) vectorized_states.append(State({})) else: raise ValueError( f'Invalid axes {prefix} args{jax.tree_util.keystr(path)}' ) if not vectorized_states: vectorized_states.append(State({})) carry_deque.append(carry_states) broadcast_deque.append(broadcast_states) return extract.NodeStates.from_split( graphdef, *vectorized_states, metadata=prefix ) else: if isinstance(prefix, StateAxes): raise ValueError( 'Cannot use StateAxes on non-graph nodes, ' f'found {prefix} args{jax.tree_util.keystr(path)}' ) elif prefix is Carry: return x elif prefix is None: broadcast_arrays.append(Broadcasted(x)) return Broadcasted(None) elif isinstance(prefix, int): if not isinstance(x, (jax.Array, np.ndarray)): raise ValueError( f'Expected an array, got {type(x).__name__} args' f'{jax.tree_util.keystr(path)}' ) return jnp.moveaxis(x, prefix, 0) else: raise ValueError( f'Invalid axes {prefix} args{jax.tree_util.keystr(path)}' ) def _scan_split_out( carry_deque: PytreeDeque[list[State]], broadcast_deque: PytreeDeque[list[State]], /, ctx: graph.SplitContext, path: extract.KeyPath, prefix, x, ): assert isinstance(path[0], jax.tree_util.SequenceKey) is_input_arg = path[0].idx == 0 if graph.is_graph_node(x): vectorized_states: list[State] = [] carry_states: list[State] = [] broadcast_states: list[State] = [] if isinstance(prefix, StateAxes): graphdef, *states = ctx.split(x, *prefix.filters) for state, filter, axis in zip(states, prefix.filters, prefix.axes): if axis is None: assert is_input_arg # validated by _check_out_axes broadcast_states.append(state) elif isinstance(axis, int): vectorized_states.append(state) elif axis is Carry: assert is_input_arg # validated by _check_out_axes carry_states.append(state) else: obj_repr = 'args' if is_input_arg else 'out' raise ValueError( f'Invalid axes {axis} for filter {filter} at ' f'{obj_repr}{jax.tree_util.keystr(path)}' ) if not vectorized_states: vectorized_states.append(State({})) if is_input_arg: carry_deque.append(carry_states) broadcast_deque.append(broadcast_states) return extract.NodeStates.from_split( graphdef, *vectorized_states, metadata=prefix ) elif isinstance(prefix, int): graphdef, state = ctx.split(x) vectorized_states.append(state) elif prefix is None: assert is_input_arg # validated by _check_out_axes graphdef, state = ctx.split(x) broadcast_states.append(state) vectorized_states.append(State({})) elif prefix is Carry: assert is_input_arg # validated by _check_out_axes graphdef, state = ctx.split(x) carry_states.append(state) vectorized_states.append(State({})) else: obj_repr = 'args' if is_input_arg else 'out' raise ValueError( f'Invalid axes {prefix} at {obj_repr}{jax.tree_util.keystr(path)}' ) if not vectorized_states: vectorized_states.append(State({})) if is_input_arg: carry_deque.append(carry_states) broadcast_deque.append(broadcast_states) return extract.NodeStates.from_split( graphdef, *vectorized_states, metadata=prefix ) else: if isinstance(prefix, StateAxes): obj_repr = 'args' if is_input_arg else 'out' raise ValueError( 'Cannot use StateAxes on non-graph nodes, ' f'found {prefix} at {obj_repr}{jax.tree_util.keystr(path)}' ) elif prefix is Carry: return x elif prefix is None: assert not is_input_arg # validated by _check_out_axes return Broadcasted(None) elif isinstance(prefix, int): return x else: obj_repr = 'args' if is_input_arg else 'out' raise ValueError( f'Invalid axes {prefix} at {obj_repr}{jax.tree_util.keystr(path)}' ) def _scan_merge_in( carry_deque: PytreeDeque[list[State]], broadcast_deque: PytreeDeque[list[State]], broadcast_arrays: PytreeDeque[Broadcasted], /, ctx: graph.MergeContext, path, prefix, x, ): if isinstance(x, extract.NodeStates): carry_states = carry_deque.popleft() broadcast_states = broadcast_deque.popleft() return ctx.merge(x.graphdef, *x.states, *carry_states, *broadcast_states) elif isinstance(x, Broadcasted): assert x.data is None return broadcast_arrays.popleft().data else: return x def _scan_merge_out( carry_deque: PytreeDeque[list[State]], broadcast_deque: PytreeDeque[list[State]], /, ctx: graph.MergeContext, path, prefix, x, ): assert isinstance(path[0], jax.tree_util.SequenceKey) is_input_arg = path[0].idx == 0 if isinstance(x, extract.NodeStates): states: list[State] = [] if is_input_arg: carry_states = deque(carry_deque.popleft()) broadcast_states = deque(broadcast_deque.popleft()) else: carry_states = deque[State]() broadcast_states = deque[State]() if isinstance(prefix, StateAxes): vectorized_states = deque(x.states) for axis in prefix.axes: if isinstance(axis, int): state = vectorized_states.popleft() state = jax.tree.map(lambda x: jnp.moveaxis(x, 0, axis), state) states.append(state) elif axis is None: states.append(broadcast_states.popleft()) else: # axis is Carry states.append(carry_states.popleft()) assert not carry_states and not broadcast_states assert not vectorized_states or ( len(vectorized_states) == 1 and not vectorized_states[0] ) elif isinstance(prefix, int): state = jax.tree.map(lambda x: jnp.moveaxis(x, 0, prefix), x.state) states.extend((state, *carry_states, *broadcast_states)) elif prefix is None: assert is_input_arg states.extend(broadcast_states) elif prefix is Carry: assert is_input_arg states.extend(carry_states) else: obj_repr = 'args' if is_input_arg else 'out' raise ValueError( f'Invalid axes {prefix} at {obj_repr}{jax.tree_util.keystr(path)}' ) return ctx.merge(x.graphdef, *states) else: if isinstance(prefix, StateAxes): obj_repr = 'args' if is_input_arg else 'out' raise ValueError( 'Cannot use StateAxes on non-graph nodes, ' f'found {prefix} at {obj_repr}{jax.tree_util.keystr(path)}' ) elif prefix is Carry: return x elif prefix is None: return x elif isinstance(prefix, int): if not isinstance(x, (jax.Array, np.ndarray)): obj_repr = 'args' if is_input_arg else 'out' raise ValueError( f'Expected an array, got {type(x).__name__} at ' f'{obj_repr}{jax.tree_util.keystr(path)}' ) return jnp.moveaxis(x, 0, prefix) else: obj_repr = 'args' if is_input_arg else 'out' raise ValueError( f'Invalid axes {prefix} at {obj_repr}{jax.tree_util.keystr(path)}' ) @dataclasses.dataclass(eq=False) class ScanFn: f: tp.Callable[..., tp.Any] input_carry_argnum: int | None | tp.Literal['all'] output_carry_argnum: int | None | tp.Literal['all'] in_axes: tp.Any out_axes: tp.Any transform_metadata: tp.Mapping[str, tp.Any] def __post_init__(self): functools.update_wrapper(self, self.f) def __call__( self, carry: tuple[ tp.Any, # carry_arg PytreeDeque[list[State]], # carry_deque PytreeDeque[list[State]], # broadcast_deque PytreeDeque[Broadcasted], # broadcast_arrays ], pure_args: tuple[tp.Any, ...], ): pure_carry_arg, carry_deque, broadcast_deque, broadcast_arrays = carry broadcast_deque_out = PytreeDeque(broadcast_deque) broadcast_arrays_out = PytreeDeque(broadcast_arrays) if self.input_carry_argnum == 'all': assert pure_args == () pure_args = (pure_carry_arg,) elif isinstance(self.input_carry_argnum, int): assert pure_args[self.input_carry_argnum] is None _pure_args = list(pure_args) _pure_args[self.input_carry_argnum] = pure_carry_arg pure_args = tuple(_pure_args) else: assert self.input_carry_argnum is None assert pure_carry_arg is None if spmd.PARTITION_NAME in self.transform_metadata: pure_args = _update_variable_sharding_metadata( pure_args, self.transform_metadata, spmd.remove_axis ) args: tuple = extract.from_tree( pure_args, prefix=self.in_axes, merge_fn=functools.partial( _scan_merge_in, carry_deque, broadcast_deque, broadcast_arrays ), is_leaf=lambda x: isinstance(x, (extract.NodeStates, Broadcasted)), map_non_graph_nodes=True, ctxtag='scan', ) assert not carry_deque and not broadcast_deque and not broadcast_arrays out = self.f(*args) # extract the carry from the args if self.input_carry_argnum == 'all': carry_arg = args[0] elif isinstance(self.input_carry_argnum, int): carry_arg = args[self.input_carry_argnum] else: assert self.input_carry_argnum is None carry_arg = None # extract the carry from the output if self.output_carry_argnum == 'all': carry_arg_out = out out = None elif isinstance(self.output_carry_argnum, int): assert isinstance(out, tuple) carry_arg_out = out[self.output_carry_argnum] _out = list(out) _out[self.output_carry_argnum] = None out = tuple(_out) else: assert self.output_carry_argnum is None carry_arg_out = None # TODO(cgarciae): allowing new references might lead to inconsistencies with # scan's looping semantics and we would also need to propagate the input _check_carry_same_references(carry_arg, carry_arg_out) args_out: tuple = extract.clear_non_graph_nodes(args) # replace the carry from the input args with the carry from the output if self.input_carry_argnum == 'all': args_out = (carry_arg_out,) elif isinstance(self.input_carry_argnum, int): _args_out = list(args_out) _args_out[self.input_carry_argnum] = carry_arg_out args_out = tuple(_args_out) else: assert self.input_carry_argnum is None assert carry_arg_out is None carry_deque_out = PytreeDeque[list[State]]() _broadcast_deque_out_tmp = PytreeDeque[list[State]]() # discarded pure_args_out: tuple pure_args_out, pure_out = extract.to_tree( (args_out, out), prefix=(self.in_axes, self.out_axes), split_fn=functools.partial( _scan_split_out, carry_deque_out, _broadcast_deque_out_tmp ), map_non_graph_nodes=True, ctxtag='scan', ) if spmd.PARTITION_NAME in self.transform_metadata: pure_args_out, pure_out = _update_variable_sharding_metadata( (pure_args_out, pure_out), self.transform_metadata, spmd.add_axis, ) # extract the pure carry from the pure args if self.input_carry_argnum == 'all': pure_carry_arg_out = pure_args_out[0] pure_args_out = () elif isinstance(self.input_carry_argnum, int): pure_carry_arg_out = pure_args_out[self.input_carry_argnum] _pure_args_out = list(pure_args_out) _pure_args_out[self.input_carry_argnum] = None pure_args_out = tuple(_pure_args_out) else: assert self.input_carry_argnum is None pure_carry_arg_out = None # next we have to remove all the index_mappings from the NodeDefs # in the carry outputs because they are not present in the inputs carry_index_mappings: list[graph.HashableMapping[int, int]] = [] pure_carry_arg_out = _extract_index_mappings( pure_carry_arg_out, carry_index_mappings ) carry_arg_out = ( pure_carry_arg_out, carry_deque_out, broadcast_deque_out, broadcast_arrays_out, ) scan_out = ( graph.Static(tuple(carry_index_mappings)), pure_args_out, pure_out, ) return carry_arg_out, scan_out @tp.overload def scan( *, length: int | None = None, reverse: bool = False, unroll: int | bool = 1, _split_transpose: bool = False, # extended api in_axes: int | None | type[Carry] | tuple[tp.Any, ...] = (Carry, 0), out_axes: tp.Any = (Carry, 0), # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), ) -> tp.Callable[[F], F]: ... @tp.overload def scan( f: F, *, length: int | None = None, reverse: bool = False, unroll: int | bool = 1, _split_transpose: bool = False, # extended api in_axes: int | None | type[Carry] | tuple[tp.Any, ...] = (Carry, 0), out_axes: tp.Any = (Carry, 0), # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), ) -> F: ...
[docs]def scan( f: F | type[Missing] = Missing, *, length: int | None = None, reverse: bool = False, unroll: int | bool = 1, _split_transpose: bool = False, # extended api in_axes: int | None | type[Carry] | tuple[tp.Any, ...] = (Carry, 0), out_axes: tp.Any = (Carry, 0), # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), ) -> F | tp.Callable[[F], F]: if f is Missing: return functools.partial( scan, length=length, reverse=reverse, unroll=unroll, _split_transpose=_split_transpose, in_axes=in_axes, out_axes=out_axes, transform_metadata=transform_metadata, ) # type: ignore[return-value] _check_out_axes(out_axes) input_carry_argnum = _get_carry_argnum(in_axes, is_in_axes=True) output_carry_argnum = _get_carry_argnum(out_axes, is_in_axes=False) if (input_carry_argnum is None and output_carry_argnum is not None) or ( input_carry_argnum is not None and output_carry_argnum is None ): raise ValueError( 'If one of in_axes or out_axes has Carry, the other must also have Carry. ' f'Got {in_axes=!r} and {out_axes=!r}' ) scan_fn = ScanFn( f, input_carry_argnum, output_carry_argnum, in_axes, out_axes, transform_metadata, ) @functools.wraps(f) @graph.update_context('scan') def scan_wrapper(*args, **kwargs): args = resolve_kwargs(f, args, kwargs) if in_axes is Carry and len(args) != 1: raise ValueError( f'When in_axes=Carry, the function must take exactly one argument, ' f'got {len(args)} arguments.' ) carry_deque = PytreeDeque() broadcast_deque = PytreeDeque() broadcast_arrays = PytreeDeque() pure_args: tuple = extract.to_tree( args, prefix=in_axes, split_fn=functools.partial( _scan_split_in, carry_deque, broadcast_deque, broadcast_arrays ), map_non_graph_nodes=True, ctxtag='scan', ) if isinstance(input_carry_argnum, int): pure_carry_arg = pure_args[input_carry_argnum] _pure_args = list(pure_args) _pure_args[input_carry_argnum] = None pure_args = tuple(_pure_args) elif input_carry_argnum == 'all': pure_carry_arg = pure_args[0] pure_args = () else: assert input_carry_argnum is None pure_carry_arg = None carry = (pure_carry_arg, carry_deque, broadcast_deque, broadcast_arrays) carry_out, scan_out = jax.lax.scan( scan_fn, carry, pure_args, length=length, reverse=reverse, unroll=unroll, _split_transpose=_split_transpose, ) ( pure_carry_arg_out, carry_deque_out, broadcast_deque_out, broadcast_arrays_out, ) = carry_out ( static_carry_index_mappings, pure_args_out, pure_out, ) = scan_out # next we have to insert all the index_mappings back into the NodeDefs # in the carry outputs carry_index_mappings = deque(static_carry_index_mappings.value) pure_carry_arg_out = _insert_index_mappings( pure_carry_arg_out, carry_index_mappings ) # insert pure carry into pure_args_out if input_carry_argnum == 'all': pure_args_out = (pure_carry_arg_out,) elif isinstance(input_carry_argnum, int): _pure_args_out = list(pure_args_out) _pure_args_out[input_carry_argnum] = pure_carry_arg_out pure_args_out = tuple(_pure_args_out) else: assert input_carry_argnum is None assert pure_carry_arg_out is None args_out, out = extract.from_tree( (pure_args_out, pure_out), prefix=(in_axes, out_axes), merge_fn=functools.partial( _scan_merge_out, carry_deque_out, broadcast_deque_out ), is_leaf=lambda x: isinstance(x, (extract.NodeStates, Broadcasted)), map_non_graph_nodes=True, ctxtag='scan', ) # extract the carry from args_out if input_carry_argnum == 'all': carry_arg = args_out[0] elif isinstance(input_carry_argnum, int): carry_arg = args_out[input_carry_argnum] else: assert input_carry_argnum is None carry_arg = None # insert carry into the output if output_carry_argnum == 'all': out = carry_arg elif isinstance(output_carry_argnum, int): _out = list(out) _out[output_carry_argnum] = carry_arg out = tuple(_out) else: assert output_carry_argnum is None assert carry_arg is None return out return scan_wrapper # type: ignore
# ------------------------------- # while_loop # ------------------------------- @dataclasses.dataclass(eq=False) class WhileLoopCondFn: f: tp.Callable[..., tp.Any] def __post_init__(self): functools.update_wrapper(self, self.f) def __call__(self, pure_val): val = extract.from_tree(pure_val) out = self.f(val) return out def _add_fake_index_mapping(tree: tp.Any): global_index_mapping = {} # for the whole context, over all inputs def per_node_state(ns: extract.NodeStates | tp.Any): if not isinstance(ns, extract.NodeStates) or not isinstance( ns._graphdef, graph.NodeDef ): return ns def per_node_def(nd: graph.NodeDef | graph.NodeRef): if nd.index >= 0: global_index_mapping[nd.index] = nd.index if isinstance(nd, graph.NodeRef): return for attribute in nd.attributes: if type(attribute) is graph.SubGraphAttribute: per_node_def(attribute.value) elif ( type(attribute) is graph.LeafAttribute and isinstance(attribute.value, (graph.VariableDef, graph.NodeRef)) and attribute.value.index >= 0 ): global_index_mapping[attribute.value.index] = attribute.value.index return per_node_def(ns._graphdef) return dataclasses.replace( ns, _graphdef=dataclasses.replace( ns._graphdef, index_mapping=graph.HashableMapping(global_index_mapping) ), ) return jax.tree.map(per_node_state, tree, is_leaf=lambda x: isinstance(x, extract.NodeStates)) def _remove_index_mapping(tree: tp.Any): '''Remove a fake index_mapping for the input to match that of the output.''' def per_node_state(ns: extract.NodeStates | tp.Any): if not isinstance(ns, extract.NodeStates) or not isinstance( ns._graphdef, graph.NodeDef ): return ns assert isinstance(ns._graphdef, graph.NodeDef) return dataclasses.replace(ns, _graphdef=dataclasses.replace( ns._graphdef, index_mapping=None )) return jax.tree.map(per_node_state, tree, is_leaf=lambda x: isinstance(x, extract.NodeStates)) @dataclasses.dataclass(eq=False) class WhileLoopBodyFn: f: tp.Callable[..., tp.Any] def __post_init__(self): functools.update_wrapper(self, self.f) @graph.update_context('while_loop_body') def __call__(self, pure_val): # Removing the dummy index mapping being added outside of body function. pure_val_in = _remove_index_mapping(pure_val) val = extract.from_tree(pure_val_in, ctxtag='while_loop_body') out = self.f(val) pure_out = extract.to_tree(out, ctxtag='while_loop_body') try: jax.tree.map(lambda a, b: None, pure_val, pure_out) except ValueError as e: msg = ("nnx.while_loop requires body function's input and output to " "have the same reference and pytree structure, but they differ. " "If the mismatch comes from `index_mapping` field, you might " "have modified reference structure within the body function, " "which is not allowed." f"Detail of the mismatch: \n {str(e)}") raise ValueError(msg) return pure_out
[docs]@graph.update_context('while_loop') def while_loop(cond_fun: tp.Callable[[T], tp.Any], body_fun: tp.Callable[[T], T], init_val: T) -> T: """A Flax NNX transformation of `jax.lax.while_loop <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html>`_. Caution: for the NNX internal reference tracing mechanism to work, you cannot change the variable reference structure of ``init_val`` inside ``body_fun``. Example:: >>> import jax >>> from flax import nnx >>> def fwd_fn(input): ... module, x, count = input ... return module, module(x), count - 1.0 >>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) >>> x = jax.random.normal(jax.random.key(0), (10,)) >>> # `module` will be called three times >>> _, y, _ = nnx.while_loop( ... lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0)) Args: cond_fun: A function for the continue condition of the while loop, taking a single input of type ``T`` and outputting a boolean. body_fun: A function that takes an input of type ``T`` and outputs an ``T``. Note that both data and modules of ``T`` must have the same reference structure between inputs and outputs. init_val: The initial input for ``cond_fun`` and ``body_fun``. Must be of type ``T``. """ pure_init_val = extract.to_tree(init_val, ctxtag='while_loop') # Adding the expected reference mapping to `pure_init_val` to match # `body_fun`'s output pytree structure, to make JAX while_loop happy. pure_init_val = _add_fake_index_mapping(pure_init_val) pure_out = jax.lax.while_loop( WhileLoopCondFn(cond_fun), WhileLoopBodyFn(body_fun), pure_init_val, ) out = extract.from_tree(pure_out, ctxtag='while_loop') return out
@dataclasses.dataclass(eq=False) class ForiLoopBodyFn: f: tp.Callable[..., tp.Any] def __post_init__(self): functools.update_wrapper(self, self.f) @graph.update_context('fori_loop_body') def __call__(self, i, pure_val): # Removing the dummy index mapping being added outside of body function. pure_val_in = _remove_index_mapping(pure_val) val = extract.from_tree(pure_val_in, ctxtag='fori_loop_body') out = self.f(i, val) pure_out = extract.to_tree(out, ctxtag='fori_loop_body') try: jax.tree.map(lambda a, b: None, pure_val, pure_out) except ValueError as e: msg = ("nnx.fori_loop requires body function's input and output to " "have the same reference and pytree structure, but they differ. " "If the mismatch comes from `index_mapping` field, you might " "have modified reference structure within the body function, " "which is not allowed. " f"Detail of the mismatch: \n {str(e)}") raise ValueError(msg) return pure_out
[docs]@graph.update_context('fori_loop') def fori_loop(lower: int, upper: int, body_fun: tp.Callable[[int, T], T], init_val: T, *, unroll: int | bool | None = None) -> T: """A Flax NNX transformation of `jax.lax.fori_loop <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html>`_. Caution: for the NNX internal reference tracing mechanism to work, you cannot change the variable reference structure of `init_val` inside `body_fun`. Example:: >>> import jax >>> from flax import nnx >>> def fwd_fn(i, input): ... m, x = input ... m.kernel.value = jnp.identity(10) * i ... return m, m(x) >>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) >>> x = jax.random.normal(jax.random.key(0), (10,)) >>> _, y = nnx.fori_loop(2, 4, fwd_fn, (module, x)) >>> np.testing.assert_array_equal(y, x * 2 * 3) Args: lower: An integer representing the loop index lower bound (inclusive). upper: An integer representing the loop index upper bound (exclusive). body_fun: a function that takes an input of type ``T`` and outputs an ``T``. Note that both data and modules of ``T`` must have the same reference structure between inputs and outputs. init_val: the initial input for body_fun. Must be of type ``T``. unroll: An optional integer or boolean that determines how much to unroll the loop. If an integer is provided, it determines how many unrolled loop iterations to run within a single rolled iteration of the loop. If a boolean is provided, it will determine if the loop is competely unrolled (i.e. ``unroll=True``) or left completely unrolled (i.e. ``unroll=False``). This argument is only applicable if the loop bounds are statically known. Returns: A loop value from the final iteration, of type ``T``. """ pure_init_val = extract.to_tree(init_val, ctxtag='fori_loop') # Adding the expected reference mapping to `pure_init_val` to match # `body_fun`'s output pytree structure, to make JAX happy. pure_init_val = _add_fake_index_mapping(pure_init_val) pure_out = jax.lax.fori_loop(lower, upper, ForiLoopBodyFn(body_fun), pure_init_val, unroll=unroll) out = extract.from_tree(pure_out, ctxtag='fori_loop') return out