Normalization#

class flax.nnx.BatchNorm(*args, **kwargs)[source]#

BatchNorm Module.

To calculate the batch norm on the input and update the batch statistics, call the train() method (or pass in use_running_average=False in the constructor or during call time).

To use the stored batch statistics’ running average, call the eval() method (or pass in use_running_average=True in the constructor or during call time).

Example usage:

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> x = jax.random.normal(jax.random.key(0), (5, 6))
>>> layer = nnx.BatchNorm(num_features=6, momentum=0.9, epsilon=1e-5,
...                       dtype=jnp.float32, rngs=nnx.Rngs(0))
>>> jax.tree.map(jnp.shape, nnx.state(layer))
State({
  'bias': VariableState(
    type=Param,
    value=(6,)
  ),
  'mean': VariableState(
    type=BatchStat,
    value=(6,)
  ),
  'scale': VariableState(
    type=Param,
    value=(6,)
  ),
  'var': VariableState(
    type=BatchStat,
    value=(6,)
  )
})

>>> # calculate batch norm on input and update batch statistics
>>> layer.train()
>>> y = layer(x)
>>> batch_stats1 = nnx.state(layer, nnx.BatchStat)
>>> y = layer(x)
>>> batch_stats2 = nnx.state(layer, nnx.BatchStat)
>>> assert (batch_stats1['mean'].value != batch_stats2['mean'].value).all()
>>> assert (batch_stats1['var'].value != batch_stats2['var'].value).all()

>>> # use stored batch statistics' running average
>>> layer.eval()
>>> y = layer(x)
>>> batch_stats3 = nnx.state(layer, nnx.BatchStat)
>>> assert (batch_stats2['mean'].value == batch_stats3['mean'].value).all()
>>> assert (batch_stats2['var'].value == batch_stats3['var'].value).all()
num_features#

the number of input features.

use_running_average#

if True, the stored batch statistics will be used instead of computing the batch statistics on the input.

axis#

the feature or non-batch axis of the input.

momentum#

decay rate for the exponential moving average of the batch statistics.

epsilon#

a small float added to variance to avoid dividing by zero.

dtype#

the dtype of the result (default: infer from input and params).

param_dtype#

the dtype passed to parameter initializers (default: float32).

use_bias#

if True, bias (beta) is added.

use_scale#

if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.

bias_init#

initializer for bias, by default, zero.

scale_init#

initializer for scale, by default, one.

axis_name#

the axis name used to combine batch statistics from multiple devices. See jax.pmap for a description of axis names (default: None).

axis_index_groups#

groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, [[0, 1], [2, 3]] would independently batch-normalize over the examples on the first two and last two devices. See jax.lax.psum for more details.

use_fast_variance#

If true, use a faster, but less numerically stable, calculation for the variance.

rngs#

rng key.

__call__(x, use_running_average=None, *, mask=None)[source]#

Normalizes the input using batch statistics.

Parameters
  • x – the input to be normalized.

  • use_running_average – if true, the stored batch statistics will be used instead of computing the batch statistics on the input. The use_running_average flag passed into the call method will take precedence over the use_running_average flag passed into the constructor.

Returns

Normalized inputs (the same shape as inputs).

Methods

class flax.nnx.LayerNorm(*args, **kwargs)[source]#

Layer normalization (https://arxiv.org/abs/1607.06450).

LayerNorm normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that maintains the mean activation within each example close to 0 and the activation standard deviation close to 1.

Example usage:

>>> from flax import nnx
>>> import jax

>>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6))
>>> layer = nnx.LayerNorm(num_features=6, rngs=nnx.Rngs(0))

>>> nnx.state(layer)
State({
  'bias': VariableState(
    type=Param,
    value=Array([0., 0., 0., 0., 0., 0.], dtype=float32)
  ),
  'scale': VariableState(
    type=Param,
    value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
  )
})

>>> y = layer(x)
num_features#

the number of input features.

epsilon#

A small float added to variance to avoid dividing by zero.

dtype#

the dtype of the result (default: infer from input and params).

param_dtype#

the dtype passed to parameter initializers (default: float32).

use_bias#

If True, bias (beta) is added.

use_scale#

If True, multiply by scale (gamma). When the next layer is linear (also e.g. nnx.relu), this can be disabled since the scaling will be done by the next layer.

bias_init#

Initializer for bias, by default, zero.

scale_init#

Initializer for scale, by default, one.

reduction_axes#

Axes for computing normalization statistics.

feature_axes#

Feature axes for learned bias and scaling.

axis_name#

the axis name used to combine batch statistics from multiple devices. See jax.pmap for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap.

axis_index_groups#

groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, [[0, 1], [2, 3]] would independently batch-normalize over the examples on the first two and last two devices. See jax.lax.psum for more details.

use_fast_variance#

If true, use a faster, but less numerically stable, calculation for the variance.

rngs#

rng key.

__call__(x, *, mask=None)[source]#

Applies layer normalization on the input.

Parameters

x – the inputs

Returns

Normalized inputs (the same shape as inputs).

Methods

class flax.nnx.RMSNorm(*args, **kwargs)[source]#

RMS Layer normalization (https://arxiv.org/abs/1910.07467).

RMSNorm normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. Unlike LayerNorm which re-centers the mean to be 0 and normalizes by the standard deviation of the activations, RMSNorm does not re-center at all and instead normalizes by the root mean square of the activations.

Example usage:

>>> from flax import nnx
>>> import jax

>>> x = jax.random.normal(jax.random.key(0), (5, 6))
>>> layer = nnx.RMSNorm(num_features=6, rngs=nnx.Rngs(0))

>>> nnx.state(layer)
State({
  'scale': VariableState(
    type=Param,
    value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
  )
})

>>> y = layer(x)
num_features#

the number of input features.

epsilon#

A small float added to variance to avoid dividing by zero.

dtype#

the dtype of the result (default: infer from input and params).

param_dtype#

the dtype passed to parameter initializers (default: float32).

use_scale#

If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.

scale_init#

Initializer for scale, by default, one.

reduction_axes#

Axes for computing normalization statistics.

feature_axes#

Feature axes for learned bias and scaling.

axis_name#

the axis name used to combine batch statistics from multiple devices. See jax.pmap for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap.

axis_index_groups#

groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, [[0, 1], [2, 3]] would independently batch-normalize over the examples on the first two and last two devices. See jax.lax.psum for more details.

use_fast_variance#

If true, use a faster, but less numerically stable, calculation for the variance.

rngs#

rng key.

__call__(x, mask=None)[source]#

Applies layer normalization on the input.

Parameters

x – the inputs

Returns

Normalized inputs (the same shape as inputs).

Methods