flax.linen.GroupNorm

class flax.linen.GroupNorm(num_groups=32, group_size=None, epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Group normalization (arxiv.org/abs/1803.08494).

This op is similar to batch normalization, but statistics are shared across equally-sized groups of channels and not shared across batch dimension. Thus, group normalization does not depend on the batch composition and does not require maintaining internal state for storing statistics. The user should either specify the total number of channel groups or the number of channels per group.

Parameters
  • num_groups (Optional[int]) –

  • group_size (Optional[int]) –

  • epsilon (float) –

  • dtype (Optional[Any]) –

  • param_dtype (Any) –

  • use_bias (bool) –

  • use_scale (bool) –

  • bias_init (Callable[[Any, Tuple[int, ...], Any], Any]) –

  • scale_init (Callable[[Any, Tuple[int, ...], Any], Any]) –

  • parent (Optional[Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel]]]) –

  • name (str) –

Return type

None

num_groups

the total number of channel groups. The default value of 32 is proposed by the original group normalization paper.

Type

Optional[int]

group_size

the number of channels in a group.

Type

Optional[int]

epsilon

A small float added to variance to avoid dividing by zero.

Type

float

dtype

the dtype of the result (default: infer from input and params).

Type

Optional[Any]

param_dtype

the dtype passed to parameter initializers (default: float32).

Type

Any

use_bias

If True, bias (beta) is added.

Type

bool

use_scale

If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.

Type

bool

bias_init

Initializer for bias, by default, zero.

Type

Callable[[Any, Tuple[int, …], Any], Any]

scale_init

Initializer for scale, by default, one.

Type

Callable[[Any, Tuple[int, …], Any], Any]

__call__(x)[source]

Applies group normalization to the input (arxiv.org/abs/1803.08494).

Parameters

x – the input of shape N…C, where N is a batch dimension and C is a channels dimensions. represents an arbitrary number of extra dimensions that are used to accumulate statistics over.

Returns

Normalized inputs (the same shape as inputs).

Methods

bias_init(shape[, dtype])

An initializer that returns a constant array full of zeros.

scale_init(shape[, dtype])

An initializer that returns a constant array full of ones.