# Copyright 2023 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for defining custom classes that can be used with jax transformations."""
import dataclasses
from typing import TypeVar
import jax
from typing_extensions import (
dataclass_transform, # pytype: disable=not-supported-yet
)
from . import serialization
_T = TypeVar('_T')
def field(pytree_node=True, **kwargs):
return dataclasses.field(metadata={'pytree_node': pytree_node}, **kwargs)
[docs]@dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required]
def dataclass(clz: _T) -> _T:
"""Create a class which can be passed to functional transformations.
NOTE: Inherit from ``PyTreeNode`` instead to avoid type checking issues when
using PyType.
Jax transformations such as `jax.jit` and `jax.grad` require objects that are
immutable and can be mapped over using the `jax.tree_util` methods.
The `dataclass` decorator makes it easy to define custom classes that can be
passed safely to Jax. For example::
>>> from flax import struct
>>> import jax
>>> from typing import Any, Callable
>>> @struct.dataclass
... class Model:
... params: Any
... # use pytree_node=False to indicate an attribute should not be touched
... # by Jax transformations.
... apply_fn: Callable = struct.field(pytree_node=False)
... def __apply__(self, *args):
... return self.apply_fn(*args)
>>> params = {}
>>> params_b = {}
>>> apply_fn = lambda v, x: x
>>> model = Model(params, apply_fn)
>>> # model.params = params_b # Model is immutable. This will raise an error.
>>> model_b = model.replace(params=params_b) # Use the replace method instead.
>>> # This class can now be used safely in Jax to compute gradients w.r.t. the
>>> # parameters.
>>> model = Model(params, apply_fn)
>>> loss_fn = lambda model: 3.
>>> model_grad = jax.grad(loss_fn)(model)
Note that dataclasses have an auto-generated ``__init__`` where
the arguments of the constructor and the attributes of the created
instance match 1:1. This correspondence is what makes these objects
valid containers that work with JAX transformations and
more generally the `jax.tree_util` library.
Sometimes a "smart constructor" is desired, for example because
some of the attributes can be (optionally) derived from others.
The way to do this with Flax dataclasses is to make a static or
class method that provides the smart constructor.
This way the simple constructor used by `jax.tree_util` is
preserved. Consider the following example::
>>> @struct.dataclass
... class DirectionAndScaleKernel:
... direction: jax.Array
... scale: jax.Array
... @classmethod
... def create(cls, kernel):
... scale = jax.numpy.linalg.norm(kernel, axis=0, keepdims=True)
... direction = direction / scale
... return cls(direction, scale)
Args:
clz: the class that will be transformed by the decorator.
Returns:
The new class.
"""
# check if already a flax dataclass
if '_flax_dataclass' in clz.__dict__:
return clz
data_clz = dataclasses.dataclass(frozen=True)(clz) # type: ignore
meta_fields = []
data_fields = []
for field_info in dataclasses.fields(data_clz):
is_pytree_node = field_info.metadata.get('pytree_node', True)
if is_pytree_node:
data_fields.append(field_info.name)
else:
meta_fields.append(field_info.name)
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
data_clz.replace = replace
def iterate_clz(x):
meta = tuple(getattr(x, name) for name in meta_fields)
data = tuple(getattr(x, name) for name in data_fields)
return data, meta
def iterate_clz_with_keys(x):
meta = tuple(getattr(x, name) for name in meta_fields)
data = tuple(
(jax.tree_util.GetAttrKey(name), getattr(x, name)) for name in data_fields
)
return data, meta
def clz_from_iterable(meta, data):
meta_args = tuple(zip(meta_fields, meta))
data_args = tuple(zip(data_fields, data))
kwargs = dict(meta_args + data_args)
return data_clz(**kwargs)
jax.tree_util.register_pytree_with_keys(
data_clz, iterate_clz_with_keys, clz_from_iterable
)
def to_state_dict(x):
state_dict = {
name: serialization.to_state_dict(getattr(x, name))
for name in data_fields
}
return state_dict
def from_state_dict(x, state):
"""Restore the state of a data class."""
state = state.copy() # copy the state so we can pop the restored fields.
updates = {}
for name in data_fields:
if name not in state:
raise ValueError(
f'Missing field {name} in state dict while restoring'
f' an instance of {clz.__name__},'
f' at path {serialization.current_path()}'
)
value = getattr(x, name)
value_state = state.pop(name)
updates[name] = serialization.from_state_dict(
value, value_state, name=name
)
if state:
names = ','.join(state.keys())
raise ValueError(
f'Unknown field(s) "{names}" in state dict while'
f' restoring an instance of {clz.__name__}'
f' at path {serialization.current_path()}'
)
return x.replace(**updates)
serialization.register_serialization_state(
data_clz, to_state_dict, from_state_dict
)
# add a _flax_dataclass flag to distinguish from regular dataclasses
data_clz._flax_dataclass = True # type: ignore[attr-defined]
return data_clz # type: ignore
TNode = TypeVar('TNode', bound='PyTreeNode')
[docs]@dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required]
class PyTreeNode:
"""Base class for dataclasses that should act like a JAX pytree node.
See ``flax.struct.dataclass`` for the ``jax.tree_util`` behavior.
This base class additionally avoids type checking errors when using PyType.
Example::
>>> from flax import struct
>>> import jax
>>> from typing import Any, Callable
>>> class Model(struct.PyTreeNode):
... params: Any
... # use pytree_node=False to indicate an attribute should not be touched
... # by Jax transformations.
... apply_fn: Callable = struct.field(pytree_node=False)
... def __apply__(self, *args):
... return self.apply_fn(*args)
>>> params = {}
>>> params_b = {}
>>> apply_fn = lambda v, x: x
>>> model = Model(params, apply_fn)
>>> # model.params = params_b # Model is immutable. This will raise an error.
>>> model_b = model.replace(params=params_b) # Use the replace method instead.
>>> # This class can now be used safely in Jax to compute gradients w.r.t. the
>>> # parameters.
>>> model = Model(params, apply_fn)
>>> loss_fn = lambda model: 3.
>>> model_grad = jax.grad(loss_fn)(model)
"""
def __init_subclass__(cls):
dataclass(cls) # pytype: disable=wrong-arg-types
def __init__(self, *args, **kwargs):
# stub for pytype
raise NotImplementedError
def replace(self: TNode, **overrides) -> TNode:
# stub for pytype
raise NotImplementedError