variables#

class flax.experimental.nnx.BatchStat(value, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[source]#
class flax.experimental.nnx.Cache(value, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[source]#
class flax.experimental.nnx.Empty[source]#
class flax.experimental.nnx.Intermediate(value, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[source]#
class flax.experimental.nnx.Param(value, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[source]#
class flax.experimental.nnx.Rng(value, *, tag, **metadata)[source]#
class flax.experimental.nnx.Variable(value, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[source]#
class flax.experimental.nnx.VariableMetadata(raw_value: ~A, set_value_hooks: tuple[typing.Callable[[ForwardRef('Variable[A]'), ~A], ~A], ...] = (), get_value_hooks: tuple[typing.Callable[[ForwardRef('Variable[A]'), ~A], ~A], ...] = (), create_value_hooks: tuple[typing.Callable[[ForwardRef('Variable[A]'), ~A], ~A], ...] = (), add_axis_hooks: tuple[typing.Callable[[ForwardRef('Variable[A]'), str, int], NoneType], ...] = (), remove_axis_hooks: tuple[typing.Callable[[ForwardRef('Variable[A]'), str, int], NoneType], ...] = (), metadata: Mapping[str, Any] = <factory>)[source]#
flax.experimental.nnx.with_metadata(initializer, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[source]#