flax.linen.WeightNorm#

class flax.linen.WeightNorm(layer_instance, epsilon=1e-12, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_scale=True, scale_init=<function ones>, feature_axes=-1, variable_filter=<factory>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

L2 weight normalization (https://arxiv.org/pdf/1602.07868.pdf).

Weight normalization normalizes the weight params so that the l2-norm of the matrix is equal to 1. This is implemented as a layer wrapper where each wrapped layer will have its params l2-normalized before computing its __call__ output.

Example usage:

>>> import flax, flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class Baz(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     return nn.Dense(2)(x)

>>> class Bar(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = Baz()(x)
...     x = nn.Dense(3)(x)
...     x = Baz()(x)
...     x = nn.Dense(3)(x)
...     return x

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(3)(x)
...     # l2-normalize all params of the second Dense layer
...     x = nn.WeightNorm(nn.Dense(4), variable_filter=None)(x)
...     x = nn.Dense(5)(x)
...     # l2-normalize all kernels in the Bar submodule and all params in the
...     # Baz submodule
...     x = nn.WeightNorm(Bar(), variable_filter={'kernel', 'Baz'})(x)
...     return x

>>> # init
>>> x = jnp.ones((1, 2))
>>> model = Foo()
>>> variables = model.init(jax.random.key(0), x)
>>> flax.core.freeze(jax.tree_map(jnp.shape, variables))
FrozenDict({
    params: {
        Bar_0: {
            Baz_0: {
                Dense_0: {
                    bias: (2,),
                    kernel: (5, 2),
                },
            },
            Baz_1: {
                Dense_0: {
                    bias: (2,),
                    kernel: (3, 2),
                },
            },
            Dense_0: {
                bias: (3,),
                kernel: (2, 3),
            },
            Dense_1: {
                bias: (3,),
                kernel: (2, 3),
            },
        },
        Dense_0: {
            bias: (3,),
            kernel: (2, 3),
        },
        Dense_1: {
            bias: (4,),
            kernel: (3, 4),
        },
        Dense_2: {
            bias: (5,),
            kernel: (4, 5),
        },
        WeightNorm_0: {
            Dense_1/bias/scale: (4,),
            Dense_1/kernel/scale: (4,),
        },
        WeightNorm_1: {
            Bar_0/Baz_0/Dense_0/bias/scale: (2,),
            Bar_0/Baz_0/Dense_0/kernel/scale: (2,),
            Bar_0/Baz_1/Dense_0/bias/scale: (2,),
            Bar_0/Baz_1/Dense_0/kernel/scale: (2,),
            Bar_0/Dense_0/kernel/scale: (3,),
            Bar_0/Dense_1/kernel/scale: (3,),
        },
    },
})
layer_instance#

Module instance that is wrapped with WeightNorm

Type

flax.linen.module.Module

epsilon#

A small float added to l2-normalization to avoid dividing by zero.

Type

float

dtype#

the dtype of the result (default: infer from input and params).

Type

Optional[Any]

param_dtype#

the dtype passed to parameter initializers (default: float32).

Type

Any

use_scale#

If True, creates a learnable variable scale that is multiplied to the layer_instance variables after l2-normalization.

Type

bool

scale_init#

Initialization function for the scaling function.

Type

Callable[[Any, Tuple[int, …], Any], Any]

feature_axes#

The feature axes dimension(s). The l2-norm is calculated by reducing the layer_instance variables over the remaining (non-feature) axes. Therefore a separate l2-norm value is calculated and a separate scale (if use_scale=True) is learned for each specified feature. By default, the trailing dimension is treated as the feature axis.

Type

Optional[Union[int, Sequence[int]]]

variable_filter#

An optional iterable that contains string items. The WeightNorm layer will selectively apply l2-normalization to the layer_instance variables whose key path (delimited by ‘/’) has a match with variable_filter. For example, variable_filter={'kernel'} will only apply l2-normalization to variables whose key path contains ‘kernel’. By default, variable_filter={'kernel'}.

Type

Optional[Iterable]

__call__(*args, **kwargs)[source]#

Compute the l2-norm of the weights in self.layer_instance and normalize the weights using this value before computing the __call__ output.

Parameters
  • *args – positional arguments to be passed into the call method of the underlying layer instance in self.layer_instance.

  • **kwargs – keyword arguments to be passed into the call method of the underlying layer instance in self.layer_instance.

Returns

Output of the layer using l2-normalized weights.

Methods