flax.linen.WeightNorm#
- class flax.linen.WeightNorm(layer_instance, epsilon=1e-12, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_scale=True, scale_init=<function ones>, feature_axes=-1, variable_filter=<factory>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
L2 weight normalization (https://arxiv.org/pdf/1602.07868.pdf).
Weight normalization normalizes the weight params so that the l2-norm of the matrix is equal to 1. This is implemented as a layer wrapper where each wrapped layer will have its params l2-normalized before computing its
__call__
output.Example usage:
>>> import flax, flax.linen as nn >>> import jax, jax.numpy as jnp >>> class Baz(nn.Module): ... @nn.compact ... def __call__(self, x): ... return nn.Dense(2)(x) >>> class Bar(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = Baz()(x) ... x = nn.Dense(3)(x) ... x = Baz()(x) ... x = nn.Dense(3)(x) ... return x >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(3)(x) ... # l2-normalize all params of the second Dense layer ... x = nn.WeightNorm(nn.Dense(4), variable_filter=None)(x) ... x = nn.Dense(5)(x) ... # l2-normalize all kernels in the Bar submodule and all params in the ... # Baz submodule ... x = nn.WeightNorm(Bar(), variable_filter={'kernel', 'Baz'})(x) ... return x >>> # init >>> x = jnp.ones((1, 2)) >>> model = Foo() >>> variables = model.init(jax.random.key(0), x) >>> flax.core.freeze(jax.tree_map(jnp.shape, variables)) FrozenDict({ params: { Bar_0: { Baz_0: { Dense_0: { bias: (2,), kernel: (5, 2), }, }, Baz_1: { Dense_0: { bias: (2,), kernel: (3, 2), }, }, Dense_0: { bias: (3,), kernel: (2, 3), }, Dense_1: { bias: (3,), kernel: (2, 3), }, }, Dense_0: { bias: (3,), kernel: (2, 3), }, Dense_1: { bias: (4,), kernel: (3, 4), }, Dense_2: { bias: (5,), kernel: (4, 5), }, WeightNorm_0: { Dense_1/bias/scale: (4,), Dense_1/kernel/scale: (4,), }, WeightNorm_1: { Bar_0/Baz_0/Dense_0/bias/scale: (2,), Bar_0/Baz_0/Dense_0/kernel/scale: (2,), Bar_0/Baz_1/Dense_0/bias/scale: (2,), Bar_0/Baz_1/Dense_0/kernel/scale: (2,), Bar_0/Dense_0/kernel/scale: (3,), Bar_0/Dense_1/kernel/scale: (3,), }, }, })
- layer_instance#
Module instance that is wrapped with WeightNorm
- epsilon#
A small float added to l2-normalization to avoid dividing by zero.
- Type
float
- dtype#
the dtype of the result (default: infer from input and params).
- Type
Optional[Any]
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type
Any
- use_scale#
If True, creates a learnable variable
scale
that is multiplied to thelayer_instance
variables after l2-normalization.- Type
bool
- scale_init#
Initialization function for the scaling function.
- Type
Callable[[Any, Tuple[int, …], Any], Any]
- feature_axes#
The feature axes dimension(s). The l2-norm is calculated by reducing the
layer_instance
variables over the remaining (non-feature) axes. Therefore a separate l2-norm value is calculated and a separate scale (ifuse_scale=True
) is learned for each specified feature. By default, the trailing dimension is treated as the feature axis.- Type
Optional[Union[int, Sequence[int]]]
- variable_filter#
An optional iterable that contains string items. The WeightNorm layer will selectively apply l2-normalization to the
layer_instance
variables whose key path (delimited by ‘/’) has a match withvariable_filter
. For example,variable_filter={'kernel'}
will only apply l2-normalization to variables whose key path contains ‘kernel’. By default,variable_filter={'kernel'}
.- Type
Optional[Iterable]
- __call__(*args, **kwargs)[source]#
Compute the l2-norm of the weights in
self.layer_instance
and normalize the weights using this value before computing the__call__
output.- Parameters
*args – positional arguments to be passed into the call method of the underlying layer instance in
self.layer_instance
.**kwargs – keyword arguments to be passed into the call method of the underlying layer instance in
self.layer_instance
.
- Returns
Output of the layer using l2-normalized weights.
Methods