# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import deque
import dataclasses
import functools
import typing as tp
from flax import struct
from flax.nnx import (
extract,
filterlib,
graph,
variablelib,
)
from flax.nnx.statelib import State
import jax
import jax.core
import jax.stages
from flax.nnx.transforms import general
from flax.nnx.transforms.transforms import resolve_kwargs
from flax.typing import MISSING, Missing
A = tp.TypeVar('A')
# C = tp.TypeVar('C')
# B = tp.TypeVar('B')
F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])
# G = tp.TypeVar('G', bound=tp.Callable[..., tp.Any])
# M = tp.TypeVar('M', bound=Module)
# MA = tp.TypeVar('MA', bound=Module)
# N = tp.TypeVar('N', bound=Module)
# StrInt = tp.TypeVar('StrInt', str, int)
AxisName = tp.Hashable
# Leaves = tp.List[Leaf]
# Index = int
# -------------------------------
# grad
# -------------------------------
@dataclasses.dataclass(frozen=True)
class DiffState:
argnum: int
filter: filterlib.Filter
@dataclasses.dataclass(eq=False)
class GradFn:
f: tp.Callable[..., tp.Any]
has_aux: bool
def __post_init__(self):
functools.update_wrapper(self, self.f)
def __call__(self, *pure_args):
# rebuild diff_state from substates in args
nondiff_states: deque[State | None] = extract.get_broadcast_state('grad')
def _grad_merge_fn(
ctx: graph.MergeContext, path, prefix, value: extract.NodeStates
):
nondiff = nondiff_states.popleft()
if nondiff is None:
return ctx.merge(value.graphdef, value.state)
else:
return ctx.merge(value.graphdef, value.state, nondiff)
args = extract.from_tree(pure_args, merge_fn=_grad_merge_fn, ctxtag='grad')
out = self.f(*args)
args_out = extract.clear_non_graph_nodes(args)
pure_args_out, pure_out = extract.to_tree((args_out, out), ctxtag='grad')
if self.has_aux:
loss, pure_aux = pure_out
fn_out = (loss, (pure_args_out, pure_aux))
else:
loss = pure_out
fn_out = (loss, pure_args_out)
return fn_out
def _grad_general(
f: tp.Callable[..., tp.Any],
argnums: int | DiffState | tp.Sequence[int | DiffState],
has_aux: bool,
holomorphic: bool,
allow_int: bool,
reduce_axes: tp.Sequence[AxisName],
return_value: bool,
) -> tp.Callable[..., tp.Any]:
transform = jax.value_and_grad if return_value else jax.grad
jax_argnums: int | tuple[int, ...]
if isinstance(argnums, (int, DiffState)):
jax_argnums = argnums.argnum if isinstance(argnums, DiffState) else argnums
else:
jax_argnums = tuple(
x.argnum if isinstance(x, DiffState) else x for x in argnums
)
_argnums = (argnums,) if isinstance(argnums, (int, DiffState)) else argnums
index_filter: dict[int, DiffState] = {}
for argnum in _argnums:
index = argnum.argnum if isinstance(argnum, DiffState) else argnum
if index in index_filter:
raise ValueError(f'argnum {index} is repeated in argnums')
index_filter[index] = (
dataclasses.replace(argnum, argnum=-1)
if isinstance(argnum, DiffState)
else DiffState(-1, variablelib.Param)
)
gradded_fn = transform(
GradFn(f, has_aux),
argnums=jax_argnums,
has_aux=True,
holomorphic=holomorphic,
allow_int=allow_int,
reduce_axes=reduce_axes,
)
@graph.update_context('grad')
def grad_wrapper(*args, **kwargs):
args = resolve_kwargs(f, args, kwargs)
del kwargs
nondiff_states: deque[State | None] = deque()
def _grad_split_fn(
ctx: graph.SplitContext, path, prefix: DiffState | None, value
):
if prefix is None:
nondiff_states.append(None)
return extract.NodeStates.from_split(*ctx.split(value))
else:
graphdef, diff, nondiff = ctx.split(value, prefix.filter, ...) # type: ignore[misc]
nondiff_states.append(nondiff)
return extract.NodeStates.from_split(graphdef, diff)
arg_filters = tuple(index_filter.get(i) for i in range(len(args)))
pure_args = extract.to_tree(
args, prefix=arg_filters, split_fn=_grad_split_fn, ctxtag='grad'
)
with extract.broadcast_state('grad', nondiff_states):
fn_out = gradded_fn(*pure_args)
def process_grads(grads):
return jax.tree.map(
lambda x: x.state if isinstance(x, extract.NodeStates) else x,
grads,
is_leaf=lambda x: isinstance(x, extract.NodeStates),
)
def process_out(pure_out: A, /) -> A:
return extract.from_tree(pure_out, ctxtag='grad')
if return_value:
# unpack value_and_grad output
if has_aux:
(loss, (pure_args_out, pure_aux)), grads = fn_out
grads = process_grads(grads)
_args_out, aux = process_out((pure_args_out, pure_aux))
return (loss, aux), grads
else:
(loss, pure_args_out), grads = fn_out
grads = process_grads(grads)
_args_out = process_out(pure_args_out)
return loss, grads
else:
# unpack grad output
if has_aux:
grads, (pure_args_out, pure_aux) = fn_out
grads = process_grads(grads)
_args_out, aux = process_out((pure_args_out, pure_aux))
return grads, aux
else:
grads, pure_args_out = fn_out
grads = process_grads(grads)
_args_out = process_out(pure_args_out)
return grads
return grad_wrapper
@tp.overload
def grad(
f: tp.Callable[..., tp.Any],
*,
argnums: int | DiffState | tp.Sequence[int | DiffState] = 0,
has_aux: bool = False,
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: tp.Sequence[AxisName] = (),
) -> tp.Callable[..., tp.Any]: ...
@tp.overload
def grad(
*,
argnums: int | DiffState | tp.Sequence[int | DiffState] = 0,
has_aux: bool = False,
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: tp.Sequence[AxisName] = (),
) -> tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]]: ...
[docs]def grad(
f: tp.Callable[..., tp.Any] | Missing = MISSING,
*,
argnums: int | DiffState | tp.Sequence[int | DiffState] = 0,
has_aux: bool = False,
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: tp.Sequence[AxisName] = (),
) -> (
tp.Callable[..., tp.Any]
| tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]]
):
"""Lifted version of ``jax.grad`` that can handle Modules / graph nodes as
arguments.
The differentiable state of each graph node is defined by the `wrt` filter,
which by default is set to `nnx.Param`. Internally the ``State`` of
graph nodes is extracted, filtered according to `wrt` filter, and
passed to the underlying ``jax.grad`` function. The gradients
of graph nodes are of type ``State``.
Example::
>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 3))
...
>>> loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2)
>>> grad_fn = nnx.grad(loss_fn)
...
>>> grads = grad_fn(m, x, y)
>>> jax.tree.map(jnp.shape, grads)
State({
'bias': VariableState(
type=Param,
value=(3,)
),
'kernel': VariableState(
type=Param,
value=(2, 3)
)
})
Args:
fun: Function to be differentiated. Its arguments at positions specified by
``argnums`` should be arrays, scalars, graph nodes or standard Python
containers. Argument arrays in the positions specified by ``argnums`` must
be of inexact (i.e., floating-point or complex) type. It should return a
scalar (which includes arrays with shape ``()`` but not arrays with shape
``(1,)`` etc.)
argnums: Optional, integer or sequence of integers. Specifies which
positional argument(s) to differentiate with respect to (default 0).
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
holomorphic. If True, inputs and outputs must be complex. Default False.
allow_int: Optional, bool. Whether to allow differentiating with
respect to integer valued inputs. The gradient of an integer input will
have a trivial vector-space dtype (float0). Default False.
reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
``fun`` implicitly broadcasts a value over that axis, the backward pass
will perform a ``psum`` of the corresponding gradient. Otherwise, the
gradient will be per-example over named axes. For example, if ``'batch'``
is a named batch axis, ``grad(f, reduce_axes=('batch',))`` will create a
function that computes the total gradient while ``grad(f)`` will create
one that computes the per-example gradient.
"""
if isinstance(f, Missing):
return functools.partial(
grad,
argnums=argnums,
has_aux=has_aux,
holomorphic=holomorphic,
allow_int=allow_int,
reduce_axes=reduce_axes,
)
return _grad_general(
f,
argnums,
has_aux,
holomorphic,
allow_int,
reduce_axes,
return_value=False,
)
@tp.overload
def value_and_grad(
f: tp.Callable[..., tp.Any],
*,
argnums: int | DiffState | tp.Sequence[int | DiffState] = 0,
has_aux: bool = False,
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: tp.Sequence[AxisName] = (),
) -> tp.Callable[..., tp.Any]: ...
@tp.overload
def value_and_grad(
*,
argnums: int | DiffState | tp.Sequence[int | DiffState] = 0,
has_aux: bool = False,
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: tp.Sequence[AxisName] = (),
) -> tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]]: ...
[docs]def value_and_grad(
f: tp.Callable[..., tp.Any] | type[Missing] = Missing,
*,
argnums: int | DiffState | tp.Sequence[int | DiffState] = 0,
has_aux: bool = False,
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: tp.Sequence[AxisName] = (),
) -> (
tp.Callable[..., tp.Any]
| tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]]
):
if f is Missing:
return functools.partial(
value_and_grad,
argnums=argnums,
has_aux=has_aux,
holomorphic=holomorphic,
allow_int=allow_int,
reduce_axes=reduce_axes,
)
return _grad_general(
f,
argnums,
has_aux,
holomorphic,
allow_int,
reduce_axes,
return_value=True,
)
# -----------------------------------------------
# custom_vjp
# -----------------------------------------------
# custom_vjp is one of the most complicated transforms as it requires
# to handle 4 different functions:
# 1. CustomVJP: the main object that runs the outer logic, converts input graph nodes
# to pytrees and output pytrees to graph nodes.
# 2. CustomVjpFnWrapper: function that wraps the user's function, it converts
# its input pytrees to graph nodes and output graph nodes to pytrees.
# 3. FwdFn: wraps the user's fwd function, it converts its input pytrees to graph nodes
# and output graph nodes to pytrees. Since it might run by itself in a separate context,
# it needs to be aware if the update_context is active or not in order to update the outer
# referenes.
# 4. BwdFn: wraps the user's bwd function, it converts its input pytrees to graph nodes
# and output graph nodes to pytrees. It doesn't need to be aware of the outer context
# since it will never update the outer references as it runs during the backward pass.
def _custom_vjp_merge_fn(
ctx: graph.MergeContext,
path,
prefix: bool | DiffState,
value: extract.NodeStates,
*,
nondiff_states: deque[extract.GraphDefState],
):
nondiff = nondiff_states.popleft()
return ctx.merge(nondiff.graphdef, value.state, nondiff.state)
def _custom_vjp_split_fn(
ctx: graph.SplitContext,
path,
prefix: bool | DiffState,
value,
*,
nondiff_states: list[extract.GraphDefState],
):
broadcast: graph.GraphState
if prefix is False:
# pure non-differentiable arg, not supported
raise TypeError(
'Passing integers to nondiff_argnums for graph nodes arguments in custom_vjp is not supported. '
f'Got {prefix} at path {jax.tree_util.keystr(path)} for value {value}'
)
elif prefix is True:
# pure differentiable arg, we pass all the state through
# but we return a TreeNode.from_states which doesn't have a graphdef
# in order to keep the gradients clean from any metadata
graphdef, passed = ctx.split(value)
broadcast = State({})
nondiff_states.append(extract.GraphDefState(graphdef, broadcast))
return extract.NodeStates.from_states(passed)
else:
# differentiable arg with DiffState filter, we use the filter to split the state
# as before we return a TreeNode.from_states to keep the gradients clean
# from any metadata, the non-differentiable state is stored in a deque
# which is broadcasted during the forward pass
graphdef, passed, broadcast = ctx.split(value, prefix.filter, ...) # type: ignore[misc]
nondiff_states.append(extract.GraphDefState(graphdef, broadcast))
return extract.NodeStates.from_states(passed)
nondiff_argnums: tuple[int, ...] = struct.field(pytree_node=False)
tangent_tree_node_args: tuple[tp.Any, ...] = struct.field(pytree_node=False)
def _extract_index_mappings(x, *, index_mappings: deque[graph.HashableMapping]):
if isinstance(x, graph.NodeDef):
assert x.index_mapping is not None
index_mappings.append(x.index_mapping)
return dataclasses.replace(x, index_mapping=None)
return x
@dataclasses.dataclass(eq=False)
class CustomVjpFnWrapper:
f: tp.Callable[..., tp.Any]
jax_nondiff_argnums: tuple[int, ...]
ctxtag: str
nondiff_states: list[extract.GraphDefState]
def __post_init__(self):
functools.update_wrapper(self, self.f)
def __call__(self, *pure_args):
nondiff_states = deque(self.nondiff_states)
args = extract.from_tree(
pure_args,
merge_fn=functools.partial(
_custom_vjp_merge_fn, nondiff_states=nondiff_states
),
ctxtag=self.ctxtag,
)
out = self.f(*args)
# remove nondiff from pure_args_out_g
args_out = tuple(
x for i, x in enumerate(args) if i not in self.jax_nondiff_argnums
)
args_out = extract.clear_non_graph_nodes(args_out)
pure_args_out, pure_out = extract.to_tree(
(args_out, out), ctxtag=self.ctxtag
)
# remove index_mapping from NodeDef's but store them in global context
index_mappings: deque[graph.HashableMapping] = extract.get_broadcast_state(
self.ctxtag
)
pure_args_out, pure_out = jax.tree.map(
functools.partial(_extract_index_mappings, index_mappings=index_mappings),
(pure_args_out, pure_out),
is_leaf=lambda x: isinstance(x, graph.NodeDef),
)
return pure_args_out, pure_out
@dataclasses.dataclass(eq=False)
class FwdFn:
fwd: tp.Callable[..., tp.Any]
nondiff_argnums: tuple[int, ...]
ctxtag: str
nondiff_states: list[extract.GraphDefState]
def __post_init__(self):
functools.update_wrapper(self, self.fwd)
def __call__(self, *pure_args):
# here we need to be aware if the update_context is active or not
# when its not active, index_mappings will be None
# when its active, we will remove the index_mappings from the NodeDef's and store them
# in the index_mappings deque created by CustomVjp
update_context_active = (
self.ctxtag in graph.GRAPH_CONTEXT.update_context_stacks
)
nondiff_states = deque(self.nondiff_states)
args = extract.from_tree(
pure_args,
merge_fn=functools.partial(
_custom_vjp_merge_fn, nondiff_states=nondiff_states
),
ctxtag=self.ctxtag if update_context_active else None,
)
out, residual = self.fwd(*args)
# remove nondiff from pure_args_out_g
args_out = tuple(
x for i, x in enumerate(args) if i not in self.nondiff_argnums
)
args_out = extract.clear_non_graph_nodes(args_out)
pure_args_out, pure_out = extract.to_tree(
(args_out, out),
ctxtag=self.ctxtag if update_context_active else None,
)
pure_residual = extract.to_tree(residual)
if update_context_active:
# remove index_mapping from NodeDef's but store them in global context
index_mappings: deque[graph.HashableMapping] = (
extract.get_broadcast_state(self.ctxtag)
)
pure_args_out, pure_out = jax.tree.map(
functools.partial(
_extract_index_mappings, index_mappings=index_mappings
),
(pure_args_out, pure_out),
is_leaf=lambda x: isinstance(x, graph.NodeDef),
)
return (pure_args_out, pure_out), pure_residual
@dataclasses.dataclass(eq=False)
class BwdFn:
bwd: tp.Callable[..., tp.Any]
tree_node_args: tuple[tp.Any, ...]
def __post_init__(self):
functools.update_wrapper(self, self.bwd)
def __call__(self, *args):
*nondiff, pure_residual, (pure_args_out_g, pure_out_g) = args
residual = extract.from_tree(pure_residual)
(pure_args_out_g, pure_out_g) = jax.tree.map(
lambda x: x.state if isinstance(x, extract.NodeStates) else x,
(pure_args_out_g, pure_out_g),
is_leaf=lambda x: isinstance(x, extract.NodeStates),
)
tangent = self.bwd(*nondiff, residual, (pure_args_out_g, pure_out_g))
def state_to_node_states(is_differentiable: bool, x):
if is_differentiable:
if isinstance(x, jax.Array):
return x
elif not isinstance(x, State):
raise ValueError(f'Expected State, got {type(x)}')
return extract.NodeStates.from_states(x)
return x
pure_tangent = jax.tree.map(
state_to_node_states,
self.tree_node_args,
tangent,
is_leaf=lambda x: isinstance(x, State),
)
return pure_tangent
class CustomVjp(tp.Generic[A]):
def __init__(
self,
fun: tp.Callable[..., A],
nondiff_argnums: tuple[int | DiffState, ...],
):
functools.update_wrapper(self, fun)
# first argument is metadata
self.jax_nondiff_argnums = tuple(
x for x in nondiff_argnums if isinstance(x, int)
)
self.ctxtag = f'custom_vjp_{fun.__name__}_{id(fun)}'
self.fun = fun
self.fwd: tp.Callable | None = None
self.bwd: tp.Callable | None = None
self.symbolic_zeros: bool | None = None
self.nondiff_argnums = nondiff_argnums
self.diff_filter: dict[int, tp.Literal[False] | DiffState] = {}
for argnum in self.nondiff_argnums:
index = argnum.argnum if isinstance(argnum, DiffState) else argnum
if index in self.diff_filter:
raise ValueError(f'argnum {index} is repeated in nondiff_argnums')
self.diff_filter[index] = (
dataclasses.replace(argnum, argnum=-1)
if isinstance(argnum, DiffState)
else False
)
# def __getattr__(self, name: str) -> tp.Any:
# if not hasattr(self.custom_vjp_fn, name):
# raise AttributeError(f'{type(self).__name__} has no attribute {name}')
# return getattr(self.custom_vjp_fn, name)
def __call__(
self, *args: tp.Any, **kwargs: tp.Any
) -> A: # pytype: disable=invalid-annotation
with graph.update_context(self.ctxtag):
args = resolve_kwargs(self.fun, args, kwargs)
del kwargs
nondiff_states: list[extract.GraphDefState] = []
arg_filters = tuple(
self.diff_filter.get(i, True) for i in range(len(args))
)
pure_args = extract.to_tree(
args,
prefix=arg_filters,
split_fn=functools.partial(
_custom_vjp_split_fn, nondiff_states=nondiff_states
),
ctxtag=self.ctxtag,
)
tree_node_args = jax.tree.map(
lambda x: isinstance(x, extract.NodeStates),
pure_args,
is_leaf=lambda x: isinstance(x, extract.NodeStates),
)
tree_node_args = tuple(
x
for i, x in enumerate(tree_node_args)
if i not in self.jax_nondiff_argnums
)
index_mappings: deque[graph.HashableMapping] = deque()
with extract.broadcast_state(self.ctxtag, index_mappings):
if self.fwd is None or self.bwd is None or self.symbolic_zeros is None:
raise ValueError()
custom_vjp_fn = jax.custom_vjp(
fun=CustomVjpFnWrapper(
f=self.fun,
jax_nondiff_argnums=self.jax_nondiff_argnums,
ctxtag=self.ctxtag,
nondiff_states=nondiff_states,
),
nondiff_argnums=self.jax_nondiff_argnums,
)
custom_vjp_fn.defvjp(
fwd=FwdFn(
fwd=self.fwd,
nondiff_argnums=self.jax_nondiff_argnums,
ctxtag=self.ctxtag,
nondiff_states=nondiff_states,
),
bwd=BwdFn(
bwd=self.bwd,
tree_node_args=tree_node_args,
),
symbolic_zeros=self.symbolic_zeros,
)
pure_args_out, pure_out = custom_vjp_fn(*pure_args)
# insert index_mappings
def _insert_index_mappings(x):
if isinstance(x, graph.NodeDef):
index_mapping: graph.HashableMapping = index_mappings.popleft()
return dataclasses.replace(x, index_mapping=index_mapping)
return x
pure_args_out, pure_out = jax.tree_util.tree_map(
_insert_index_mappings,
(pure_args_out, pure_out),
is_leaf=lambda x: isinstance(x, graph.NodeDef),
)
args_out, out = extract.from_tree(
(pure_args_out, pure_out), ctxtag=self.ctxtag
)
return out
def defvjp(
self,
fwd: tp.Callable[..., tuple[A, tp.Any]],
bwd: tp.Callable[..., tuple[tp.Any, ...]],
symbolic_zeros: bool = False,
) -> None:
self.fwd = fwd
self.bwd = bwd
self.symbolic_zeros = symbolic_zeros
@tp.overload
def custom_vjp(
fun: tp.Callable[..., A],
*,
nondiff_argnums: tuple[int | DiffState, ...] = (),
) -> CustomVjp[A]: ...
@tp.overload
def custom_vjp(
*,
nondiff_argnums: tuple[int | DiffState, ...] = (),
) -> tp.Callable[[tp.Callable[..., A]], CustomVjp[A]]: ...
[docs]def custom_vjp(
fun: tp.Callable[..., A] | Missing = MISSING,
*,
nondiff_argnums: tuple[int | DiffState, ...] = (),
) -> CustomVjp[A] | tp.Callable[[tp.Callable[..., A]], CustomVjp[A]]:
"""Reference aware version of
`jax.custom_vjp <https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_vjp.html>`__.
``nnx.custom_vjp`` accepts Modules and other Flax NNX objects as arguments. The main difference
with the JAX version is that, because Modules follow reference semantics, they propagate the State
updates for the inputs as auxiliary outputs. This means that the incomming gradients in the ``bwd`` function
will have the form ``(input_updates_g, out_g)`` where ``input_updates_g`` is the gradient updated state of
the inputs w.r.t. to the inputs. All Module terms on the inputs will an associated ``State`` term in
``input_updates_g``, while all non-Module terms will appear as None. The shape of the tanget will be
expected to have the same shape as the input, with ``State`` terms in place of the corresponding Module terms.
Example::
>>> import jax
>>> import jax.numpy as jnp
>>> from flax import nnx
...
>>> class Foo(nnx.Module):
... def __init__(self, x, y):
... self.x = nnx.Param(x)
... self.y = nnx.Param(y)
...
>>> @nnx.custom_vjp
... def f(m: Foo):
... return jnp.sin(m.x) * m.y
...
>>> def f_fwd(m: Foo):
... return f(m), (jnp.cos(m.x), jnp.sin(m.x), m)
...
>>> def f_bwd(res, g):
... input_updates_g, out_g = g
... cos_x, sin_x, m = res
... (m_updates_g,) = input_updates_g
... m_g = jax.tree.map(lambda x: x, m_updates_g) # create copy
...
... m_g['x'].value = cos_x * out_g * m.y
... m_g['y'].value = sin_x * out_g
... return (m_g,)
...
>>> f.defvjp(f_fwd, f_bwd)
...
>>> m = Foo(x=jnp.array(1.), y=jnp.array(2.))
>>> grads = nnx.grad(f)(m)
...
>>> jax.tree.map(jnp.shape, grads)
State({
'x': VariableState(
type=Param,
value=()
),
'y': VariableState(
type=Param,
value=()
)
})
Note that the State objects that represent Module terms on ``input_updates_g`` have the
same shape as the State objects expected in the output tanget. This means that you can
usually just copy them from ``input_updates_g`` and update them with their corresponding
gradient values.
You can select which substates are differentiable (have a tangent) for Modules and other
graph nodes by passing a ``DiffState`` to ``nondiff_argnums``. For example, if you want to
differentiate only the ``x`` attribute of the ``Foo`` class, you can do the following::
>>> x_attribute = nnx.PathContains('x')
>>> diff_state = nnx.DiffState(0, x_attribute)
...
>>> @nnx.custom_vjp(nondiff_argnums=(diff_state,))
... def f(m: Foo):
... return jnp.sin(m.x) * m.y # type: ignore
>>> def f_fwd(m: Foo):
... y = f(m)
... res = (jnp.cos(m.x), m) # type: ignore
... return y, res
...
>>> def f_bwd(res, g):
... input_updates_g, out_g = g
... cos_x, m = res
... (m_updates_g,) = input_updates_g
... m_g = jax.tree.map(lambda x: x, m_updates_g) # create copy
...
... m_g.x.value = cos_x * out_g * m.y
... del m_g['y'] # y is not differentiable
... return (m_g,)
>>> f.defvjp(f_fwd, f_bwd)
...
>>> m = Foo(x=jnp.array(1.), y=jnp.array(2.))
>>> grad = nnx.grad(f, argnums=nnx.DiffState(0, x_attribute))(m)
...
>>> jax.tree.map(jnp.shape, grad)
State({
'x': VariableState(
type=Param,
value=()
)
})
Note that ``grad`` cannot calculate gradients for states that don't have a tangent
defined by ``custom_vjp``, in the example above we reuse the same ``x_attribute``
filter to keep ``custom_vjp`` and ``grad`` in sync.
Args:
fun: Callable base function.
nondiff_argnums: Tuple of integers or DiffState objects specifying the
argument indices that are not differentiated. By default all arguments are
differentiated. Integers cannot be used to mark graph nodes such as Modules
as non-differentiable, in this case use a DiffState object. DiffState objects
define the set of differentiable substates, contrary to what the name of this
argument suggests, this is done for compatibility with ``grad``.
"""
if isinstance(fun, Missing):
return functools.partial(custom_vjp, nondiff_argnums=nondiff_argnums)
return CustomVjp(fun, nondiff_argnums)
# -------------------------------
# remat
# -------------------------------
@tp.overload
def remat(
*,
prevent_cse: bool = True,
static_argnums: int | tuple[int, ...] = (),
policy: tp.Callable[..., bool] | None = None,
) -> tp.Callable[[F], F]: ...
@tp.overload
def remat(
f: F,
*,
prevent_cse: bool = True,
static_argnums: int | tuple[int, ...] = (),
policy: tp.Callable[..., bool] | None = None,
) -> F: ...
[docs]def remat(
f: F | Missing = MISSING,
*,
prevent_cse: bool = True,
static_argnums: int | tuple[int, ...] = (),
policy: tp.Callable[..., bool] | None = None,
) -> F | tp.Callable[[F], F]:
if isinstance(f, Missing):
return functools.partial(
remat,
prevent_cse=prevent_cse,
static_argnums=static_argnums,
policy=policy,
) # type: ignore[return-value]
return resolve_kwargs()(
graph.update_context('remat')(
general.split_inputs(
jax.checkpoint(
general.merge_inputs(f, ctxtag='remat'),
prevent_cse=prevent_cse,
static_argnums=static_argnums,
policy=policy,
),
ctxtag='remat',
),
)
)
"""A 'lifted' version of the
`jax.checkpoint <https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html>`__
(a.k.a. ``jax.remat``).
``flax.nnx.remat``, similar to ``jax.checkpoint`` can provide control over, for
example, how ``flax.nnx.grad`` values are computed and saved during the forward pass versus
how they are recomputed during the backward pass, trading off memory and FLOPs.
Learn more in `Flax NNX vs JAX Transformations <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`_.
To learn about ``jax.remat``, go to JAX's
`fundamentals of jax.checkpoint <https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#fundamentals-of-jax-checkpoint>`_
and `practical notes <https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#practical-notes>`_.
"""