graph#
- flax.nnx.dataclass(cls=None, /, *, init=True, eq=True, order=False, unsafe_hash=False, match_args=True, kw_only=False, slots=False)[source]#
Makes an nnx.Object type as a dataclass and defines its pytree node attributes using type hints.
nnx.dataclasscan be used to create pytree dataclass types using type hints instead of the__data__attribute. By default, all fields are considered to be nodes, to mark a field as static annotate it withnnx.Static[T].Example:
from flax import nnx import jax @nnx.dataclass class Foo(nnx.Object): a: int b: jax.Array c: nnx.Static[int] tree = Foo(a=1, b=jax.numpy.array(1), c=1) assert len(jax.tree.leaves(tree)) == 2 # a and b
dataclasswill raise aValueErrorif the class does not derive fromnnx.Object, if the parent Object haspytreeset to anything other than'strict', or if the class has a__data__attribute.nnx.dataclassdoesn’t acceptreprand defines it asFalseto avoid overwriting the default__repr__method fromObject.