flax.linen.BatchNorm#

class flax.linen.BatchNorm(use_running_average=None, axis=-1, momentum=0.99, epsilon=1e-05, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, axis_name=None, axis_index_groups=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

BatchNorm Module.

Usage Note: If we define a model with BatchNorm, for example:

BN = nn.BatchNorm(use_running_average=False, momentum=0.9, epsilon=1e-5,
                  dtype=jnp.float32)

The initialized variables dict will contain in addition to a ‘params’ collection a separate ‘batch_stats’ collection that will contain all the running statistics for all the BatchNorm layers in a model:

vars_initialized = BN.init(key, x)  # {'params': ..., 'batch_stats': ...}

We then update the batch_stats during training by specifying that the batch_stats collection is mutable in the apply method for our module.:

vars_in = {'params': params, 'batch_stats': old_batch_stats}
y, mutated_vars = BN.apply(vars_in, x, mutable=['batch_stats'])
new_batch_stats = mutated_vars['batch_stats']

During eval we would define BN with use_running_average=True and use the batch_stats collection from training to set the statistics. In this case we are not mutating the batch statistics collection, and needn’t mark it mutable:

vars_in = {'params': params, 'batch_stats': training_batch_stats}
y = BN.apply(vars_in, x)
use_running_average#

if True, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input.

Type

Optional[bool]

axis#

the feature or non-batch axis of the input.

Type

int

momentum#

decay rate for the exponential moving average of the batch statistics.

Type

float

epsilon#

a small float added to variance to avoid dividing by zero.

Type

float

dtype#

the dtype of the result (default: infer from input and params).

Type

Optional[Any]

param_dtype#

the dtype passed to parameter initializers (default: float32).

Type

Any

use_bias#

if True, bias (beta) is added.

Type

bool

use_scale#

if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.

Type

bool

bias_init#

initializer for bias, by default, zero.

Type

Callable[[Any, Tuple[int, …], Any], Any]

scale_init#

initializer for scale, by default, one.

Type

Callable[[Any, Tuple[int, …], Any], Any]

axis_name#

the axis name used to combine batch statistics from multiple devices. See jax.pmap for a description of axis names (default: None).

Type

Optional[str]

axis_index_groups#

groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, [[0, 1], [2, 3]] would independently batch-normalize over the examples on the first two and last two devices. See jax.lax.psum for more details.

Type

Any

__call__(x, use_running_average=None)[source]#

Normalizes the input using batch statistics.

NOTE: During initialization (when parameters are mutable) the running average of the batch statistics will not be updated. Therefore, the inputs fed during initialization don’t need to match that of the actual input distribution and the reduction axis (set with axis_name) does not have to exist.

Parameters
  • x – the input to be normalized.

  • use_running_average – if true, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input.

Returns

Normalized inputs (the same shape as inputs).

Methods