class flax.nn.BatchNorm(x, batch_stats=None, use_running_average=False, axis=-1, momentum=0.99, epsilon=1e-05, dtype=<class 'jax._src.numpy.lax_numpy.float32'>, bias=True, scale=True, bias_init=<function zeros>, scale_init=<function ones>, axis_name=None, axis_index_groups=None)[source]

Normalizes the input using batch statistics.

  • x – the input to be normalized.

  • batch_stats – a flax.nn.Collection used to store an exponential moving average of the batch statistics (default: None).

  • use_running_average – if true, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input.

  • axis – the feature or non-batch axis of the input.

  • momentum – decay rate for the exponential moving average of the batch statistics.

  • epsilon – a small float added to variance to avoid dividing by zero.

  • dtype – the dtype of the computation (default: float32).

  • bias – if True, bias (beta) is added.

  • scale – if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.

  • bias_init – initializer for bias, by default, zero.

  • scale_init – initializer for scale, by default, one.

  • axis_name – the axis name used to combine batch statistics from multiple devices. See jax.pmap for a description of axis names (default: None).

  • axis_index_groups – groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, [[0, 1], [2, 3]] would independently batch-normalize over the examples on the first two and last two devices. See jax.lax.psum for more details.


Normalized inputs (the same shape as inputs).


Initialize self. See help(type(self)) for accurate signature.



Initialize self.

apply(x[, batch_stats, use_running_average, …])

Normalizes the input using batch statistics.

call(x[, batch_stats, use_running_average, …])

Evaluate the module with the given parameters.

create(x[, batch_stats, …])

Creates a module instance by evaluating the model.

create_by_shape(input_specs, x[, …])

Creates a module instance using only shape and dtype information.


Retrieves a parameter within the module’s apply function.

init(x[, batch_stats, use_running_average, …])

Initializes the module parameters.

init_by_shape(input_specs, x[, batch_stats, …])

Initialize the module parameters.



param(name, shape, initializer)

Defines a parameter within the module’s apply function.

partial([batch_stats, use_running_average, …])

Partially applies a module with the given arguments.

shared(*[, name])

Partially applies a module and shared parameters for each call.

state(name[, shape, initializer, collection])

Declare a state variable within the module’s apply function.