class flax.nn.LayerNorm(x, epsilon=1e-06, dtype=<class 'jax._src.numpy.lax_numpy.float32'>, bias=True, scale=True, bias_init=<function zeros>, scale_init=<function ones>)[source]

Applies layer normalization on the input.

It normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that maintains the mean activation within each example close to 0 and the activation standard deviation close to 1.

  • x – the inputs

  • epsilon – A small float added to variance to avoid dividing by zero.

  • dtype – the dtype of the computation (default: float32).

  • bias – If True, bias (beta) is added.

  • scale – If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.

  • bias_init – Initializer for bias, by default, zero.

  • scale_init – Initializer for scale, by default, one.


Normalized inputs (the same shape as inputs).




apply(x[, epsilon, dtype, bias, scale, …])

Applies layer normalization on the input.

call(x[, epsilon, dtype, bias, scale, …])

Evaluate the module with the given parameters.

create(x[, epsilon, dtype, bias, scale, …])

Creates a module instance by evaluating the model.

create_by_shape(input_specs, x[, epsilon, …])

Creates a module instance using only shape and dtype information.


Retrieves a parameter within the module’s apply function.

init(x[, epsilon, dtype, bias, scale, …])

Initializes the module parameters.

init_by_shape(input_specs, x[, epsilon, …])

Initialize the module parameters.



param(name, shape, initializer)

Defines a parameter within the module’s apply function.

partial([epsilon, dtype, bias, scale, …])

Partially applies a module with the given arguments.

shared(*[, name])

Partially applies a module and shared parameters for each call.

state(name[, shape, initializer, collection])

Declare a state variable within the module’s apply function.