class flax.nn.GroupNorm(x, num_groups=32, group_size=None, epsilon=1e-06, dtype=<class 'jax._src.numpy.lax_numpy.float32'>, bias=True, scale=True, bias_init=<function zeros>, scale_init=<function ones>)[source]

Applies group normalization to the input (arxiv.org/abs/1803.08494).

This op is similar to batch normalization, but statistics are shared across equally-sized groups of channels and not shared across batch dimension. Thus, group normalization does not depend on the batch composition and does not require maintaining internal state for storing statistics.

The user should either specify the total number of channel groups or the number of channels per group.

  • x – the input of shape N…C, where N is a batch dimension and C is a channels dimensions. represents an arbitrary number of extra dimensions that are used to accumulate statistics over.

  • num_groups – the total number of channel groups. The default value of 32 is proposed by the original group normalization paper.

  • group_size – the number of channels in a group.

  • epsilon – A small float added to variance to avoid dividing by zero.

  • dtype – the dtype of the computation (default: float32).

  • bias – If True, bias (beta) is added.

  • scale – If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.

  • bias_init – Initializer for bias, by default, zero.

  • scale_init – Initializer for scale, by default, one.


Normalized inputs (the same shape as inputs).




apply(x[, num_groups, group_size, epsilon, …])

Applies group normalization to the input (arxiv.org/abs/1803.08494).

call(x[, num_groups, group_size, epsilon, …])

Evaluate the module with the given parameters.

create(x[, num_groups, group_size, epsilon, …])

Creates a module instance by evaluating the model.

create_by_shape(input_specs, x[, …])

Creates a module instance using only shape and dtype information.


Retrieves a parameter within the module’s apply function.

init(x[, num_groups, group_size, epsilon, …])

Initializes the module parameters.

init_by_shape(input_specs, x[, num_groups, …])

Initialize the module parameters.



param(name, shape, initializer)

Defines a parameter within the module’s apply function.

partial([num_groups, group_size, epsilon, …])

Partially applies a module with the given arguments.

shared(*[, name])

Partially applies a module and shared parameters for each call.

state(name[, shape, initializer, collection])

Declare a state variable within the module’s apply function.