object#
- flax.nnx.Object#
alias of
Pytree
- flax.nnx.data(value, /)[source]#
Annotates a an attribute as pytree data.
The return value from data must be directly assigned to an Object attribute which will be registered as a pytree data attribute.
Example:
from flax import nnx import jax class Foo(nnx.Pytree): def __init__(self): self.data_attr = nnx.data(42) # pytree data self.static_attr = "hello" # static attribute foo = Foo() assert jax.tree.leaves(foo) == [42]
- Parameters
value – The value to annotate as data.
- Returns
A value which will register the attribute as data on assignment.
- flax.nnx.Data#
Data marks attributes of a class as pytree data using type annotations.
Data annotations must be used at the class level and will apply to all instances. The usage of Data is recommended when type annotations are used already present or required e.g. for dataclasses.
Example:
from flax import nnx import jax import dataclasses @dataclasses.dataclass class Foo(nnx.Pytree): a: nnx.Data[int] # Annotates `a` as pytree data b: str # `b` is not pytree data foo = Foo(a=42, b='hello') assert jax.tree.leaves(foo) == [42]
alias of
A[A]
- flax.nnx.is_data_type(value, /)[source]#
Checks if a value is a registered data type.
This function checks a the value is registered data type, which means it is automatically recognized as pytree data when assigned to an Object attribute.
Data types are: - jax.Arrays - np.ndarrays - ArrayRefs - Variables (Param, BatchStat, RngState, etc.) - All graph nodes (Object, Module, Rngs, etc.) - Any type registered with nnx.register_data_type
Example:
from flax import nnx import jax.numpy as jnp module = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) blocks = [module, module, module] assert nnx.is_data_type(jnp.array(42)) # Arrays are data assert nnx.is_data_type(nnx.Param(1)) # Variables are data assert nnx.is_data_type(nnx.Rngs(0)) # Objects are data assert nnx.is_data_type(module) # Objects are data assert not nnx.is_data_type(0.) # float is not data assert not nnx.is_data_type(1) # int is not data assert not nnx.is_data_type("hello") # str is not data assert not nnx.is_data_type(blocks) # list is not data
- Parameters
value – The value to check.
- Returns
True if the value is a registered data type, False otherwise.
- flax.nnx.register_data_type(type_, /)[source]#
Registers a type as pytree data type recognized by Object.
Custom types registered as data will be automatically recognized as data attributes when assigned to an Object attribute. This means that values of this type do not need to be wrapped in nnx.data(…) for Object to mark the attribute its being assigned to as data.
Example:
from flax import nnx from dataclasses import dataclass @dataclass(frozen=True) class MyType: value: int nnx.register_data_type(MyType) class Foo(nnx.Pytree): def __init__(self, a): self.a = MyType(a) # Automatically registered as data self.b = "hello" # str not registered as data foo = Foo(42) assert nnx.is_data_type(foo.a) # True assert jax.tree.leaves(foo) == [MyType(value=42)]