flax.struct package#
Utilities for defining custom classes that can be used with jax transformations.
- flax.struct.dataclass(clz)[source]#
Create a class which can be passed to functional transformations.
NOTE: Inherit from
PyTreeNodeinstead to avoid type checking issues when using PyType.Jax transformations such as
jax.jitandjax.gradrequire objects that are immutable and can be mapped over using thejax.tree_utilmethods. Thedataclassdecorator makes it easy to define custom classes that can be passed safely to Jax. For example:>>> from flax import struct >>> import jax >>> from typing import Any, Callable >>> @struct.dataclass ... class Model: ... params: Any ... # use pytree_node=False to indicate an attribute should not be touched ... # by Jax transformations. ... apply_fn: Callable = struct.field(pytree_node=False) ... def __apply__(self, *args): ... return self.apply_fn(*args) >>> params = {} >>> params_b = {} >>> apply_fn = lambda v, x: x >>> model = Model(params, apply_fn) >>> # model.params = params_b # Model is immutable. This will raise an error. >>> model_b = model.replace(params=params_b) # Use the replace method instead. >>> # This class can now be used safely in Jax to compute gradients w.r.t. the >>> # parameters. >>> model = Model(params, apply_fn) >>> loss_fn = lambda model: 3. >>> model_grad = jax.grad(loss_fn)(model)
Note that dataclasses have an auto-generated
__init__where the arguments of the constructor and the attributes of the created instance match 1:1. This correspondence is what makes these objects valid containers that work with JAX transformations and more generally thejax.tree_utillibrary.Sometimes a “smart constructor” is desired, for example because some of the attributes can be (optionally) derived from others. The way to do this with Flax dataclasses is to make a static or class method that provides the smart constructor. This way the simple constructor used by
jax.tree_utilis preserved. Consider the following example:>>> @struct.dataclass ... class DirectionAndScaleKernel: ... direction: jax.Array ... scale: jax.Array ... @classmethod ... def create(cls, kernel): ... scale = jax.numpy.linalg.norm(kernel, axis=0, keepdims=True) ... direction = direction / scale ... return cls(direction, scale)
- Parameters
clz – the class that will be transformed by the decorator.
- Returns
The new class.
- class flax.struct.PyTreeNode(*args, **kwargs)[source]#
Base class for dataclasses that should act like a JAX pytree node.
See
flax.struct.dataclassfor thejax.tree_utilbehavior. This base class additionally avoids type checking errors when using PyType.Example:
>>> from flax import struct >>> import jax >>> from typing import Any, Callable >>> class Model(struct.PyTreeNode): ... params: Any ... # use pytree_node=False to indicate an attribute should not be touched ... # by Jax transformations. ... apply_fn: Callable = struct.field(pytree_node=False) ... def __apply__(self, *args): ... return self.apply_fn(*args) >>> params = {} >>> params_b = {} >>> apply_fn = lambda v, x: x >>> model = Model(params, apply_fn) >>> # model.params = params_b # Model is immutable. This will raise an error. >>> model_b = model.replace(params=params_b) # Use the replace method instead. >>> # This class can now be used safely in Jax to compute gradients w.r.t. the >>> # parameters. >>> model = Model(params, apply_fn) >>> loss_fn = lambda model: 3. >>> model_grad = jax.grad(loss_fn)(model)