flax.nn.BatchNorm¶
-
class
flax.nn.
BatchNorm
(x, batch_stats=None, use_running_average=False, axis=-1, momentum=0.99, epsilon=1e-05, dtype=<class 'jax._src.numpy.lax_numpy.float32'>, bias=True, scale=True, bias_init=<function zeros>, scale_init=<function ones>, axis_name=None, axis_index_groups=None)[source]¶ Normalizes the input using batch statistics.
- Parameters
x – the input to be normalized.
batch_stats – a flax.nn.Collection used to store an exponential moving average of the batch statistics (default: None).
use_running_average – if true, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input.
axis – the feature or non-batch axis of the input.
momentum – decay rate for the exponential moving average of the batch statistics.
epsilon – a small float added to variance to avoid dividing by zero.
dtype – the dtype of the computation (default: float32).
bias – if True, bias (beta) is added.
scale – if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.
bias_init – initializer for bias, by default, zero.
scale_init – initializer for scale, by default, one.
axis_name – the axis name used to combine batch statistics from multiple devices. See jax.pmap for a description of axis names (default: None).
axis_index_groups – groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, [[0, 1], [2, 3]] would independently batch-normalize over the examples on the first two and last two devices. See jax.lax.psum for more details.
- Returns
Normalized inputs (the same shape as inputs).
-
__init__
()¶ Initialize self. See help(type(self)) for accurate signature.
Methods
__init__
()Initialize self.
apply
(x[, batch_stats, use_running_average, …])Normalizes the input using batch statistics.
call
(x[, batch_stats, use_running_average, …])Evaluate the module with the given parameters.
create
(x[, batch_stats, …])Creates a module instance by evaluating the model.
create_by_shape
(input_specs, x[, …])Creates a module instance using only shape and dtype information.
get_param
(name)Retrieves a parameter within the module’s apply function.
init
(x[, batch_stats, use_running_average, …])Initializes the module parameters.
init_by_shape
(input_specs, x[, batch_stats, …])Initialize the module parameters.
is_initializing
()is_stateful
()param
(name, shape, initializer)Defines a parameter within the module’s apply function.
partial
([batch_stats, use_running_average, …])Partially applies a module with the given arguments.
shared
(*[, name])Partially applies a module and shared parameters for each call.
state
(name[, shape, initializer, collection])Declare a state variable within the module’s apply function.