variables#
- class flax.nnx.BatchStat(value, *, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[source]#
The mean and variance batch statistics stored in the
BatchNorm
layer. Note, these are not the learnable scale and bias parameters, but rather the running average statistics that are typically used during post-training inference:>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> layer = nnx.BatchNorm(3, rngs=nnx.Rngs(0)) >>> jax.tree.map(jnp.shape, nnx.state(layer)) State({ 'bias': VariableState( type=Param, value=(3,) ), 'mean': VariableState( type=BatchStat, value=(3,) ), 'scale': VariableState( type=Param, value=(3,) ), 'var': VariableState( type=BatchStat, value=(3,) ) })
- class flax.nnx.Cache(value, *, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[source]#
Autoregressive cache in
MultiHeadAttention
:>>> from flax import nnx >>> import jax, jax.numpy as jnp
>>> layer = nnx.MultiHeadAttention( ... num_heads=2, ... in_features=3, ... qkv_features=6, ... out_features=6, ... decode=True, ... rngs=nnx.Rngs(0), ... )
>>> layer.init_cache((1, 3)) >>> jax.tree.map(jnp.shape, nnx.state(layer, nnx.Cache)) State({ 'cache_index': VariableState( type=Cache, value=() ), 'cached_key': VariableState( type=Cache, value=(1, 2, 3) ), 'cached_value': VariableState( type=Cache, value=(1, 2, 3) ) })
- class flax.nnx.Intermediate(value, *, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[source]#
Variable
type that is typically used forModule.sow()
:>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... self.sow(nnx.Intermediate, 'i', x) ... x = self.linear2(x) ... return x >>> model = Model(rngs=nnx.Rngs(0)) >>> x = jnp.ones((1, 2)) >>> y = model(x) >>> jax.tree.map(jnp.shape, nnx.state(model, nnx.Intermediate)) State({ 'i': VariableState( type=Intermediate, value=((1, 3),) ) })
- class flax.nnx.Param(value, *, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[source]#
The canonical learnable parameter. All learnable parameters in NNX layer modules will have the
Param
Variable
type:>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> layer = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> jax.tree.map(jnp.shape, nnx.state(layer)) State({ 'bias': VariableState( type=Param, value=(3,) ), 'kernel': VariableState( type=Param, value=(2, 3) ) })
- class flax.nnx.Variable(value, *, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[source]#
The base class for all
Variable
types. Create customVariable
types by subclassing this class. Numerous NNX graph functions can filter for specificVariable
types, for example,split()
,state()
,pop()
, andState.filter()
.Example usage:
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> class CustomVariable(nnx.Variable): ... pass >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... self.custom_variable = CustomVariable(jnp.ones((1, 3))) ... def __call__(self, x): ... return self.linear(x) + self.custom_variable >>> model = Model(rngs=nnx.Rngs(0)) >>> linear_variables = nnx.state(model, nnx.Param) >>> jax.tree.map(jnp.shape, linear_variables) State({ 'linear': { 'bias': VariableState( type=Param, value=(3,) ), 'kernel': VariableState( type=Param, value=(2, 3) ) } }) >>> custom_variable = nnx.state(model, CustomVariable) >>> jax.tree.map(jnp.shape, custom_variable) State({ 'custom_variable': VariableState( type=CustomVariable, value=(1, 3) ) }) >>> variables = nnx.state(model) >>> jax.tree.map(jnp.shape, variables) State({ 'custom_variable': VariableState( type=CustomVariable, value=(1, 3) ), 'linear': { 'bias': VariableState( type=Param, value=(3,) ), 'kernel': VariableState( type=Param, value=(2, 3) ) } })
- class flax.nnx.VariableMetadata(raw_value: 'A', set_value_hooks: 'tuple[SetValueHook[A], ...]' = (), get_value_hooks: 'tuple[GetValueHook[A], ...]' = (), create_value_hooks: 'tuple[CreateValueHook[A], ...]' = (), add_axis_hooks: 'tuple[AddAxisHook[Variable[A]], ...]' = (), remove_axis_hooks: 'tuple[RemoveAxisHook[Variable[A]], ...]' = (), metadata: 'tp.Mapping[str, tp.Any]' = <factory>)[source]#