flax.linen.RMSNorm#
- class flax.linen.RMSNorm(epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_scale=True, scale_init=<function ones>, reduction_axes=-1, feature_axes=-1, axis_name=None, axis_index_groups=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
RMS Layer normalization (https://arxiv.org/abs/1910.07467).
RMSNorm normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. Unlike LayerNorm which re-centers the mean to be 0 and normalizes by the standard deviation of the activations, RMSNorm does not re-center at all and instead normalizes by the root mean square of the activations.
- Example::
>>> import jax.numpy as jnp >>> import jax >>> import flax.linen as nn ... >>> x = jax.random.uniform(jax.random.key(0), (2, 3)) >>> layer = nn.RMSNorm() >>> variables = layer.init(jax.random.key(1), x) >>> y = layer.apply(variables, x)
- epsilon#
A small float added to variance to avoid dividing by zero.
- Type
float
- dtype#
the dtype of the result (default: infer from input and params).
- Type
Optional[Any]
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type
Any
- use_scale#
If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.
- Type
bool
- scale_init#
Initializer for scale, by default, one.
- Type
Callable[[Any, Tuple[int, …], Any], Any]
- reduction_axes#
Axes for computing normalization statistics.
- Type
Union[int, Sequence[int]]
- feature_axes#
Feature axes for learned bias and scaling.
- Type
Union[int, Sequence[int]]
- axis_name#
the axis name used to combine batch statistics from multiple devices. See jax.pmap for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap.
- Type
Optional[str]
- axis_index_groups#
groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, [[0, 1], [2, 3]] would independently batch-normalize over the examples on the first two and last two devices. See jax.lax.psum for more details.
- Type
Any
- __call__(x)[source]#
Applies layer normalization on the input.
- Parameters
x – the inputs
- Returns
Normalized inputs (the same shape as inputs).
Methods