flax.linen.BatchNorm

class flax.linen.BatchNorm(use_running_average=None, axis=-1, momentum=0.99, epsilon=1e-05, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, axis_name=None, axis_index_groups=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]

BatchNorm Module.

Usage Note: If we define a model with BatchNorm, for example:

BN = nn.BatchNorm(use_running_average=False, momentum=0.9, epsilon=1e-5,
                  dtype=jnp.float32)

The initialized variables dict will contain in addition to a ‘params’ collection a separate ‘batch_stats’ collection that will contain all the running statistics for all the BatchNorm layers in a model:

vars_initialized = BN.init(key, x)  # {'params': ..., 'batch_stats': ...}

We then update the batch_stats during training by specifying that the batch_stats collection is mutable in the apply method for our module.:

vars_in = {'params': params, 'batch_stats': old_batch_stats}
y, mutated_vars = BN.apply(vars_in, x, mutable=['batch_stats'])
new_batch_stats = mutated_vars['batch_stats']

During eval we would define BN with use_running_average=True and use the batch_stats collection from training to set the statistics. In this case we are not mutating the batch statistics collection, and needn’t mark it mutable:

vars_in = {'params': params, 'batch_stats': training_batch_stats}
y = BN.apply(vars_in, x)
Parameters
  • use_running_average (Optional[bool]) –

  • axis (int) –

  • momentum (float) –

  • epsilon (float) –

  • dtype (Optional[Any]) –

  • param_dtype (Any) –

  • use_bias (bool) –

  • use_scale (bool) –

  • bias_init (Callable[[Any, Tuple[int, ...], Any], Any]) –

  • scale_init (Callable[[Any, Tuple[int, ...], Any], Any]) –

  • axis_name (Optional[str]) –

  • axis_index_groups (Any) –

  • parent (Optional[Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel]]]) –

  • name (str) –

Return type

None

use_running_average

if True, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input.

Type

Optional[bool]

axis

the feature or non-batch axis of the input.

Type

int

momentum

decay rate for the exponential moving average of the batch statistics.

Type

float

epsilon

a small float added to variance to avoid dividing by zero.

Type

float

dtype

the dtype of the result (default: infer from input and params).

Type

Optional[Any]

param_dtype

the dtype passed to parameter initializers (default: float32).

Type

Any

use_bias

if True, bias (beta) is added.

Type

bool

use_scale

if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.

Type

bool

bias_init

initializer for bias, by default, zero.

Type

Callable[[Any, Tuple[int, …], Any], Any]

scale_init

initializer for scale, by default, one.

Type

Callable[[Any, Tuple[int, …], Any], Any]

axis_name

the axis name used to combine batch statistics from multiple devices. See jax.pmap for a description of axis names (default: None).

Type

Optional[str]

axis_index_groups

groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, [[0, 1], [2, 3]] would independently batch-normalize over the examples on the first two and last two devices. See jax.lax.psum for more details.

Type

Any

__call__(x, use_running_average=None)[source]

Normalizes the input using batch statistics.

NOTE: During initialization (when parameters are mutable) the running average of the batch statistics will not be updated. Therefore, the inputs fed during initialization don’t need to match that of the actual input distribution and the reduction axis (set with axis_name) does not have to exist.

Parameters
  • x – the input to be normalized.

  • use_running_average (Optional[bool]) – if true, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input.

Returns

Normalized inputs (the same shape as inputs).

Methods

bias_init(shape[, dtype])

An initializer that returns a constant array full of zeros.

scale_init(shape[, dtype])

An initializer that returns a constant array full of ones.