flax.struct package

flax.struct package#

Utilities for defining custom classes that can be used with jax transformations.

flax.struct.dataclass(clz, **kwargs)[source]#

Create a class which can be passed to functional transformations.

Note

Inherit from PyTreeNode instead to avoid type checking issues when using PyType.

Jax transformations such as jax.jit and jax.grad require objects that are immutable and can be mapped over using the jax.tree_util methods. The dataclass decorator makes it easy to define custom classes that can be passed safely to Jax. For example:

>>> from flax import struct
>>> import jax
>>> from typing import Any, Callable

>>> @struct.dataclass
... class Model:
...   params: Any
...   # use pytree_node=False to indicate an attribute should not be touched
...   # by Jax transformations.
...   apply_fn: Callable = struct.field(pytree_node=False)

...   def __apply__(self, *args):
...     return self.apply_fn(*args)

>>> params = {}
>>> params_b = {}
>>> apply_fn = lambda v, x: x
>>> model = Model(params, apply_fn)

>>> # model.params = params_b  # Model is immutable. This will raise an error.
>>> model_b = model.replace(params=params_b)  # Use the replace method instead.

>>> # This class can now be used safely in Jax to compute gradients w.r.t. the
>>> # parameters.
>>> model = Model(params, apply_fn)
>>> loss_fn = lambda model: 3.
>>> model_grad = jax.grad(loss_fn)(model)

Note that dataclasses have an auto-generated __init__ where the arguments of the constructor and the attributes of the created instance match 1:1. This correspondence is what makes these objects valid containers that work with JAX transformations and more generally the jax.tree_util library.

Sometimes a “smart constructor” is desired, for example because some of the attributes can be (optionally) derived from others. The way to do this with Flax dataclasses is to make a static or class method that provides the smart constructor. This way the simple constructor used by jax.tree_util is preserved. Consider the following example:

>>> @struct.dataclass
... class DirectionAndScaleKernel:
...   direction: jax.Array
...   scale: jax.Array

...   @classmethod
...   def create(cls, kernel):
...     scale = jax.numpy.linalg.norm(kernel, axis=0, keepdims=True)
...     direction = direction / scale
...     return cls(direction, scale)
Parameters

clz – the class that will be transformed by the decorator.

Returns

The new class.

class flax.struct.PyTreeNode(*args, **kwargs)[source]#

Base class for dataclasses that should act like a JAX pytree node.

See flax.struct.dataclass for the jax.tree_util behavior. This base class additionally avoids type checking errors when using PyType.

Example:

>>> from flax import struct
>>> import jax
>>> from typing import Any, Callable

>>> class Model(struct.PyTreeNode):
...   params: Any
...   # use pytree_node=False to indicate an attribute should not be touched
...   # by Jax transformations.
...   apply_fn: Callable = struct.field(pytree_node=False)

...   def __apply__(self, *args):
...     return self.apply_fn(*args)

>>> params = {}
>>> params_b = {}
>>> apply_fn = lambda v, x: x
>>> model = Model(params, apply_fn)

>>> # model.params = params_b  # Model is immutable. This will raise an error.
>>> model_b = model.replace(params=params_b)  # Use the replace method instead.

>>> # This class can now be used safely in Jax to compute gradients w.r.t. the
>>> # parameters.
>>> model = Model(params, apply_fn)
>>> loss_fn = lambda model: 3.
>>> model_grad = jax.grad(loss_fn)(model)