# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pytype: skip-file
from __future__ import annotations
import dataclasses
import functools
import typing as tp
from flax import struct
from flax.core.frozen_dict import FrozenDict
from flax.nnx import extract, filterlib, graph, rnglib, spmd, variablelib
from flax.nnx.module import GraphDef, Module
from flax.nnx.proxy_caller import DelayedAccessor
from flax.nnx.statelib import State
from flax.nnx.transforms.transforms import LiftedModule
from flax.typing import MISSING, Leaf, Missing
import jax
from jax._src.tree_util import broadcast_prefix
import jax.core
import jax.numpy as jnp
import jax.stages
from flax import nnx
A = tp.TypeVar('A')
C = tp.TypeVar('C')
B = tp.TypeVar('B')
F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])
G = tp.TypeVar('G', bound=tp.Callable[..., tp.Any])
M = tp.TypeVar('M', bound=Module)
MA = tp.TypeVar('MA', bound=Module)
N = tp.TypeVar('N', bound=Module)
StrInt = tp.TypeVar('StrInt', str, int)
AxisName = tp.Hashable
Leaves = tp.List[Leaf]
Index = int
def _normalize_sequence(
x: StrInt | tp.Iterable[StrInt] | None, /
) -> tuple[StrInt, ...]:
if x is None:
return ()
elif isinstance(x, (str, int)):
return (x,) # type: ignore
else:
return tuple(x)
# -------------------------------
# vmap
# -------------------------------
class _VmapForkStates(tp.NamedTuple):
split_keys: State
split_counts: State
broadcast_keys: State
broadcast_counts: State
def _get_axis_sizes(pytree, axes):
axes = broadcast_prefix(axes, pytree, is_leaf=lambda x: x is None)
leaves = jax.tree_util.tree_leaves(pytree)
axis_sizes = {
leaf.shape[axis] for axis, leaf in zip(axes, leaves) if axis is not None
}
return axis_sizes
def _fork_vmap_keys(
state: State,
split_filter: filterlib.Filter,
num_splits: int,
) -> _VmapForkStates:
split_keys, split_counts, broadcast_keys, broadcast_counts = state.split(
filterlib.All(split_filter, rnglib.RngKey),
filterlib.All(split_filter, rnglib.RngCount),
rnglib.RngKey,
rnglib.RngCount,
)
def split_key(key: tp.Any, count: tp.Any) -> jax.Array:
if not isinstance(key, jax.Array):
raise TypeError(f'key must be a jax.Array, got {type(key)}')
if not isinstance(count, jax.Array):
raise TypeError(f'count must be a jax.Array, got {type(count)}')
key = jax.random.fold_in(key, count)
return jax.random.split(key, num_splits)
split_keys_leaves, split_keys_treedef = jax.tree.flatten(split_keys)
split_counts_leaves, split_counts_treedef = jax.tree.flatten(split_counts)
if len(split_keys_leaves) != len(split_counts_leaves):
raise ValueError(
'split_keys and split_counts must have the same number of leaves',
f'got {len(split_keys_leaves)} and {len(split_counts_leaves)}',
)
split_keys_leaves = [
split_key(key, count)
for key, count in zip(split_keys_leaves, split_counts_leaves)
]
split_counts_leaves = [
jnp.full((num_splits,), 0, dtype=jnp.uint32) for _ in split_counts_leaves
]
split_keys = jax.tree.unflatten(split_keys_treedef, split_keys_leaves)
split_counts = jax.tree.unflatten(split_counts_treedef, split_counts_leaves)
return _VmapForkStates(
split_keys, split_counts, broadcast_keys, broadcast_counts
)
def _backup_vmap_keys(node: tp.Any, /):
backups: list[
tuple[graph.PathParts, rnglib.RngStream, jax.Array, jax.Array]
] = []
for path, stream in graph.iter_graph(node):
if isinstance(stream, rnglib.RngStream):
backups.append((path, stream, stream.key.value, stream.count.value))
return backups
def _restore_vmap_keys(
backups: list[tuple[graph.PathParts, rnglib.RngStream, jax.Array, jax.Array]],
split_rngs: filterlib.Filter,
/,
):
predicate_fn = filterlib.to_predicate(split_rngs)
for path, stream, key, count in backups:
stream.key.value = key
count_path = (*path, 'count')
if predicate_fn(count_path, stream.count.to_state()):
# restore count only if it was split
# add 1 to reflect the split
stream.count.value = count + 1
def vmap_fn(
args: tuple[tp.Any, ...],
kwargs: dict[str, tp.Any],
graphdef: GraphDef[tuple[tp.Any, ...]],
split_keys: State,
split_counts: State,
broadcast_keys: State,
broadcast_counts: State,
vectorized_states: list[State],
broadcast_state: State,
transform_metadata: tp.Mapping[str, tp.Any],
state_axes_: list[tuple[filterlib.Filter, int]],
f: tp.Callable[..., tp.Any],
filters: tp.Tuple[filterlib.Filter, ...],
split_rngs: filterlib.Filter,
):
ctx = graph.current_update_context('vmap')
state_axes = dict(state_axes_)
# remove metadata axis name from Variable.sharding
if spmd.PARTITION_NAME in transform_metadata:
vectorized_states = [
spmd.remove_axis(state, index, transform_metadata)
for state, index in zip(vectorized_states, state_axes.values())
]
# merge module state
input_graph_nodes = ctx.merge(
graphdef,
*vectorized_states,
broadcast_state,
split_keys,
split_counts,
broadcast_keys,
broadcast_counts,
)
(args, kwargs) = extract.insert_graph_nodes((args, kwargs), input_graph_nodes)
out = f(*args, **kwargs)
out, output_graph_nodes = extract.extract_graph_nodes(out)
# split module state
(
graphdef_out,
rng_state_out,
*vectorized_states_out,
broadcast_state_out,
) = ctx.split( # type: ignore[misc]
(input_graph_nodes, output_graph_nodes),
rnglib.RngState,
*filters,
)
split_keys_out, broadcast_keys_out = rng_state_out.split(split_rngs, ...)
broadcast_state_out = State.merge(broadcast_state_out, broadcast_keys_out)
# add metadata axis name to Variable.sharding
if spmd.PARTITION_NAME in transform_metadata:
vectorized_states_out = [
spmd.add_axis(state, index, transform_metadata)
for state, index in zip(vectorized_states_out, state_axes.values())
]
return (
graphdef_out,
broadcast_state_out,
vectorized_states_out,
split_keys_out,
out,
)
@tp.overload
def vmap(
*,
in_axes: int | None | tp.Sequence[tp.Any] = 0,
out_axes: tp.Any = 0,
axis_name: AxisName | None = None,
axis_size: int | None = None,
spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
# nnx specific
in_axes_kwargs: tp.Any = 0,
state_axes: tp.Mapping[filterlib.Filter, int | None] = FrozenDict({...: 0}),
split_rngs: filterlib.Filter = ...,
transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}),
) -> tp.Callable[[F], F]: ...
@tp.overload
def vmap(
f: F,
*,
in_axes: int | None | tp.Sequence[tp.Any] = 0,
out_axes: tp.Any = 0,
axis_name: AxisName | None = None,
axis_size: int | None = None,
spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
# nnx specific
in_axes_kwargs: tp.Any = 0,
state_axes: tp.Mapping[filterlib.Filter, int | None] = FrozenDict({...: 0}),
split_rngs: filterlib.Filter = ...,
transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}),
) -> F: ...
def vmap(
f: F | Missing = MISSING,
*,
in_axes: int | None | tp.Sequence[tp.Any] = 0,
out_axes: tp.Any = 0,
axis_name: AxisName | None = None,
axis_size: int | None = None,
spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
# nnx specific
in_axes_kwargs: tp.Any = 0,
state_axes: tp.Mapping[filterlib.Filter, int | None] = FrozenDict({...: 0}),
split_rngs: filterlib.Filter = ...,
transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}),
) -> F | tp.Callable[[F], F]:
if isinstance(f, Missing):
return functools.partial(
vmap,
in_axes=in_axes,
out_axes=out_axes,
axis_name=axis_name,
axis_size=axis_size,
spmd_axis_name=spmd_axis_name,
in_axes_kwargs=in_axes_kwargs,
state_axes=state_axes,
split_rngs=split_rngs,
transform_metadata=transform_metadata,
) # type: ignore[return-value]
vectorized_states_axes = list(state_axes.values())
vmapped_fn = jax.vmap(
vmap_fn,
in_axes=(
in_axes, # args
in_axes_kwargs, # kwargs
None, # graphdef
0, # split_keys
0, # split_counts
None, # broadcast_keys
None, # broadcast_counts
vectorized_states_axes, # vectorized_states
None, # broadcast_state
None, # transform_metadata
None, # states_axes
None, # f
None, # vectorized_states_filters
None, # split_rngs
),
out_axes=(
None, # graphdef_out
None, # broadcast_state
vectorized_states_axes,
0, # keys_out
out_axes, # out_axes
),
axis_name=axis_name,
axis_size=axis_size,
spmd_axis_name=spmd_axis_name,
)
@functools.wraps(f)
@graph.update_context('vmap')
def vmap_wrapper(*args, **kwargs):
ctx = graph.current_update_context('vmap')
(args, kwargs), input_graph_nodes = extract.extract_graph_nodes(
(args, kwargs)
)
input_rng_streams = _backup_vmap_keys(input_graph_nodes)
# split module state
filters = (*state_axes.keys(), ...)
graphdef, rng_state, *vectorized_states, broadcast_state = ctx.split( # type: ignore[misc]
input_graph_nodes, rnglib.RngState, *filters
)
# infer length
axis_sizes: tp.Set[int] = set()
axis_sizes.update(_get_axis_sizes(args, in_axes))
axis_sizes.update(_get_axis_sizes(kwargs, in_axes_kwargs))
for state, state_axis in zip(vectorized_states, state_axes.values()):
axis_sizes.update(_get_axis_sizes(state, state_axis))
if len(axis_sizes) > 1:
raise ValueError(
'Inconsistent lengths between state_axes states and '
f'arguments: {axis_sizes}'
)
elif len(axis_sizes) == 0:
if axis_size is None:
raise ValueError(
'Cannot infer length from state_axes states or axes_arg, '
'please specify `length`'
)
_axis_size = axis_size
else:
_axis_size = axis_sizes.pop()
if axis_size is not None and axis_size != _axis_size:
raise ValueError(
f'Specified axis_size {axis_size} is not the same as the'
f' inferred length {_axis_size}'
)
split_keys, split_counts, broadcast_keys, broadcast_counts = (
_fork_vmap_keys(
rng_state,
split_rngs,
_axis_size,
)
)
(
graphdef_out,
broadcast_state,
vectorized_states,
split_keys_out,
out,
) = vmapped_fn(
args,
kwargs,
graphdef,
split_keys,
split_counts,
broadcast_keys,
broadcast_counts,
vectorized_states,
broadcast_state,
transform_metadata,
list(state_axes.items()),
f,
filters,
split_rngs,
)
_, output_graph_nodes = ctx.merge(
graphdef_out,
*vectorized_states,
broadcast_state,
split_keys_out,
)
out = extract.insert_graph_nodes(out, output_graph_nodes)
_restore_vmap_keys(input_rng_streams, split_rngs)
return out
return vmap_wrapper # type: ignore
[docs]class Vmap(tp.Generic[M], LiftedModule[M]):
@staticmethod
def constructor(
module_constructor: tp.Callable[..., MA],
*,
in_axes: int | None | tp.Sequence[tp.Any] = 0,
out_axes: tp.Any = 0,
axis_name: AxisName | None = None,
axis_size: int | None = None,
spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
# nnx specific
in_axes_kwargs: tp.Any = 0,
state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}),
split_rngs: filterlib.Filter = ...,
transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}),
) -> tp.Callable[..., Vmap[MA]]:
def _create_vmap(*args, **kwargs):
return Vmap(
module_constructor=module_constructor,
in_axes=in_axes,
out_axes=out_axes,
axis_size=axis_size,
axis_name=axis_name,
spmd_axis_name=spmd_axis_name,
# nnx specific
in_axes_kwargs=in_axes_kwargs,
state_axes=state_axes,
split_rngs=split_rngs,
transform_metadata=transform_metadata,
# submodule args
module_init_args=args,
module_init_kwargs=kwargs,
)
return _create_vmap
def __init__(
self,
module_constructor: tp.Callable[..., M],
*,
in_axes: int | None | tp.Sequence[tp.Any] = 0,
out_axes: tp.Any = 0,
axis_name: AxisName | None = None,
axis_size: int | None = None,
spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
# nnx specific
in_axes_kwargs: tp.Any = 0,
state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}),
split_rngs: filterlib.Filter = ...,
transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}),
# submodule args
module_init_args: tuple[tp.Any, ...],
module_init_kwargs: dict[str, tp.Any],
):
self.module_constructor = module_constructor
@functools.partial(
vmap,
in_axes=None,
out_axes=None,
axis_name=axis_name,
axis_size=axis_size,
spmd_axis_name=spmd_axis_name,
in_axes_kwargs=None,
state_axes=state_axes,
split_rngs=split_rngs,
transform_metadata=transform_metadata,
)
def vmap_init(*args, **kwargs):
return module_constructor(*args, **kwargs)
self.vmap_module = vmap_init(*module_init_args, **module_init_kwargs)
@functools.partial(
vmap,
in_axes=in_axes,
out_axes=out_axes,
axis_name=axis_name,
axis_size=axis_size,
spmd_axis_name=spmd_axis_name,
in_axes_kwargs=in_axes_kwargs,
state_axes=state_axes,
split_rngs=split_rngs,
transform_metadata=transform_metadata,
)
def vmap_call(module, *args, _nnx_vmap_accessor: DelayedAccessor, **kwargs):
method = _nnx_vmap_accessor(module)
return method(*args, **kwargs)
self.vmap_call = vmap_call
@property
def _submodule(self) -> M:
return self.vmap_module
def _call(self, accessor: DelayedAccessor, *args, **kwargs):
return self.vmap_call(
self._submodule, *args, _nnx_vmap_accessor=accessor, **kwargs
) # type: ignore[type-var, call-arg]
# -------------------------------
# pmap
# -------------------------------
@struct.dataclass
class PmapInputs:
transform_metadata: tp.Mapping[str, tp.Any] = struct.field(pytree_node=False)
state_axes: tp.Mapping[filterlib.Filter, int] = struct.field(
pytree_node=False
)
f: tp.Callable[..., tp.Any] = struct.field(pytree_node=False)
filters: tp.Tuple[filterlib.Filter, ...] = struct.field(pytree_node=False)
split_rngs: filterlib.Filter = struct.field(pytree_node=False)
def pmap_fn(
args: tuple[tp.Any, ...],
kwargs: dict[str, tp.Any],
graphdef: GraphDef[tuple[tp.Any, ...]],
split_keys: State,
split_counts: State,
broadcast_keys: State,
broadcast_counts: State,
vectorized_states: list[State],
broadcast_state: State,
pmap_inputs: PmapInputs,
):
transform_metadata = pmap_inputs.transform_metadata
state_axes = pmap_inputs.state_axes
f = pmap_inputs.f
filters = pmap_inputs.filters
split_rngs = pmap_inputs.split_rngs
ctx = graph.current_update_context('pmap')
# remove metadata axis name from Variable.sharding
if spmd.PARTITION_NAME in transform_metadata:
vectorized_states = [
spmd.remove_axis(state, index, transform_metadata)
for state, index in zip(vectorized_states, state_axes.values())
]
# merge module state
input_graph_nodes = ctx.merge(
graphdef,
*vectorized_states,
broadcast_state,
split_keys,
split_counts,
broadcast_keys,
broadcast_counts,
)
(args, kwargs) = extract.insert_graph_nodes((args, kwargs), input_graph_nodes)
out = f(*args, **kwargs)
out, output_graph_nodes = extract.extract_graph_nodes(out)
# split module state
(
graphdef_out,
rng_state_out,
*vectorized_states_out,
broadcast_state_out,
) = ctx.split( # type: ignore[misc]
(input_graph_nodes, output_graph_nodes),
rnglib.RngState,
*filters,
)
not_keys_out, split_keys_out, broadcast_keys_out = rng_state_out.split(
rnglib.NotKey, split_rngs, ...
)
broadcast_state_out = State.merge(
broadcast_state_out, broadcast_keys_out, not_keys_out
)
# add metadata axis name to Variable.sharding
if spmd.PARTITION_NAME in transform_metadata:
vectorized_states_out = [
spmd.add_axis(state, index, transform_metadata)
for state, index in zip(vectorized_states_out, state_axes.values())
]
return (
graphdef_out,
broadcast_state_out,
vectorized_states_out,
split_keys_out,
out,
)
@tp.overload
def pmap(
*,
axis_name: AxisName | None = None,
in_axes: tp.Any = 0,
out_axes: tp.Any = 0,
static_broadcasted_argnums: int | tp.Iterable[int] = (),
devices: tp.Sequence[jax.Device] | None = None, # noqa: F811
backend: str | None = None,
axis_size: int | None = None,
donate_argnums: int | tp.Iterable[int] = (),
global_arg_shapes: tuple[tuple[int, ...], ...] | None = None,
# nnx specific
in_axes_kwargs: tp.Any = 0,
state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}),
split_rngs: filterlib.Filter = ...,
transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}),
) -> tp.Callable[[F], F]: ...
@tp.overload
def pmap(
f: F,
*,
axis_name: AxisName | None = None,
in_axes: tp.Any = 0,
out_axes: tp.Any = 0,
static_broadcasted_argnums: int | tp.Iterable[int] = (),
devices: tp.Sequence[jax.Device] | None = None, # noqa: F811
backend: str | None = None,
axis_size: int | None = None,
donate_argnums: int | tp.Iterable[int] = (),
global_arg_shapes: tuple[tuple[int, ...], ...] | None = None,
# nnx specific
in_axes_kwargs: tp.Any = 0,
state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}),
split_rngs: filterlib.Filter = ...,
transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}),
) -> F: ...
def pmap(
f: F | Missing = MISSING,
*,
axis_name: AxisName | None = None,
in_axes: tp.Any = 0,
out_axes: tp.Any = 0,
static_broadcasted_argnums: int | tp.Iterable[int] = (),
devices: tp.Sequence[jax.Device] | None = None, # noqa: F811
backend: str | None = None,
axis_size: int | None = None,
donate_argnums: int | tp.Iterable[int] = (),
global_arg_shapes: tuple[tuple[int, ...], ...] | None = None,
# nnx specific
in_axes_kwargs: tp.Any = 0,
state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}),
split_rngs: filterlib.Filter = ...,
transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}),
) -> F | tp.Callable[[F], F]:
if isinstance(f, Missing):
return functools.partial(
pmap,
axis_name=axis_name,
in_axes=in_axes,
out_axes=out_axes,
static_broadcasted_argnums=static_broadcasted_argnums,
devices=devices,
backend=backend,
axis_size=axis_size,
donate_argnums=donate_argnums,
global_arg_shapes=global_arg_shapes,
in_axes_kwargs=in_axes_kwargs,
state_axes=state_axes,
split_rngs=split_rngs,
transform_metadata=transform_metadata,
) # type: ignore[return-value]
if static_broadcasted_argnums:
raise NotImplementedError(
'static_broadcasted_argnums is not yet supported in nnx.pmap'
)
if donate_argnums != ():
raise NotImplementedError('donate_argnums is not yet supported in nnx.pmap')
if global_arg_shapes is not None:
raise NotImplementedError(
'global_arg_shapes is not yet supported in nnx.pmap'
)
vectorized_states_axes = list(state_axes.values())
pmapped_fn = jax.pmap(
pmap_fn,
axis_name=axis_name,
in_axes=(
in_axes, # args_axes
in_axes_kwargs, # kwargs_axes
None, # graphdef_axes
0, # split_keys_axes
None, # split_counts_axes
None, # broadcast_keys_axes
None, # broadcast_counts_axes
vectorized_states_axes, # vectorized_states_axes
None, # broadcast_state_axes
None, # pmap_inputs_axes
), # type: ignore
out_axes=(
None, # graphdef_out_axes
None, # broadcast_state_axes
vectorized_states_axes,
0, # keys_axes_out
out_axes, # out_axes
), # type: ignore
devices=devices,
backend=backend,
axis_size=axis_size,
)
@functools.wraps(f)
@graph.update_context('pmap')
def pmap_wrapper(*args, **kwargs):
ctx = graph.current_update_context('pmap')
(args, kwargs), input_graph_nodes = extract.extract_graph_nodes(
(args, kwargs)
)
input_rng_streams = rnglib.backup_keys(input_graph_nodes)
# split module state
filters = (*state_axes.keys(), ...)
graphdef, rng_state, *vectorized_states, broadcast_state = ctx.split( # type: ignore[misc]
input_graph_nodes, rnglib.RngState, *filters
)
# infer length
axis_sizes: tp.Set[int] = set()
axis_sizes.update(_get_axis_sizes(args, in_axes))
axis_sizes.update(_get_axis_sizes(kwargs, in_axes_kwargs))
for state, state_axis in zip(vectorized_states, state_axes.values()):
axis_sizes.update(_get_axis_sizes(state, state_axis))
if len(axis_sizes) > 1:
raise ValueError(
'Inconsistent lengths between state_axes states and '
f'arguments: {axis_sizes}'
)
elif len(axis_sizes) == 0:
if axis_size is None:
raise ValueError(
'Cannot infer length from state_axes states or axes_arg, '
'please specify `length`'
)
_axis_size = axis_size
else:
_axis_size = axis_sizes.pop()
if axis_size is not None and axis_size != _axis_size:
raise ValueError(
f'Specified axis_size {axis_size} is not the same as the'
f' inferred length {_axis_size}'
)
split_keys, split_counts, broadcast_keys, broadcast_counts = rnglib.fork(
rng_state,
split_rngs,
_axis_size,
)
(
graphdef_out,
broadcast_state,
vectorized_states,
split_keys_out,
out,
) = pmapped_fn(
args,
kwargs,
graphdef,
split_keys,
split_counts,
broadcast_keys,
broadcast_counts,
vectorized_states,
broadcast_state,
PmapInputs(
transform_metadata=transform_metadata,
state_axes=state_axes,
f=f,
filters=filters,
split_rngs=split_rngs,
),
)
_, output_graph_nodes = ctx.merge(
graphdef_out,
*vectorized_states,
broadcast_state,
split_keys_out,
)
out = extract.insert_graph_nodes(out, output_graph_nodes)
rnglib.restore_rngs(input_rng_streams)
return out
return pmap_wrapper # type: ignore
class Pmap(tp.Generic[M], LiftedModule[M]):
@staticmethod
def constructor(
module_constructor: tp.Callable[..., MA],
*,
axis_name: AxisName | None = None,
in_axes: tp.Any = 0,
out_axes: tp.Any = 0,
static_broadcasted_argnums: int | tp.Iterable[int] = (),
devices: tp.Sequence[jax.Device] | None = None, # noqa: F811
backend: str | None = None,
axis_size: int | None = None,
donate_argnums: int | tp.Iterable[int] = (),
global_arg_shapes: tuple[tuple[int, ...], ...] | None = None,
# nnx specific
in_axes_kwargs: tp.Any = 0,
state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}),
split_rngs: filterlib.Filter = ...,
transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}),
) -> tp.Callable[..., Pmap[MA]]:
def _create_pmap(*args, **kwargs):
return Pmap(
module_constructor=module_constructor,
axis_name=axis_name,
in_axes=in_axes,
out_axes=out_axes,
static_broadcasted_argnums=static_broadcasted_argnums,
devices=devices,
backend=backend,
axis_size=axis_size,
# nnx specific
in_axes_kwargs=in_axes_kwargs,
state_axes=state_axes,
split_rngs=split_rngs,
transform_metadata=transform_metadata,
# submodule args
module_init_args=args,
module_init_kwargs=kwargs,
)
return _create_pmap
def __init__(
self,
module_constructor: tp.Callable[..., M],
*,
axis_name: AxisName | None = None,
in_axes: tp.Any = 0,
out_axes: tp.Any = 0,
static_broadcasted_argnums: int | tp.Iterable[int] = (),
devices: tp.Sequence[jax.Device] | None = None, # noqa: F811
backend: str | None = None,
axis_size: int | None = None,
donate_argnums: int | tp.Iterable[int] = (),
global_arg_shapes: tuple[tuple[int, ...], ...] | None = None,
# nnx specific
in_axes_kwargs: tp.Any = 0,
state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}),
split_rngs: filterlib.Filter = ...,
transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}),
# submodule args
module_init_args: tuple[tp.Any, ...],
module_init_kwargs: dict[str, tp.Any],
):
self.module_constructor = module_constructor
@pmap(
axis_name=axis_name,
in_axes=None,
out_axes=None,
static_broadcasted_argnums=static_broadcasted_argnums,
devices=devices,
backend=backend,
axis_size=axis_size,
donate_argnums=(),
global_arg_shapes=None,
in_axes_kwargs=None,
state_axes=state_axes,
split_rngs=split_rngs,
transform_metadata=transform_metadata,
)
def pmap_init(*args, **kwargs):
return module_constructor(*args, **kwargs)
self.pmap_module = pmap_init(*module_init_args, **module_init_kwargs)
@pmap(
axis_name=axis_name,
in_axes=in_axes,
out_axes=out_axes,
static_broadcasted_argnums=static_broadcasted_argnums,
devices=devices,
backend=backend,
axis_size=axis_size,
donate_argnums=donate_argnums,
global_arg_shapes=global_arg_shapes,
in_axes_kwargs=in_axes_kwargs,
state_axes=state_axes,
split_rngs=split_rngs,
transform_metadata=transform_metadata,
)
def pmap_call(module, *args, _nnx_vmap_accessor: DelayedAccessor, **kwargs):
method = _nnx_vmap_accessor(module)
return method(*args, **kwargs)
self.pmap_call = pmap_call
@property
def _submodule(self) -> M:
return self.pmap_module
def _call(self, accessor: DelayedAccessor, *args, **kwargs):
return self.pmap_call(
self._submodule, *args, _nnx_vmap_accessor=accessor, **kwargs
)
# -------------------------------
# scan
# -------------------------------
@dataclasses.dataclass(frozen=True)
class FlatDef(tp.Generic[A]):
type: type[A]
treedef: jax.tree_util.PyTreeDef
flat_axes: list[int | None]
jax.tree_util.register_static(FlatDef)
def _transpose_tree(tree: A, axes, /, *, move_front: bool) -> A:
flatdef, flat_transposes, _ = _transpose_and_split(
tree, axes, allow_none=False, move_front=move_front
)
return flatdef.treedef.unflatten(flat_transposes)
def _transpose_and_split(
tree: A, axes, /, *, allow_none: bool = True, move_front: bool = True
) -> tuple[
FlatDef[A],
list[jax.Array | None],
list[tp.Any],
]:
flat_axes: list[int | None] = broadcast_prefix(
axes, tree, is_leaf=lambda x: x is None
)
flat_tree, treedef = jax.tree.flatten(tree)
flat_broadcasts: list[tp.Any] = []
flat_transposes: list[jax.Array | None] = []
for i, (axis, node) in enumerate(zip(flat_axes, flat_tree)):
if axis is None:
if not allow_none:
raise ValueError('None axis not allowed')
flat_broadcasts.append(node)
flat_transposes.append(None)
else:
if not isinstance(node, jax.Array):
raise TypeError(
f'Expected a jax.Array, got {type(node).__name__} for axis {axis}'
)
# normalize axis
if axis < 0:
if axis < -len(node.shape):
raise ValueError(
f'Axis {axis} out of bounds for array with shape {node.shape}'
)
axis = len(node.shape) + axis
flat_axes[i] = axis
if node.shape == ():
raise ValueError(f'Cannot map over a scalar array, got {node}')
elif axis >= len(node.shape):
raise ValueError(
f'Axis {axis} out of bounds for array with shape {node.shape}'
)
if move_front:
node = jnp.moveaxis(node, axis, 0)
else:
node = jnp.moveaxis(node, 0, axis)
flat_broadcasts.append(None)
flat_transposes.append(node)
flatdef = FlatDef(type(tree), treedef, flat_axes)
return flatdef, flat_transposes, flat_broadcasts
def _unflatten_splits(
flatdef: FlatDef[A],
flat_transposes: list[jax.Array | None],
flat_broadcasts: list[tp.Any] | None = None,
/,
*,
allow_none: bool = True,
) -> A:
flat_axes = flatdef.flat_axes
treedef = flatdef.treedef
if flat_broadcasts is None:
if allow_none:
raise ValueError('flat_broadcasts must be provided if allow_none is True')
flat_broadcasts = [None] * len(flat_axes)
flat_tree = []
for axis, transpose, broadcast in zip(
flat_axes, flat_transposes, flat_broadcasts
):
if axis is None:
if not allow_none:
raise ValueError('None axis not allowed')
flat_tree.append(broadcast)
else:
if transpose is None:
raise ValueError('None transpose not allowed')
flat_tree.append(transpose)
tree = treedef.unflatten(flat_tree)
return tree
def _extract_carry_arg(
args: tuple[tp.Any, ...], carry_argnum: int, /
) -> tuple[tp.Any, tuple[tp.Any, ...]]:
# extract carry arg
if len(args) < carry_argnum + 1:
raise TypeError(
f'Expected at least {carry_argnum + 1} positional arguments, '
f'got {len(args)}'
)
args_ = list(args)
carry_arg = args_[carry_argnum]
args_[carry_argnum] = None
args = tuple(args_)
return carry_arg, args
def _insert_carry_arg(
args: tuple[tp.Any, ...], carry_argnum: int, carry_arg: tp.Any, /
) -> tuple[tp.Any, ...]:
args_ = list(args)
args_[carry_argnum] = carry_arg
args = tuple(args_)
return args
@struct.dataclass
class ScanBroadcasts(tp.Generic[C, B]):
flatdef: FlatDef[
tuple[tuple[tp.Any, ...], dict[str, tp.Any], list[State]]
] = struct.field(pytree_node=False)
flat_carry: list[tp.Any] = struct.field(pytree_node=True)
graphdef: GraphDef[tuple[tp.Any, ...]] = struct.field(pytree_node=False)
filters: tuple[filterlib.Filter, ...] = struct.field(pytree_node=False)
f: tp.Callable[..., tuple[C, B] | C] = struct.field(pytree_node=False)
# options
carry_argnum: int = struct.field(pytree_node=False)
state_axes: tp.Mapping[filterlib.Filter, int] = struct.field(
pytree_node=False
)
split_rngs: filterlib.Filter = struct.field(pytree_node=False)
transform_metadata: tp.Mapping[str, tp.Any] = struct.field(pytree_node=False)
scan_output: bool = struct.field(pytree_node=False)
def scan_fn(
carry: tuple[
State, # split_rng_state
State, # broadcast_rng_state
State, # carry_state
tp.Any, # carry_arg
ScanBroadcasts[C, B], # broadcasts
],
scan: tuple[
list[jax.Array | None], # flat_scan
],
):
split_rng_state, broadcast_rng_state, carry_state, carry_arg, broadcasts = (
carry
)
(flat_scan,) = scan
flatdef = broadcasts.flatdef
flat_carry = broadcasts.flat_carry
graphdef, filters = broadcasts.graphdef, broadcasts.filters
f = broadcasts.f
ctx = graph.current_update_context('scan')
# merge args and kwargs
args, kwargs, scan_states = _unflatten_splits(flatdef, flat_scan, flat_carry)
# remove metadata axis name from Variable.sharding
if spmd.PARTITION_NAME in broadcasts.transform_metadata:
scan_states = [
spmd.remove_axis(state, index, broadcasts.transform_metadata)
for state, index in zip(scan_states, broadcasts.state_axes.values())
]
# insert carry arg
args = _insert_carry_arg(args, broadcasts.carry_argnum, carry_arg)
# merge module state
input_graph_nodes = ctx.merge(
graphdef, *scan_states, carry_state, split_rng_state, broadcast_rng_state
)
(args, kwargs) = extract.insert_graph_nodes((args, kwargs), input_graph_nodes)
out = f(*args, **kwargs)
if broadcasts.scan_output:
if not isinstance(out, tuple) or len(out) != 2:
raise ValueError(
'Expected a tuple of length 2 as the output of the scan function, '
f'got {out}'
)
out = tp.cast(tuple[C, B], out) # type: ignore[invalid-annotation]
carry_arg_out, scan_args_out = out
else:
out = tp.cast(C, out) # type: ignore[invalid-annotation]
carry_arg_out = out
scan_args_out = None
((carry_arg_out, scan_args_out), output_graph_nodes) = (
extract.extract_graph_nodes((carry_arg_out, scan_args_out))
)
# split module state
(
graphdef_out,
rng_state_out,
*scan_states_out,
carry_state_out,
) = ctx.split( # type: ignore[misc]
(input_graph_nodes, output_graph_nodes),
rnglib.RngState,
*filters,
)
split_rng_state_out, broadcast_rng_state_out = rng_state_out.split(
broadcasts.split_rngs, ...
)
def _extract_carry_state(state: State, /):
if 1 in state:
raise ValueError(
f'Cannot add new carry state during scan, got {state[1]}'
)
if 0 in state:
_state = state[0]
assert isinstance(_state, State)
state = _state
return state
carry_state_out = _extract_carry_state(carry_state_out)
split_rng_state_out = _extract_carry_state(split_rng_state_out)
broadcast_rng_state_out = _extract_carry_state(broadcast_rng_state_out)
# override broadcast_rng_state_out to keep the same state
# for the next iteration
broadcast_rng_state_out = broadcast_rng_state
# add metadata axis name to Variable.sharding
if spmd.PARTITION_NAME in broadcasts.transform_metadata:
scan_states_out = [
spmd.add_axis(state, index, broadcasts.transform_metadata)
for state, index in zip(scan_states_out, broadcasts.state_axes.values())
]
carry_out = (
split_rng_state_out,
broadcast_rng_state_out,
carry_state_out,
carry_arg_out,
broadcasts,
)
scan_out = (graphdef_out, scan_args_out, scan_states_out)
return carry_out, scan_out
@tp.overload
def scan(
*,
length: int | None = None,
reverse: bool = False,
unroll: int | bool = 1,
_split_transpose: bool = False,
# extended api
in_axes: int | None | tp.Sequence[tp.Any] = 0,
in_axes_kwargs: tp.Any = 0,
out_axes: tp.Any = 0,
carry_argnum: int = 0,
# nnx specific
state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}),
split_rngs: filterlib.Filter = ...,
transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}),
scan_output: bool = True,
) -> tp.Callable[[F], F]: ...
@tp.overload
def scan(
f: F,
*,
length: int | None = None,
reverse: bool = False,
unroll: int | bool = 1,
_split_transpose: bool = False,
# extended api
in_axes: int | None | tp.Sequence[tp.Any] = 0,
in_axes_kwargs: tp.Any = 0,
out_axes: tp.Any = 0,
carry_argnum: int = 0,
# nnx specific
state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}),
split_rngs: filterlib.Filter = ...,
transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}),
scan_output: bool = True,
) -> F: ...
def scan(
f: F | Missing = MISSING,
*,
length: int | None = None,
reverse: bool = False,
unroll: int | bool = 1,
_split_transpose: bool = False,
# extended api
in_axes: int | None | tp.Sequence[tp.Any] = 0,
in_axes_kwargs: tp.Any = 0,
out_axes: tp.Any = 0,
carry_argnum: int = 0,
# nnx specific
state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}),
split_rngs: filterlib.Filter = ...,
transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}),
scan_output: bool = True,
) -> F | tp.Callable[[F], F]:
if isinstance(f, Missing):
return functools.partial(
scan,
length=length,
reverse=reverse,
unroll=unroll,
_split_transpose=_split_transpose,
in_axes=in_axes,
in_axes_kwargs=in_axes_kwargs,
out_axes=out_axes,
carry_argnum=carry_argnum,
state_axes=state_axes,
split_rngs=split_rngs,
transform_metadata=transform_metadata,
scan_output=scan_output,
) # type: ignore[return-value]
@functools.wraps(f)
@graph.update_context('scan')
def scan_apply_wrapper(*args, **kwargs):
# extract nodes
(args, kwargs), input_graph_nodes = extract.extract_graph_nodes(
(args, kwargs)
)
input_rng_streams = rnglib.backup_keys(input_graph_nodes)
# extract carry arg
carry_arg, args = _extract_carry_arg(args, carry_argnum)
ctx = graph.current_update_context('scan')
# split module state
filters = (*state_axes.keys(), ...)
graphdef, rng_state, *scan_states, carry_state = ctx.split( # type: ignore[misc]
input_graph_nodes, rnglib.RngState, *filters
)
# transpose axes arg
flatdef, flat_scan, flat_carry = _transpose_and_split(
(args, kwargs, scan_states),
(in_axes, in_axes_kwargs, list(state_axes.values())),
)
# infer length
lengths: set[int] = {
x.shape[0] # type: ignore
for x, axis in zip(flat_scan, flatdef.flat_axes)
if axis is not None
}
if len(lengths) > 1:
raise ValueError(
'Inconsistent lengths between state_axes states and '
f'arguments: {lengths}'
)
elif len(lengths) == 0:
if length is None:
raise ValueError(
'Cannot infer length from state_axes states or axes_arg, '
'please specify `length`'
)
infered_length = length
else:
infered_length = lengths.pop()
if length is not None and length != infered_length:
raise ValueError(
f'Specified length {length} is not the same as the inferred '
f'length {infered_length}'
)
# split rng state
split_rng_state, broadcast_rng_state = rng_state.split(split_rngs, ...)
broadcasts = ScanBroadcasts(
flatdef,
flat_carry,
graphdef,
filters,
f,
# options
carry_argnum,
state_axes,
split_rngs,
transform_metadata,
scan_output,
)
carry = (
split_rng_state,
broadcast_rng_state,
carry_state,
carry_arg,
broadcasts,
)
scan = (flat_scan,)
carry_out, scan_out = jax.lax.scan(
scan_fn,
carry,
scan,
length=infered_length,
reverse=reverse,
unroll=unroll,
_split_transpose=_split_transpose,
)
(
split_rng_state_out,
broadcast_rng_state_out,
carry_state_out,
carry_arg_out,
broadcasts,
) = carry_out
graphdef_out, scan_args_out, scan_states_out = scan_out
scan_args_out, scan_states_out = _transpose_tree(
(scan_args_out, scan_states_out),
(out_axes, list(state_axes.values())),
move_front=False,
)
if carry_state_out:
carry_state_out = State({0: carry_state_out._mapping})
if split_rng_state_out:
split_rng_state_out = State({0: split_rng_state_out._mapping})
if broadcast_rng_state_out:
broadcast_rng_state_out = State({0: broadcast_rng_state_out._mapping})
_, output_graph_nodes = ctx.merge(
graphdef_out,
*scan_states_out,
carry_state_out,
split_rng_state_out,
broadcast_rng_state_out,
)
carry_arg_out, scan_args_out = extract.insert_graph_nodes(
(carry_arg_out, scan_args_out), output_graph_nodes
)
rnglib.restore_rngs(input_rng_streams)
if scan_output:
scan_args_out = tp.cast(B, scan_args_out)
return carry_arg_out, scan_args_out
else:
return carry_arg_out
return scan_apply_wrapper # type: ignore
[docs]class Scan(tp.Generic[M], LiftedModule[M]):
@staticmethod
def constructor(
module_constructor: tp.Callable[..., MA],
*,
length: int | None = None,
reverse: bool = False,
unroll: int | bool = 1,
_split_transpose: bool = False,
# extended api
in_axes: int | None | tp.Sequence[tp.Any] = 0,
in_axes_kwargs: tp.Any = 0,
out_axes: tp.Any = 0,
carry_argnum: int = 1,
# nnx specific
state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}),
split_rngs: filterlib.Filter = ...,
transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}),
scan_output: bool = True,
) -> tp.Callable[..., Scan[MA]]:
def _create_scan(*args, **kwargs):
return Scan(
module_constructor=module_constructor,
module_init_args=args,
module_init_kwargs=kwargs,
# base api
length=length,
reverse=reverse,
unroll=unroll,
_split_transpose=_split_transpose,
# extended api
in_axes=in_axes,
in_axes_kwargs=in_axes_kwargs,
out_axes=out_axes,
carry_argnum=carry_argnum,
# nnx specific
state_axes=state_axes,
split_rngs=split_rngs,
transform_metadata=transform_metadata,
scan_output=scan_output,
)
return _create_scan
def __init__(
self,
module_constructor: tp.Callable[..., M],
*,
length: int | None = None,
reverse: bool = False,
unroll: int | bool = 1,
_split_transpose: bool = False,
# extended api
in_axes: int | None | tp.Sequence[tp.Any] = 0,
in_axes_kwargs: tp.Any = 0,
out_axes: tp.Any = 0,
carry_argnum: int = 1,
# nnx specific
state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}),
split_rngs: filterlib.Filter = ...,
transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}),
scan_output: bool = True,
# submodule args
module_init_args: tuple[tp.Any, ...],
module_init_kwargs: dict[str, tp.Any],
):
self.module_constructor = module_constructor
# use Vmap to handle initialisation
vmapped_module = Vmap.constructor(
module_constructor,
in_axes=in_axes,
out_axes=None,
axis_name=None,
axis_size=length,
spmd_axis_name=None,
state_axes=state_axes,
split_rngs=split_rngs,
in_axes_kwargs=in_axes_kwargs,
transform_metadata=transform_metadata,
)(*module_init_args, **module_init_kwargs)
self.scan_module = vmapped_module.vmap_module
@functools.partial(
scan,
length=length,
reverse=reverse,
unroll=unroll,
_split_transpose=_split_transpose,
in_axes=in_axes,
in_axes_kwargs=in_axes_kwargs,
out_axes=out_axes,
carry_argnum=carry_argnum,
state_axes=state_axes,
split_rngs=split_rngs,
transform_metadata=transform_metadata,
scan_output=scan_output,
)
def scan_call(module, *args, _nnx_scan_accessor: DelayedAccessor, **kwargs):
method = _nnx_scan_accessor(module)
return method(*args, **kwargs)
self.scan_call = scan_call
@property
def _submodule(self) -> M:
return self.scan_module
def _call(
self, accessor: DelayedAccessor, *args, **kwargs
) -> tuple[tp.Any, tp.Any]:
return self.scan_call(
self._submodule, *args, _nnx_scan_accessor=accessor, **kwargs
) # type: ignore[call-arg, return-value, type-var]
# -------------------------------
# remat
# -------------------------------
[docs]class Remat(tp.Generic[M], LiftedModule[M]):
@staticmethod
def constructor(
module_constructor: tp.Callable[..., MA],
prevent_cse: bool = True,
static_argnums: int | tuple[int, ...] = (),
policy: tp.Callable[..., bool] | None = None,
) -> tp.Callable[..., Remat[MA]]:
def create_remat(*args, **kwargs):
return Remat(
module_constructor=module_constructor,
module_init_args=args,
module_init_kwargs=kwargs,
prevent_cse=prevent_cse,
static_argnums=static_argnums,
policy=policy,
)
return create_remat
def __init__(
self,
*,
module_constructor: tp.Callable[..., M],
prevent_cse: bool = True,
static_argnums: int | tuple[int, ...] = (),
policy: tp.Callable[..., bool] | None = None,
# submodule args
module_init_args: tuple[tp.Any, ...],
module_init_kwargs: dict[str, tp.Any],
):
self.module_constructor = module_constructor
self.remat_module = self.module_constructor(
*module_init_args, **module_init_kwargs
)
@nnx.remat(
prevent_cse=prevent_cse, static_argnums=static_argnums, policy=policy
)
def remat_call(module, *args):
accessor: DelayedAccessor
*args, accessor = args
method = accessor(module)
return method(*args)
self.rem_call = remat_call
@property
def _submodule(self) -> M:
return self.remat_module
def _call(self, accessor: DelayedAccessor, *args) -> tp.Any:
return self.rem_call(self._submodule, *args, accessor)
# -------------------------------
# grad
# -------------------------------
def grad_fn(*args):
f: tp.Callable[..., tp.Any]
graphdef: GraphDef[tuple[dict[int, tp.Any], tuple[tp.Any, ...]]]
non_diff_state: State
has_aux: bool
diff_args: list[int]
ctx = graph.current_update_context('grad')
*args, f, graphdef, non_diff_state, has_aux, diff_args = args
# rebuild diff_state from substates in args
diff_state = State({})
for i in diff_args:
diff_state[i] = args[i]
diff_state: graph.GraphState = State({0: diff_state.raw_mapping})
diff_graph_nodes, input_nodes = ctx.merge(
graphdef, diff_state, non_diff_state
)
# add nodes to the args
for i, arg in diff_graph_nodes.items():
args[i] = arg
# add other nodes to the args
args = extract.insert_graph_nodes(args, input_nodes)
out = f(*args)
out, out_nodes = extract.extract_graph_nodes(out)
graphdef_out, state_out = ctx.split((input_nodes, out_nodes))
if has_aux:
loss, aux = out
out = (loss, (graphdef_out, state_out, aux))
else:
out = (out, (graphdef_out, state_out))
return out
def _grad_general(
f: tp.Callable[..., tp.Any],
argnums: int | tp.Sequence[int],
has_aux: bool,
holomorphic: bool,
allow_int: bool,
reduce_axes: tp.Sequence[AxisName],
wrt: filterlib.Filter,
return_value: bool,
) -> tp.Callable[..., tp.Any]:
@graph.update_context('grad')
def grad_wrapper(*args):
ctx: graph.UpdateContext = graph.current_update_context('grad')
_argnums = _normalize_sequence(argnums)
diff_graph_nodes: dict[int, tp.Any] = {
i: arg
for i, arg in enumerate(args)
if i in _argnums and graph.is_node(arg)
}
args, input_nodes = extract.extract_graph_nodes(args)
args = list(args)
def only_diff(path: tuple, value: tp.Any) -> bool:
# diff_graph_nodes is the first element in the tuple
return path[0] == 0
graphdef, diff_state, non_diff_state = ctx.split(
(diff_graph_nodes, input_nodes), filterlib.All(wrt, only_diff), ...
) # type: ignore[misc]
# extract diff_state substates into the args
diff_args: list[int] = []
if 0 in diff_state:
for i, diff_substate in diff_state[0].items(): # type: ignore
assert isinstance(i, int)
args[i] = diff_substate
diff_args.append(i)
transform = jax.value_and_grad if return_value else jax.grad
_argnums = _argnums[0] if len(_argnums) == 1 else _argnums
out = transform(
grad_fn,
argnums=_argnums,
has_aux=True,
holomorphic=holomorphic,
allow_int=allow_int,
reduce_axes=reduce_axes,
)(*args, f, graphdef, non_diff_state, has_aux, diff_args)
if return_value:
if has_aux:
(loss, (graphdef_out, state_out, aux)), grads = out
out = (loss, aux), grads
else:
(loss, (graphdef_out, state_out)), grads = out
out = loss, grads
else:
if has_aux:
grads, (graphdef_out, state_out, aux) = out
out = grads, aux
else:
out, (graphdef_out, state_out) = out
input_nodes, out_nodes = ctx.merge(graphdef_out, state_out)
out = extract.insert_graph_nodes(out, out_nodes)
return out
return grad_wrapper
def grad(
f: tp.Callable[..., tp.Any],
argnums: int | tp.Sequence[int] = 0,
has_aux: bool = False,
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: tp.Sequence[AxisName] = (),
*,
wrt: filterlib.Filter = variablelib.Param,
) -> tp.Callable[..., tp.Any]:
"""Lifted version of ``jax.grad`` that can handle Modules / graph nodes as
arguments.
The differentiable state of each graph node is defined by the `wrt` filter,
which by default is set to `nnx.Param`. Internally the ``State`` of
graph nodes is extracted, filtered according to `wrt` filter, and
passed to the underlying ``jax.grad`` function. The gradients
of graph nodes are of type ``State``.
Example::
>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 3))
...
>>> loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2)
>>> grad_fn = nnx.transforms.deprecated.grad(loss_fn, wrt=nnx.Param)
...
>>> grads = grad_fn(m, x, y)
>>> jax.tree.map(jnp.shape, grads)
State({
'bias': VariableState(
type=Param,
value=(3,)
),
'kernel': VariableState(
type=Param,
value=(2, 3)
)
})
Args:
fun: Function to be differentiated. Its arguments at positions specified by
``argnums`` should be arrays, scalars, graph nodes or standard Python
containers. Argument arrays in the positions specified by ``argnums`` must
be of inexact (i.e., floating-point or complex) type. It should return a
scalar (which includes arrays with shape ``()`` but not arrays with shape
``(1,)`` etc.)
argnums: Optional, integer or sequence of integers. Specifies which
positional argument(s) to differentiate with respect to (default 0).
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
holomorphic. If True, inputs and outputs must be complex. Default False.
allow_int: Optional, bool. Whether to allow differentiating with
respect to integer valued inputs. The gradient of an integer input will
have a trivial vector-space dtype (float0). Default False.
reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
``fun`` implicitly broadcasts a value over that axis, the backward pass
will perform a ``psum`` of the corresponding gradient. Otherwise, the
gradient will be per-example over named axes. For example, if ``'batch'``
is a named batch axis, ``grad(f, reduce_axes=('batch',))`` will create a
function that computes the total gradient while ``grad(f)`` will create
one that computes the per-example gradient.
wrt: Optional, filterlib.Filter. Filter to extract the differentiable state
of each graph node. Default is `nnx.Param`.
"""
return _grad_general(
f,
argnums,
has_aux,
holomorphic,
allow_int,
reduce_axes,
wrt,
return_value=False,
)
def value_and_grad(
f: tp.Callable[..., tp.Any],
argnums: int | tp.Sequence[int] = 0,
has_aux: bool = False,
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: tp.Sequence[AxisName] = (),
*,
wrt: filterlib.Filter = variablelib.Param,
) -> tp.Callable[..., tp.Any]:
return _grad_general(
f,
argnums,
has_aux,
holomorphic,
allow_int,
reduce_axes,
wrt,
return_value=True,
)
class Grad(tp.Generic[M], LiftedModule[M]):
@staticmethod
def constructor(
module_constructor: tp.Callable[..., MA],
has_aux: bool = False,
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: tp.Sequence[AxisName] = (),
return_value: bool = False,
*,
wrt: filterlib.Filter = variablelib.Param,
) -> tp.Callable[..., Grad[MA]]:
def _create_grad(*args, **kwargs):
return Grad(
module_constructor=module_constructor,
wrt=wrt,
has_aux=has_aux,
holomorphic=holomorphic,
allow_int=allow_int,
reduce_axes=reduce_axes,
return_value=return_value,
# submodule args
module_init_args=args,
module_init_kwargs=kwargs,
)
return _create_grad
def __init__(
self,
module_constructor: tp.Callable[..., M],
argnums: int | tp.Sequence[int] = 0,
has_aux: bool = False,
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: tp.Sequence[AxisName] = (),
*,
wrt: filterlib.Filter = variablelib.Param,
# submodule args
module_init_args: tuple[tp.Any, ...],
module_init_kwargs: dict[str, tp.Any],
):
self.module_constructor = module_constructor
self.grad_module = self.module_constructor(
*module_init_args, **module_init_kwargs
)
@functools.partial(
grad,
argnums=argnums,
has_aux=has_aux,
holomorphic=holomorphic,
allow_int=allow_int,
reduce_axes=reduce_axes,
wrt=wrt,
)
def grad_call_apply(module, *args):
*args, accessor = args
method = accessor(module)
return method(*args)
self.grad_apply = grad_call_apply
@property
def _submodule(self) -> M:
return self.grad_module
def _call(self, accessor: DelayedAccessor, *args) -> tp.Any:
return self.grad_apply(self.grad_module, *args, accessor)
# -------------------------------
# jit
# -------------------------------
[docs]class Jit(tp.Generic[M], LiftedModule[M]):
@staticmethod
def constructor(
module_constructor: tp.Callable[..., MA],
*,
in_shardings: tp.Any = None,
out_shardings: tp.Any = None,
static_argnums: int | tp.Sequence[int] | None = None,
static_argnames: str | tp.Iterable[str] | None = None,
donate_argnums: int | tp.Sequence[int] | None = None,
donate_argnames: str | tp.Iterable[str] | None = None,
keep_unused: bool = False,
device: tp.Optional[jax.Device] = None,
backend: tp.Optional[str] = None,
inline: bool = False,
abstracted_axes: tp.Optional[tp.Any] = None,
) -> tp.Callable[..., Jit[MA]]:
def _create_jit(*args, **kwargs):
return Jit(
module_constructor=module_constructor,
in_shardings=in_shardings,
out_shardings=out_shardings,
static_argnums=static_argnums,
static_argnames=static_argnames,
donate_argnums=donate_argnums,
donate_argnames=donate_argnames,
keep_unused=keep_unused,
device=device,
backend=backend,
inline=inline,
abstracted_axes=abstracted_axes,
# submodule args
module_init_args=args,
module_init_kwargs=kwargs,
)
return _create_jit
def __init__(
self,
module_constructor: tp.Callable[..., M],
*,
in_shardings: tp.Any = None,
out_shardings: tp.Any = None,
static_argnums: int | tp.Sequence[int] | None = None,
static_argnames: str | tp.Iterable[str] | None = None,
donate_argnums: int | tp.Sequence[int] | None = None,
donate_argnames: str | tp.Iterable[str] | None = None,
keep_unused: bool = False,
device: tp.Optional[jax.Device] = None,
backend: tp.Optional[str] = None,
inline: bool = False,
abstracted_axes: tp.Optional[tp.Any] = None,
# submodule args
module_init_args: tuple[tp.Any, ...],
module_init_kwargs: dict[str, tp.Any],
):
@functools.partial(
nnx.jit,
in_shardings=in_shardings,
out_shardings=out_shardings,
static_argnums=static_argnums,
static_argnames=static_argnames,
donate_argnums=donate_argnums,
donate_argnames=donate_argnames,
keep_unused=keep_unused,
device=device,
backend=backend,
inline=inline,
abstracted_axes=abstracted_axes,
)
def jit_call_module(
module, *args, _nnx_jit_accessor: DelayedAccessor, **kwargs
):
method = _nnx_jit_accessor(module)
return method(*args, **kwargs)
self.jitted_fn = jit_call_module
self.module_constructor = module_constructor
self.jit_module = self.module_constructor(
*module_init_args, **module_init_kwargs
)
@property
def _submodule(self) -> M:
return self.jit_module
def _call(self, accessor: DelayedAccessor, *args, **kwargs) -> tp.Any:
out = self.jitted_fn(
self.jit_module, *args, _nnx_jit_accessor=accessor, **kwargs
) # type: ignore[call-arg, type-var]
return out