# flax.nn.BatchNorm¶

class flax.nn.BatchNorm(x, batch_stats=None, use_running_average=False, axis=-1, momentum=0.99, epsilon=1e-05, dtype=<class 'jax._src.numpy.lax_numpy.float32'>, bias=True, scale=True, bias_init=<function zeros>, scale_init=<function ones>, axis_name=None, axis_index_groups=None)[source]

Normalizes the input using batch statistics.

Parameters
• x – the input to be normalized.

• batch_stats – a flax.nn.Collection used to store an exponential moving average of the batch statistics (default: None).

• use_running_average – if true, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input.

• axis – the feature or non-batch axis of the input.

• momentum – decay rate for the exponential moving average of the batch statistics.

• epsilon – a small float added to variance to avoid dividing by zero.

• dtype – the dtype of the computation (default: float32).

• bias – if True, bias (beta) is added.

• scale – if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.

• bias_init – initializer for bias, by default, zero.

• scale_init – initializer for scale, by default, one.

• axis_name – the axis name used to combine batch statistics from multiple devices. See jax.pmap for a description of axis names (default: None).

• axis_index_groups – groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, [[0, 1], [2, 3]] would independently batch-normalize over the examples on the first two and last two devices. See jax.lax.psum for more details.

Returns

Normalized inputs (the same shape as inputs).

__init__()

Initialize self. See help(type(self)) for accurate signature.

Methods

 Initialize self. apply(x[, batch_stats, use_running_average, …]) Normalizes the input using batch statistics. call(x[, batch_stats, use_running_average, …]) Evaluate the module with the given parameters. create(x[, batch_stats, …]) Creates a module instance by evaluating the model. create_by_shape(input_specs, x[, …]) Creates a module instance using only shape and dtype information. get_param(name) Retrieves a parameter within the module’s apply function. init(x[, batch_stats, use_running_average, …]) Initializes the module parameters. init_by_shape(input_specs, x[, batch_stats, …]) Initialize the module parameters. is_initializing() is_stateful() param(name, shape, initializer) Defines a parameter within the module’s apply function. partial([batch_stats, use_running_average, …]) Partially applies a module with the given arguments. shared(*[, name]) Partially applies a module and shared parameters for each call. state(name[, shape, initializer, collection]) Declare a state variable within the module’s apply function.