flax.linen.GroupNorm
flax.linen.GroupNorm#
- class flax.linen.GroupNorm(num_groups=32, group_size=None, epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Group normalization (arxiv.org/abs/1803.08494).
This op is similar to batch normalization, but statistics are shared across equally-sized groups of channels and not shared across batch dimension. Thus, group normalization does not depend on the batch composition and does not require maintaining internal state for storing statistics. The user should either specify the total number of channel groups or the number of channels per group.
- num_groups#
the total number of channel groups. The default value of 32 is proposed by the original group normalization paper.
- Type
Optional[int]
- group_size#
the number of channels in a group.
- Type
Optional[int]
- epsilon#
A small float added to variance to avoid dividing by zero.
- Type
float
- dtype#
the dtype of the result (default: infer from input and params).
- Type
Optional[Any]
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type
Any
- use_bias#
If True, bias (beta) is added.
- Type
bool
- use_scale#
If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.
- Type
bool
- bias_init#
Initializer for bias, by default, zero.
- Type
Callable[[Any, Tuple[int, …], Any], Any]
- scale_init#
Initializer for scale, by default, one.
- Type
Callable[[Any, Tuple[int, …], Any], Any]
- __call__(x)[source]#
Applies group normalization to the input (arxiv.org/abs/1803.08494).
- Parameters
x – the input of shape N…C, where N is a batch dimension and C is a channels dimensions. … represents an arbitrary number of extra dimensions that are used to accumulate statistics over.
- Returns
Normalized inputs (the same shape as inputs).
Methods