Source code for flax.nnx.transforms.transforms

# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pytype: skip-file
from __future__ import annotations

from abc import abstractmethod
import dataclasses
import functools
import inspect
import typing as tp

from jax._src import checkify as checkify_lib

from flax.nnx import (
  extract,
  graph,
)
from flax.nnx.module import Module
from flax.nnx.proxy_caller import (
  CallableProxy,
  DelayedAccessor,
)
from flax.nnx.transforms import general
from flax.typing import MISSING, Leaf, Missing
import jax
import jax.core
import jax.stages

A = tp.TypeVar('A')
C = tp.TypeVar('C')
B = tp.TypeVar('B')
F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])
G = tp.TypeVar('G', bound=tp.Callable[..., tp.Any])
M = tp.TypeVar('M', bound=Module)
MA = tp.TypeVar('MA', bound=Module)
N = tp.TypeVar('N', bound=Module)
StrInt = tp.TypeVar('StrInt', str, int)
AxisName = tp.Hashable
Leaves = tp.List[Leaf]
Index = int


@tp.overload
def resolve_kwargs(
  fun: tp.Callable[..., tp.Any],
  args: tuple,
  kwargs: dict[str, tp.Any],
) -> tuple: ...
@tp.overload
def resolve_kwargs() -> tp.Callable[[F], F]: ...
def resolve_kwargs(
  fun: tp.Callable[..., tp.Any] | Missing = MISSING,
  args: tuple | Missing = MISSING,
  kwargs: dict[str, tp.Any] | Missing = MISSING,
) -> tuple | tp.Callable[[F], F]:
  if isinstance(fun, Missing):

    def resolve_kwargs_decorator(f):
      @functools.wraps(f)
      def resolve_kwargs_wrapper(*args, **kwargs):
        args = resolve_kwargs(f, args, kwargs)
        return f(*args)

      return resolve_kwargs_wrapper

    return resolve_kwargs_decorator  # type: ignore

  if isinstance(args, Missing):
    raise ValueError('args must be provided')
  if isinstance(kwargs, Missing):
    raise ValueError('kwargs must be provided')

  if isinstance(fun, functools.partial):
    # functools.partial should have an opaque signature.
    fun = lambda *args, **kwargs: None
  ba = inspect.signature(fun).bind(*args, **kwargs)
  ba.apply_defaults()
  if ba.kwargs:
    raise TypeError('keyword arguments could not be resolved to positions')
  else:
    return ba.args



class LiftedModule(tp.Generic[M], Module):  # type: ignore[ignored-abstractmethod]
  @abstractmethod
  def _call(self, accessor: DelayedAccessor, *args, **kwargs) -> tp.Any:
    pass

  @property
  @abstractmethod
  def _submodule(self) -> M:
    pass  # type: ignore[bad-return-type] # why pytype?

  def __call__(self, *args, **kwargs) -> tp.Any:
    return self.call(*args, **kwargs)  # type: ignore

  @property
  def call(self) -> tp.Any:
    module = self

    def check_and_call(accessor: DelayedAccessor, *args, **kwargs):
      return self._call(accessor, *args, **kwargs)

    proxy = CallableProxy(check_and_call)  # type: ignore[arg-type]

    while isinstance(module._submodule, LiftedModule):
      module = module._submodule
      proxy = proxy.call

    return proxy  # type: ignore


# -------------------------------
# simple transforms
# -------------------------------


[docs]def eval_shape( f: tp.Callable[..., A], *args: tp.Any, **kwargs: tp.Any, ) -> A: args, kwargs = extract.to_tree((args, kwargs)) @functools.wraps(f) def _eval_shape_fn(*args, **kwargs): args, kwargs = extract.from_tree((args, kwargs)) out = f(*args, **kwargs) return extract.to_tree(out) out = jax.eval_shape(_eval_shape_fn, *args, **kwargs) return extract.from_tree(out) """A "lifted" version of `jax.eval_shape <https://jax.readthedocs.io/en/latest/_autosummary/jax.eval_shape.html#jax.eval_shape>`_ that can handle `flax.nnx.Module <https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module>`_ / graph nodes as arguments. Similar to ``jax.eval_shape``, it computes the shape/dtype of a function `f` without performing any floating point operations (FLOPs) which can be expensive. This can be useful for performing shape inference, for example. """
@dataclasses.dataclass(eq=False) class CheckifyFn: f: tp.Callable[..., tp.Any] def __post_init__(self): functools.update_wrapper(self, self.f) def __call__(self, *pure_args, **pure_kwargs): args, kwargs = extract.from_tree( (pure_args, pure_kwargs), ctxtag='checkify' ) out = self.f(*args, **kwargs) args_out, kwargs_out = extract.clear_non_graph_nodes((args, kwargs)) pure_args_out, pure_kwargs_out, pure_out = extract.to_tree( (args, kwargs, out), ctxtag='checkify' ) return pure_args_out, pure_kwargs_out, pure_out def checkify( f: tp.Callable[..., checkify_lib.Out], errors: frozenset[type[checkify_lib.JaxException]] = checkify_lib.user_checks, # type: ignore ) -> tp.Callable[..., tuple[checkify_lib.Error, checkify_lib.Out]]: """Reference-aware version of `jax.experimental.checkify <https://flax.readthedocs.io/en/latest/nnx_basics.html#the-flax-functional-api>`_. Example:: >>> import jax >>> import jax.numpy as jnp >>> from jax.experimental import checkify >>> import dataclasses >>> from flax import nnx ... >>> @dataclasses.dataclass ... class Foo(nnx.Module): ... a: nnx.Param ... >>> @nnx.jit ... def f(m): ... y = jnp.sin(m.a.value) # error ... return m.a + y ... >>> m = Foo(a=nnx.Param(jnp.inf)) >>> err, out = nnx.checkify(f, errors=checkify.float_checks)(m) >>> # err.throw() >>> print(err) Error(nan generated by primitive: sin.) """ checkify_fn = checkify_lib.checkify(CheckifyFn(f), errors) @functools.wraps(f) @graph.update_context('checkify') def jit_wrapper(*args, **kwargs): pure_args, pure_kwargs = extract.to_tree( (args, kwargs), ctxtag='checkify', ) error, (pure_args_out, pure_kwargs_out, pure_out) = checkify_fn( *pure_args, **pure_kwargs ) args_out, kwargs_out, out = extract.from_tree( (pure_args_out, pure_kwargs_out, pure_out), ctxtag='checkify', ) return error, out return jit_wrapper # type: ignore
[docs]@general.split_inputs(ctxtag='cond') def cond( pred, true_fun: tp.Callable[..., A], false_fun: tp.Callable[..., A], *operands, **kwargs, ) -> A: return jax.lax.cond( pred, general.merge_inputs(true_fun, ctxtag='cond'), general.merge_inputs(false_fun, ctxtag='cond'), *operands, **kwargs, )
[docs]@general.split_inputs(ctxtag='switch') def switch( index, branches: tp.Sequence[tp.Callable[..., A]], *operands, ) -> A: return jax.lax.switch( index, [general.merge_inputs(f, ctxtag='switch') for f in branches], *operands, )