flax.linen.LayerNorm#

class flax.linen.LayerNorm(epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, reduction_axes=-1, feature_axes=-1, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Layer normalization (https://arxiv.org/abs/1607.06450).

LayerNorm normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that maintains the mean activation within each example close to 0 and the activation standard deviation close to 1.

epsilon#

A small float added to variance to avoid dividing by zero.

Type

float

dtype#

the dtype of the result (default: infer from input and params).

Type

Optional[Any]

param_dtype#

the dtype passed to parameter initializers (default: float32).

Type

Any

use_bias#

If True, bias (beta) is added.

Type

bool

use_scale#

If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.

Type

bool

bias_init#

Initializer for bias, by default, zero.

Type

Callable[[Any, Tuple[int, …], Any], Any]

scale_init#

Initializer for scale, by default, one.

Type

Callable[[Any, Tuple[int, …], Any], Any]

reduction_axes#

Axes for computing normalization statistics.

Type

Union[int, Iterable[int]]

feature_axes#

Feature axes for learned bias and scaling.

Type

Union[int, Iterable[int]]

__call__(x)[source]#

Applies layer normalization on the input.

Parameters

x – the inputs

Returns

Normalized inputs (the same shape as inputs).

Methods