flax.linen.LayerNorm#

class flax.linen.LayerNorm(epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, reduction_axes=-1, feature_axes=-1, axis_name=None, axis_index_groups=None, use_fast_variance=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Layer normalization (https://arxiv.org/abs/1607.06450).

LayerNorm normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that maintains the mean activation within each example close to 0 and the activation standard deviation close to 1.

NOTE: This normalization operation is identical to InstanceNorm and GroupNorm; the difference is simply which axes are reduced and the shape of the feature axes (i.e. the shape of the learnable scale and bias parameters).

Example usage:

>>> import flax.linen as nn
>>> import jax
>>> import numpy as np

>>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6))
>>> layer = nn.LayerNorm()
>>> variables = layer.init(jax.random.key(1), x)
>>> variables
{'params': {'scale': Array([1., 1., 1., 1., 1., 1.], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0.], dtype=float32)}}
>>> y = layer.apply(variables, x)

>>> y = nn.LayerNorm(reduction_axes=(1, 2, 3)).apply(variables, x)
>>> y2 = nn.GroupNorm(num_groups=1).apply(variables, x)
>>> np.testing.assert_allclose(y, y2)

>>> y = nn.LayerNorm(reduction_axes=(1, 2), feature_axes=-1).apply(variables, x)
>>> y2 = nn.InstanceNorm(feature_axes=-1).apply(variables, x)
>>> np.testing.assert_allclose(y, y2)
epsilon#

A small float added to variance to avoid dividing by zero.

Type

float

dtype#

the dtype of the result (default: infer from input and params).

Type

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

the dtype passed to parameter initializers (default: float32).

Type

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

use_bias#

If True, bias (beta) is added.

Type

bool

use_scale#

If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.

Type

bool

bias_init#

Initializer for bias, by default, zero.

Type

Union[jax.nn.initializers.Initializer, Callable[[…], Any]]

scale_init#

Initializer for scale, by default, one.

Type

Union[jax.nn.initializers.Initializer, Callable[[…], Any]]

reduction_axes#

Axes for computing normalization statistics.

Type

Union[int, Sequence[int]]

feature_axes#

Feature axes for learned bias and scaling.

Type

Union[int, Sequence[int]]

axis_name#

the axis name used to combine batch statistics from multiple devices. See jax.pmap for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap or shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives.

Type

Optional[str]

axis_index_groups#

groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, [[0, 1], [2, 3]] would independently batch-normalize over the examples on the first two and last two devices. See jax.lax.psum for more details.

Type

Any

use_fast_variance#

If true, use a faster, but less numerically stable, calculation for the variance.

Type

bool

__call__(x, *, mask=None)[source]#

Applies layer normalization on the input.

Parameters
  • x – the inputs

  • mask – Binary array of shape broadcastable to inputs tensor, indicating the positions for which the mean and variance should be computed.

Returns

Normalized inputs (the same shape as inputs).

Methods