# Source code for flax.nn.normalization

```
# Copyright 2021 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Normalization modules for Flax."""
from . import base
from jax import lax
from jax.nn import initializers
import jax.numpy as jnp
_no_init = lambda rng, shape: ()
def _absolute_dims(rank, dims):
return tuple([rank + dim if dim < 0 else dim for dim in dims])
[docs]class BatchNorm(base.Module):
"""DEPRECATION WARNING:
The `flax.nn` module is Deprecated, use `flax.linen` instead.
Learn more and find an upgrade guide at
https://github.com/google/flax/blob/main/flax/linen/README.md"
BatchNorm Module."""
def apply(self,
x,
batch_stats=None,
use_running_average=False,
axis=-1,
momentum=0.99,
epsilon=1e-5,
dtype=jnp.float32,
bias=True,
scale=True,
bias_init=initializers.zeros,
scale_init=initializers.ones,
axis_name=None,
axis_index_groups=None):
"""Normalizes the input using batch statistics.
Args:
x: the input to be normalized.
batch_stats: a `flax.nn.Collection` used to store an exponential moving
average of the batch statistics (default: None).
use_running_average: if true, the statistics stored in batch_stats
will be used instead of computing the batch statistics on the input.
axis: the feature or non-batch axis of the input.
momentum: decay rate for the exponential moving average of
the batch statistics.
epsilon: a small float added to variance to avoid dividing by zero.
dtype: the dtype of the computation (default: float32).
bias: if True, bias (beta) is added.
scale: if True, multiply by scale (gamma).
When the next layer is linear (also e.g. nn.relu), this can be disabled
since the scaling will be done by the next layer.
bias_init: initializer for bias, by default, zero.
scale_init: initializer for scale, by default, one.
axis_name: the axis name used to combine batch statistics from multiple
devices. See `jax.pmap` for a description of axis names (default: None).
axis_index_groups: groups of axis indices within that named axis
representing subsets of devices to reduce over (default: None). For example,
`[[0, 1], [2, 3]]` would independently batch-normalize over the examples
on the first two and last two devices. See `jax.lax.psum` for more details.
Returns:
Normalized inputs (the same shape as inputs).
"""
x = jnp.asarray(x, jnp.float32)
axis = axis if isinstance(axis, tuple) else (axis,)
axis = _absolute_dims(x.ndim, axis)
feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape))
reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis)
reduction_axis = tuple(i for i in range(x.ndim) if i not in axis)
if self.is_stateful() or batch_stats:
ra_mean = self.state('mean', reduced_feature_shape,
initializers.zeros, collection=batch_stats)
ra_var = self.state('var', reduced_feature_shape,
initializers.ones, collection=batch_stats)
else:
ra_mean = None
ra_var = None
if use_running_average:
if ra_mean is None:
raise ValueError('when use_running_averages is True '
'either use a stateful context or provide batch_stats')
mean, var = ra_mean.value, ra_var.value
else:
mean = jnp.mean(x, axis=reduction_axis, keepdims=False)
mean2 = jnp.mean(lax.square(x), axis=reduction_axis, keepdims=False)
if axis_name is not None and not self.is_initializing():
concatenated_mean = jnp.concatenate([mean, mean2])
mean, mean2 = jnp.split(
lax.pmean(
concatenated_mean,
axis_name=axis_name,
axis_index_groups=axis_index_groups), 2)
var = mean2 - lax.square(mean)
if ra_mean and not self.is_initializing():
ra_mean.value = momentum * ra_mean.value + (1 - momentum) * mean
ra_var.value = momentum * ra_var.value + (1 - momentum) * var
y = x - mean.reshape(feature_shape)
mul = lax.rsqrt(var + epsilon)
if scale:
mul = mul * self.param(
'scale', reduced_feature_shape, scale_init).reshape(feature_shape)
y = y * mul
if bias:
y = y + self.param(
'bias', reduced_feature_shape, bias_init).reshape(feature_shape)
return jnp.asarray(y, dtype)
[docs]class LayerNorm(base.Module):
"""DEPRECATION WARNING:
The `flax.nn` module is Deprecated, use `flax.linen` instead.
Learn more and find an upgrade guide at
https://github.com/google/flax/blob/main/flax/linen/README.md"
Layer normalization (https://arxiv.org/abs/1607.06450).
Operates on the last axis of the input data.
"""
def apply(self,
x,
epsilon=1e-6,
dtype=jnp.float32,
bias=True,
scale=True,
bias_init=initializers.zeros,
scale_init=initializers.ones):
"""Applies layer normalization on the input.
It normalizes the activations of the layer for each given example in a
batch independently, rather than across a batch like Batch Normalization.
i.e. applies a transformation that maintains the mean activation within
each example close to 0 and the activation standard deviation close to 1.
Args:
x: the inputs
epsilon: A small float added to variance to avoid dividing by zero.
dtype: the dtype of the computation (default: float32).
bias: If True, bias (beta) is added.
scale: If True, multiply by scale (gamma). When the next layer is linear
(also e.g. nn.relu), this can be disabled since the scaling will be done
by the next layer.
bias_init: Initializer for bias, by default, zero.
scale_init: Initializer for scale, by default, one.
Returns:
Normalized inputs (the same shape as inputs).
"""
x = jnp.asarray(x, jnp.float32)
features = x.shape[-1]
mean = jnp.mean(x, axis=-1, keepdims=True)
mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
var = mean2 - lax.square(mean)
mul = lax.rsqrt(var + epsilon)
if scale:
mul = mul * jnp.asarray(self.param('scale', (features,), scale_init),
dtype)
y = (x - mean) * mul
if bias:
y = y + jnp.asarray(self.param('bias', (features,), bias_init), dtype)
return jnp.asarray(y, dtype)
[docs]class GroupNorm(base.Module):
"""DEPRECATION WARNING:
The `flax.nn` module is Deprecated, use `flax.linen` instead.
Learn more and find an upgrade guide at
https://github.com/google/flax/blob/main/flax/linen/README.md"
Group normalization (arxiv.org/abs/1803.08494)."""
def apply(self,
x,
num_groups=32,
group_size=None,
epsilon=1e-6,
dtype=jnp.float32,
bias=True,
scale=True,
bias_init=initializers.zeros,
scale_init=initializers.ones):
"""Applies group normalization to the input (arxiv.org/abs/1803.08494).
This op is similar to batch normalization, but statistics are shared across
equally-sized groups of channels and not shared across batch dimension.
Thus, group normalization does not depend on the batch composition and does
not require maintaining internal state for storing statistics.
The user should either specify the total number of channel groups or the
number of channels per group.
Args:
x: the input of shape N...C, where N is a batch dimension and C is a
channels dimensions. `...` represents an arbitrary number of extra
dimensions that are used to accumulate statistics over.
num_groups: the total number of channel groups. The default value of 32 is
proposed by the original group normalization paper.
group_size: the number of channels in a group.
epsilon: A small float added to variance to avoid dividing by zero.
dtype: the dtype of the computation (default: float32).
bias: If True, bias (beta) is added.
scale: If True, multiply by scale (gamma). When the next layer is linear
(also e.g. nn.relu), this can be disabled since the scaling will be done
by the next layer.
bias_init: Initializer for bias, by default, zero.
scale_init: Initializer for scale, by default, one.
Returns:
Normalized inputs (the same shape as inputs).
"""
x = jnp.asarray(x, jnp.float32)
if ((num_groups is None and group_size is None) or
(num_groups is not None and group_size is not None)):
raise ValueError('Either `num_groups` or `group_size` should be '
'specified, but not both of them.')
if group_size is not None:
channels = x.shape[-1]
if channels % group_size != 0:
raise ValueError('Number of channels ({}) is not multiple of the '
'group size ({}).'.format(channels, group_size))
num_groups = channels // group_size
input_shape = x.shape
group_shape = x.shape[:-1] + (num_groups, x.shape[-1] // num_groups)
x = x.reshape(group_shape)
reduction_axis = [d for d in range(1, x.ndim - 2)] + [x.ndim - 1]
mean = jnp.mean(x, axis=reduction_axis, keepdims=True)
mean_of_squares = jnp.mean(jnp.square(x), axis=reduction_axis,
keepdims=True)
var = mean_of_squares - jnp.square(mean)
x = (x - mean) * lax.rsqrt(var + epsilon)
x = x.reshape(input_shape)
feature_shape = tuple([1 for d in input_shape[:-1]] + [input_shape[-1]])
if scale:
x = x * self.param('scale', feature_shape, scale_init)
if bias:
x = x + self.param('bias', feature_shape, bias_init)
return x.astype(dtype)
```