Normalization#
- class flax.experimental.nnx.BatchNorm(*args, **kwargs)[source]#
BatchNorm Module.
- use_running_average#
if True, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input.
- axis#
the feature or non-batch axis of the input.
- momentum#
decay rate for the exponential moving average of the batch statistics.
- epsilon#
a small float added to variance to avoid dividing by zero.
- dtype#
the dtype of the result (default: infer from input and params).
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- use_bias#
if True, bias (beta) is added.
- use_scale#
if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.
- bias_init#
initializer for bias, by default, zero.
- scale_init#
initializer for scale, by default, one.
- axis_name#
the axis name used to combine batch statistics from multiple devices. See jax.pmap for a description of axis names (default: None).
- axis_index_groups#
groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, [[0, 1], [2, 3]] would independently batch-normalize over the examples on the first two and last two devices. See jax.lax.psum for more details.
- use_fast_variance#
If true, use a faster, but less numerically stable, calculation for the variance.
- class flax.experimental.nnx.LayerNorm(*args, **kwargs)[source]#
Layer normalization (https://arxiv.org/abs/1607.06450).
LayerNorm normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that maintains the mean activation within each example close to 0 and the activation standard deviation close to 1.
- epsilon#
A small float added to variance to avoid dividing by zero.
- dtype#
the dtype of the result (default: infer from input and params).
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- use_bias#
If True, bias (beta) is added.
- use_scale#
If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.
- bias_init#
Initializer for bias, by default, zero.
- scale_init#
Initializer for scale, by default, one.
- reduction_axes#
Axes for computing normalization statistics.
- feature_axes#
Feature axes for learned bias and scaling.
- axis_name#
the axis name used to combine batch statistics from multiple devices. See jax.pmap for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap.
- axis_index_groups#
groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, [[0, 1], [2, 3]] would independently batch-normalize over the examples on the first two and last two devices. See jax.lax.psum for more details.
- use_fast_variance#
If true, use a faster, but less numerically stable, calculation for the variance.
- class flax.experimental.nnx.RMSNorm(*args, **kwargs)[source]#
RMS Layer normalization (https://arxiv.org/abs/1910.07467).
RMSNorm normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. Unlike LayerNorm which re-centers the mean to be 0 and normalizes by the standard deviation of the activations, RMSNorm does not re-center at all and instead normalizes by the root mean square of the activations.
- epsilon#
A small float added to variance to avoid dividing by zero.
- dtype#
the dtype of the result (default: infer from input and params).
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- use_scale#
If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.
- scale_init#
Initializer for scale, by default, one.
- reduction_axes#
Axes for computing normalization statistics.
- feature_axes#
Feature axes for learned bias and scaling.
- axis_name#
the axis name used to combine batch statistics from multiple devices. See jax.pmap for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap.
- axis_index_groups#
groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, [[0, 1], [2, 3]] would independently batch-normalize over the examples on the first two and last two devices. See jax.lax.psum for more details.
- use_fast_variance#
If true, use a faster, but less numerically stable, calculation for the variance.