flax.linen.BatchNorm#
- class flax.linen.BatchNorm(use_running_average=None, axis=-1, momentum=0.99, epsilon=1e-05, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, axis_name=None, axis_index_groups=None, use_fast_variance=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
BatchNorm Module.
Usage Note: If we define a model with BatchNorm, for example:
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> BN = nn.BatchNorm(momentum=0.9, epsilon=1e-5, dtype=jnp.float32)
The initialized variables dict will contain, in addition to a ‘params’ collection, a separate ‘batch_stats’ collection that will contain all the running statistics for all the BatchNorm layers in a model:
>>> x = jax.random.normal(jax.random.key(0), (5, 6)) >>> variables = BN.init(jax.random.key(1), x, use_running_average=False) >>> jax.tree_map(jnp.shape, variables) {'batch_stats': {'mean': (6,), 'var': (6,)}, 'params': {'bias': (6,), 'scale': (6,)}}
We then update the batch_stats during training by specifying that the batch_stats collection is mutable in the apply method for our module.:
>>> y, new_batch_stats = BN.apply(variables, x, mutable=['batch_stats'], use_running_average=False)
During eval we would define BN with use_running_average=True and use the batch_stats collection from training to set the statistics. In this case we are not mutating the batch statistics collection, and needn’t mark it mutable:
>>> y = BN.apply(variables, x, mutable=['batch_stats'], use_running_average=True)
- use_running_average#
if True, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input.
- Type
Optional[bool]
- axis#
the feature or non-batch axis of the input.
- Type
int
- momentum#
decay rate for the exponential moving average of the batch statistics.
- Type
float
- epsilon#
a small float added to variance to avoid dividing by zero.
- Type
float
- dtype#
the dtype of the result (default: infer from input and params).
- Type
Optional[Any]
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type
Any
- use_bias#
if True, bias (beta) is added.
- Type
bool
- use_scale#
if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.
- Type
bool
- bias_init#
initializer for bias, by default, zero.
- Type
Callable[[Any, Tuple[int, …], Any], Any]
- scale_init#
initializer for scale, by default, one.
- Type
Callable[[Any, Tuple[int, …], Any], Any]
- axis_name#
the axis name used to combine batch statistics from multiple devices. See jax.pmap for a description of axis names (default: None). Note, this is only used for pmap and shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives.
- Type
Optional[str]
- axis_index_groups#
groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, [[0, 1], [2, 3]] would independently batch-normalize over the examples on the first two and last two devices. See jax.lax.psum for more details.
- Type
Any
- use_fast_variance#
If true, use a faster, but less numerically stable, calculation for the variance.
- Type
bool
- __call__(x, use_running_average=None, mask=None)[source]#
Normalizes the input using batch statistics.
NOTE: During initialization (when self.is_initializing() is True) the running average of the batch statistics will not be updated. Therefore, the inputs fed during initialization don’t need to match that of the actual input distribution and the reduction axis (set with axis_name) does not have to exist.
- Parameters
x – the input to be normalized.
use_running_average – if true, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input.
mask – Binary array of shape broadcastable to inputs tensor, indicating the positions for which the mean and variance should be computed.
- Returns
Normalized inputs (the same shape as inputs).
Methods