flax.linen.LayerNorm

class flax.linen.LayerNorm(epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Layer normalization (https://arxiv.org/abs/1607.06450).

Operates on the last axis of the input data.

It normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that maintains the mean activation within each example close to 0 and the activation standard deviation close to 1.

Parameters
  • epsilon (float) –

  • dtype (Optional[Any]) –

  • param_dtype (Any) –

  • use_bias (bool) –

  • use_scale (bool) –

  • bias_init (Callable[[Any, Tuple[int, ...], Any], Any]) –

  • scale_init (Callable[[Any, Tuple[int, ...], Any], Any]) –

  • parent (Optional[Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel]]]) –

  • name (str) –

Return type

None

epsilon

A small float added to variance to avoid dividing by zero.

Type

float

dtype

the dtype of the result (default: infer from input and params).

Type

Optional[Any]

param_dtype

the dtype passed to parameter initializers (default: float32).

Type

Any

use_bias

If True, bias (beta) is added.

Type

bool

use_scale

If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.

Type

bool

bias_init

Initializer for bias, by default, zero.

Type

Callable[[Any, Tuple[int, …], Any], Any]

scale_init

Initializer for scale, by default, one.

Type

Callable[[Any, Tuple[int, …], Any], Any]

__call__(x)[source]

Applies layer normalization on the input.

Parameters

x – the inputs

Returns

Normalized inputs (the same shape as inputs).

Methods

bias_init(shape[, dtype])

An initializer that returns a constant array full of zeros.

scale_init(shape[, dtype])

An initializer that returns a constant array full of ones.