Normalization#

class flax.nnx.BatchNorm(*args, **kwargs)[source]#

BatchNorm Module.

To calculate the batch norm on the input and update the batch statistics, call the train() method (or pass in use_running_average=False in the constructor or during call time).

To use the stored batch statistics’ running average, call the eval() method (or pass in use_running_average=True in the constructor or during call time).

Example usage:

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> x = jax.random.normal(jax.random.key(0), (5, 6))
>>> layer = nnx.BatchNorm(num_features=6, momentum=0.9, epsilon=1e-5,
...                       dtype=jnp.float32, rngs=nnx.Rngs(0))
>>> jax.tree.map(jnp.shape, nnx.state(layer))
State({
  'bias': VariableState(
    type=Param,
    value=(6,)
  ),
  'mean': VariableState(
    type=BatchStat,
    value=(6,)
  ),
  'scale': VariableState(
    type=Param,
    value=(6,)
  ),
  'var': VariableState(
    type=BatchStat,
    value=(6,)
  )
})

>>> # calculate batch norm on input and update batch statistics
>>> layer.train()
>>> y = layer(x)
>>> batch_stats1 = nnx.state(layer, nnx.BatchStat)
>>> y = layer(x)
>>> batch_stats2 = nnx.state(layer, nnx.BatchStat)
>>> assert (batch_stats1['mean'].value != batch_stats2['mean'].value).all()
>>> assert (batch_stats1['var'].value != batch_stats2['var'].value).all()

>>> # use stored batch statistics' running average
>>> layer.eval()
>>> y = layer(x)
>>> batch_stats3 = nnx.state(layer, nnx.BatchStat)
>>> assert (batch_stats2['mean'].value == batch_stats3['mean'].value).all()
>>> assert (batch_stats2['var'].value == batch_stats3['var'].value).all()
num_features#

the number of input features.

use_running_average#

if True, the stored batch statistics will be used instead of computing the batch statistics on the input.

axis#

the feature or non-batch axis of the input.

momentum#

decay rate for the exponential moving average of the batch statistics.

epsilon#

a small float added to variance to avoid dividing by zero.

dtype#

the dtype of the result (default: infer from input and params).

param_dtype#

the dtype passed to parameter initializers (default: float32).

use_bias#

if True, bias (beta) is added.

use_scale#

if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.

bias_init#

initializer for bias, by default, zero.

scale_init#

initializer for scale, by default, one.

axis_name#

the axis name used to combine batch statistics from multiple devices. See jax.pmap for a description of axis names (default: None).

axis_index_groups#

groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, [[0, 1], [2, 3]] would independently batch-normalize over the examples on the first two and last two devices. See jax.lax.psum for more details.

use_fast_variance#

If true, use a faster, but less numerically stable, calculation for the variance.

rngs#

rng key.

__call__(x, use_running_average=None, *, mask=None)[source]#

Normalizes the input using batch statistics.

Parameters
  • x – the input to be normalized.

  • use_running_average – if true, the stored batch statistics will be used instead of computing the batch statistics on the input. The use_running_average flag passed into the call method will take precedence over the use_running_average flag passed into the constructor.

Returns

Normalized inputs (the same shape as inputs).

Methods

class flax.nnx.LayerNorm(*args, **kwargs)[source]#

Layer normalization (https://arxiv.org/abs/1607.06450).

LayerNorm normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that maintains the mean activation within each example close to 0 and the activation standard deviation close to 1.

Example usage:

>>> from flax import nnx
>>> import jax

>>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6))
>>> layer = nnx.LayerNorm(num_features=6, rngs=nnx.Rngs(0))

>>> nnx.state(layer)
State({
  'bias': VariableState(
    type=Param,
    value=Array([0., 0., 0., 0., 0., 0.], dtype=float32)
  ),
  'scale': VariableState(
    type=Param,
    value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
  )
})

>>> y = layer(x)
num_features#

the number of input features.

epsilon#

A small float added to variance to avoid dividing by zero.

dtype#

the dtype of the result (default: infer from input and params).

param_dtype#

the dtype passed to parameter initializers (default: float32).

use_bias#

If True, bias (beta) is added.

use_scale#

If True, multiply by scale (gamma). When the next layer is linear (also e.g. nnx.relu), this can be disabled since the scaling will be done by the next layer.

bias_init#

Initializer for bias, by default, zero.

scale_init#

Initializer for scale, by default, one.

reduction_axes#

Axes for computing normalization statistics.

feature_axes#

Feature axes for learned bias and scaling.

axis_name#

the axis name used to combine batch statistics from multiple devices. See jax.pmap for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap.

axis_index_groups#

groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, [[0, 1], [2, 3]] would independently batch-normalize over the examples on the first two and last two devices. See jax.lax.psum for more details.

use_fast_variance#

If true, use a faster, but less numerically stable, calculation for the variance.

rngs#

rng key.

__call__(x, *, mask=None)[source]#

Applies layer normalization on the input.

Parameters

x – the inputs

Returns

Normalized inputs (the same shape as inputs).

Methods

class flax.nnx.RMSNorm(*args, **kwargs)[source]#

RMS Layer normalization (https://arxiv.org/abs/1910.07467).

RMSNorm normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. Unlike LayerNorm which re-centers the mean to be 0 and normalizes by the standard deviation of the activations, RMSNorm does not re-center at all and instead normalizes by the root mean square of the activations.

Example usage:

>>> from flax import nnx
>>> import jax

>>> x = jax.random.normal(jax.random.key(0), (5, 6))
>>> layer = nnx.RMSNorm(num_features=6, rngs=nnx.Rngs(0))

>>> nnx.state(layer)
State({
  'scale': VariableState(
    type=Param,
    value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
  )
})

>>> y = layer(x)
num_features#

the number of input features.

epsilon#

A small float added to variance to avoid dividing by zero.

dtype#

the dtype of the result (default: infer from input and params).

param_dtype#

the dtype passed to parameter initializers (default: float32).

use_scale#

If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.

scale_init#

Initializer for scale, by default, one.

reduction_axes#

Axes for computing normalization statistics.

feature_axes#

Feature axes for learned bias and scaling.

axis_name#

the axis name used to combine batch statistics from multiple devices. See jax.pmap for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap.

axis_index_groups#

groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, [[0, 1], [2, 3]] would independently batch-normalize over the examples on the first two and last two devices. See jax.lax.psum for more details.

use_fast_variance#

If true, use a faster, but less numerically stable, calculation for the variance.

rngs#

rng key.

__call__(x, mask=None)[source]#

Applies layer normalization on the input.

Parameters

x – the inputs

Returns

Normalized inputs (the same shape as inputs).

Methods

class flax.nnx.GroupNorm(*args, **kwargs)[source]#

Group normalization (arxiv.org/abs/1803.08494).

This op is similar to batch normalization, but statistics are shared across equally-sized groups of channels and not shared across batch dimension. Thus, group normalization does not depend on the batch composition and does not require maintaining internal state for storing statistics. The user should either specify the total number of channel groups or the number of channels per group.

Note

LayerNorm is a special case of GroupNorm where num_groups=1.

Example usage:

>>> from flax import nnx
>>> import jax
>>> import numpy as np
...
>>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6))
>>> layer = nnx.GroupNorm(num_features=6, num_groups=3, rngs=nnx.Rngs(0))
>>> nnx.state(layer)
State({
  'bias': VariableState(
    type=Param,
    value=Array([0., 0., 0., 0., 0., 0.], dtype=float32)
  ),
  'scale': VariableState(
    type=Param,
    value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
  )
})
>>> y = layer(x)
...
>>> y = nnx.GroupNorm(num_features=6, num_groups=1, rngs=nnx.Rngs(0))(x)
>>> y2 = nnx.LayerNorm(num_features=6, reduction_axes=(1, 2, 3), rngs=nnx.Rngs(0))(x)
>>> np.testing.assert_allclose(y, y2)
num_features#

the number of input features/channels.

num_groups#

the total number of channel groups. The default value of 32 is proposed by the original group normalization paper.

group_size#

the number of channels in a group.

epsilon#

A small float added to variance to avoid dividing by zero.

dtype#

the dtype of the result (default: infer from input and params).

param_dtype#

the dtype passed to parameter initializers (default: float32).

use_bias#

If True, bias (beta) is added.

use_scale#

If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.

bias_init#

Initializer for bias, by default, zero.

scale_init#

Initializer for scale, by default, one.

reduction_axes#

List of axes used for computing normalization statistics. This list must include the final dimension, which is assumed to be the feature axis. Furthermore, if the input used at call time has additional leading axes compared to the data used for initialisation, for example due to batching, then the reduction axes need to be defined explicitly.

axis_name#

the axis name used to combine batch statistics from multiple devices. See jax.pmap for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap or shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives.

axis_index_groups#

groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, [[0, 1], [2, 3]] would independently batch-normalize over the examples on the first two and last two devices. See jax.lax.psum for more details.

use_fast_variance#

If true, use a faster, but less numerically stable, calculation for the variance.

rngs#

rng key.

__call__(x, *, mask=None)[source]#

Applies group normalization to the input (arxiv.org/abs/1803.08494).

Parameters
  • x – the input of shape ...self.num_features where self.num_features is a channels dimension and ... represents an arbitrary number of extra dimensions that can be used to accumulate statistics over. If no reduction axes have been specified then all additional dimensions ... will be used to accumulate statistics apart from the leading dimension which is assumed to represent the batch.

  • mask – Binary array of shape broadcastable to inputs tensor, indicating the positions for which the mean and variance should be computed.

Returns

Normalized inputs (the same shape as inputs).

Methods