class flax.nn.BatchNorm(x, batch_stats=None, use_running_average=False, axis=-1, momentum=0.99, epsilon=1e-05, dtype=<class 'jax._src.numpy.lax_numpy.float32'>, bias=True, scale=True, bias_init=<function zeros>, scale_init=<function ones>, axis_name=None, axis_index_groups=None)[source]

Normalizes the input using batch statistics.

  • x – the input to be normalized.

  • batch_stats – a flax.nn.Collection used to store an exponential moving average of the batch statistics (default: None).

  • use_running_average – if true, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input.

  • axis – the feature or non-batch axis of the input.

  • momentum – decay rate for the exponential moving average of the batch statistics.

  • epsilon – a small float added to variance to avoid dividing by zero.

  • dtype – the dtype of the computation (default: float32).

  • bias – if True, bias (beta) is added.

  • scale – if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.

  • bias_init – initializer for bias, by default, zero.

  • scale_init – initializer for scale, by default, one.

  • axis_name – the axis name used to combine batch statistics from multiple devices. See jax.pmap for a description of axis names (default: None).

  • axis_index_groups – groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, [[0, 1], [2, 3]] would independently batch-normalize over the examples on the first two and last two devices. See jax.lax.psum for more details.


Normalized inputs (the same shape as inputs).


Initialize self.

apply(x[, batch_stats, use_running_average, …])

Normalizes the input using batch statistics.

call(x[, batch_stats, use_running_average, …])

Evaluate the module with the given parameters.

create(x[, batch_stats, …])

Creates a module instance by evaluating the model.

create_by_shape(input_specs, x[, …])

Creates a module instance using only shape and dtype information.


Retrieves a parameter within the module’s apply function.

init(x[, batch_stats, use_running_average, …])

Initializes the module parameters.

init_by_shape(input_specs, x[, batch_stats, …])

Initialize the module parameters.



param(name, shape, initializer)

Defines a parameter within the module’s apply function.

partial([batch_stats, use_running_average, …])

Partially applies a module with the given arguments.

shared(*[, name])

Partially applies a module and shared parameters for each call.

state(name[, shape, initializer, collection])

Declare a state variable within the module’s apply function.