variables#

class flax.nnx.BatchStat(value, *, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[source]#

The mean and variance batch statistics stored in the BatchNorm layer. Note, these are not the learnable scale and bias parameters, but rather the running average statistics that are typically used during post-training inference:

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> layer = nnx.BatchNorm(3, rngs=nnx.Rngs(0))
>>> jax.tree.map(jnp.shape, nnx.state(layer))
State({
  'bias': VariableState(
    type=Param,
    value=(3,)
  ),
  'mean': VariableState(
    type=BatchStat,
    value=(3,)
  ),
  'scale': VariableState(
    type=Param,
    value=(3,)
  ),
  'var': VariableState(
    type=BatchStat,
    value=(3,)
  )
})
class flax.nnx.Cache(value, *, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[source]#

Autoregressive cache in MultiHeadAttention:

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> layer = nnx.MultiHeadAttention(
...   num_heads=2,
...   in_features=3,
...   qkv_features=6,
...   out_features=6,
...   decode=True,
...   rngs=nnx.Rngs(0),
... )
>>> layer.init_cache((1, 3))
>>> jax.tree.map(jnp.shape, nnx.state(layer, nnx.Cache))
State({
  'cache_index': VariableState(
    type=Cache,
    value=()
  ),
  'cached_key': VariableState(
    type=Cache,
    value=(1, 2, 3)
  ),
  'cached_value': VariableState(
    type=Cache,
    value=(1, 2, 3)
  )
})
class flax.nnx.Intermediate(value, *, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[source]#

Variable type that is typically used for Module.sow():

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'i', x)
...     x = self.linear2(x)
...     return x
>>> model = Model(rngs=nnx.Rngs(0))

>>> x = jnp.ones((1, 2))
>>> y = model(x)
>>> jax.tree.map(jnp.shape, nnx.state(model, nnx.Intermediate))
State({
  'i': VariableState(
    type=Intermediate,
    value=((1, 3),)
  )
})
class flax.nnx.Param(value, *, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[source]#

The canonical learnable parameter. All learnable parameters in NNX layer modules will have the Param Variable type:

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> layer = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> jax.tree.map(jnp.shape, nnx.state(layer))
State({
  'bias': VariableState(
    type=Param,
    value=(3,)
  ),
  'kernel': VariableState(
    type=Param,
    value=(2, 3)
  )
})
class flax.nnx.Variable(value, *, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[source]#

The base class for all Variable types. Create custom Variable types by subclassing this class. Numerous NNX graph functions can filter for specific Variable types, for example, split(), state(), pop(), and State.filter().

Example usage:

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> class CustomVariable(nnx.Variable):
...   pass

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...     self.custom_variable = CustomVariable(jnp.ones((1, 3)))
...   def __call__(self, x):
...     return self.linear(x) + self.custom_variable
>>> model = Model(rngs=nnx.Rngs(0))

>>> linear_variables = nnx.state(model, nnx.Param)
>>> jax.tree.map(jnp.shape, linear_variables)
State({
  'linear': {
    'bias': VariableState(
      type=Param,
      value=(3,)
    ),
    'kernel': VariableState(
      type=Param,
      value=(2, 3)
    )
  }
})

>>> custom_variable = nnx.state(model, CustomVariable)
>>> jax.tree.map(jnp.shape, custom_variable)
State({
  'custom_variable': VariableState(
    type=CustomVariable,
    value=(1, 3)
  )
})

>>> variables = nnx.state(model)
>>> jax.tree.map(jnp.shape, variables)
State({
  'custom_variable': VariableState(
    type=CustomVariable,
    value=(1, 3)
  ),
  'linear': {
    'bias': VariableState(
      type=Param,
      value=(3,)
    ),
    'kernel': VariableState(
      type=Param,
      value=(2, 3)
    )
  }
})
class flax.nnx.VariableMetadata(raw_value: 'A', set_value_hooks: 'tuple[SetValueHook[A], ...]' = (), get_value_hooks: 'tuple[GetValueHook[A], ...]' = (), create_value_hooks: 'tuple[CreateValueHook[A], ...]' = (), add_axis_hooks: 'tuple[AddAxisHook[Variable[A]], ...]' = (), remove_axis_hooks: 'tuple[RemoveAxisHook[Variable[A]], ...]' = (), metadata: 'tp.Mapping[str, tp.Any]' = <factory>)[source]#
class flax.nnx.VariableState(type, value, **metadata)[source]#
flax.nnx.with_metadata(initializer, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[source]#