flax.linen.SpectralNorm#

class flax.linen.SpectralNorm(layer_instance, n_steps=1, epsilon=1e-12, dtype=None, param_dtype=<class 'jax.numpy.float32'>, error_on_non_matrix=False, collection_name='batch_stats', parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Spectral normalization. See:

Spectral normalization normalizes the weight params so that the spectral norm of the matrix is equal to 1. This is implemented as a layer wrapper where each wrapped layer will have its params spectral normalized before computing its __call__ output.

Usage Note: The initialized variables dict will contain, in addition to a ‘params’ collection, a separate ‘batch_stats’ collection that will contain a u vector and sigma value, which are intermediate values used when performing spectral normalization. During training, we pass in update_stats=True and mutable=['batch_stats'] so that u and sigma are updated with the most recently computed values using power iteration. This will help the power iteration method approximate the true singular value more accurately over time. During eval, we pass in update_stats=False to ensure we get deterministic behavior from the model. For example:

Example usage:

>>> import flax, flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> import optax

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x, train):
...     x = nn.Dense(3)(x)
...     # only spectral normalize the params of the second Dense layer
...     x = nn.SpectralNorm(nn.Dense(4))(x, update_stats=train)
...     x = nn.Dense(5)(x)
...     return x

>>> # init
>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 5))
>>> model = Foo()
>>> variables = model.init(jax.random.PRNGKey(0), x, train=False)
>>> flax.core.freeze(jax.tree_map(jnp.shape, variables))
FrozenDict({
    batch_stats: {
        SpectralNorm_0: {
            Dense_1/kernel/sigma: (),
            Dense_1/kernel/u: (1, 4),
        },
    },
    params: {
        Dense_0: {
            bias: (3,),
            kernel: (2, 3),
        },
        Dense_1: {
            bias: (4,),
            kernel: (3, 4),
        },
        Dense_2: {
            bias: (5,),
            kernel: (4, 5),
        },
    },
})

>>> # train
>>> def train_step(variables, x, y):
...   def loss_fn(params):
...     logits, updates = model.apply(
...         {'params': params, 'batch_stats': variables['batch_stats']},
...         x,
...         train=True,
...         mutable=['batch_stats'],
...     )
...     loss = jnp.mean(optax.l2_loss(predictions=logits, targets=y))
...     return loss, updates
...
...   (loss, updates), grads = jax.value_and_grad(loss_fn, has_aux=True)(
...       variables['params']
...   )
...   return {
...       'params': jax.tree_map(
...           lambda p, g: p - 0.1 * g, variables['params'], grads
...       ),
...       'batch_stats': updates['batch_stats'],
...   }, loss
>>> for _ in range(10):
...   variables, loss = train_step(variables, x, y)

>>> # inference / eval
>>> out = model.apply(variables, x, train=False)
layer_instance#

Module instance that is wrapped with SpectralNorm

Type

flax.linen.module.Module

n_steps#

How many steps of power iteration to perform to approximate the singular value of the weight params.

Type

int

epsilon#

A small float added to l2-normalization to avoid dividing by zero.

Type

float

dtype#

the dtype of the result (default: infer from input and params).

Type

Optional[Any]

param_dtype#

the dtype passed to parameter initializers (default: float32).

Type

Any

error_on_non_matrix#

Spectral normalization is only defined on matrices. By default, this module will return scalars unchanged and flatten higher-order tensors in their leading dimensions. Setting this flag to True will instead throw an error if a weight tensor with dimension greater than 2 is used by the layer.

Type

bool

collection_name#

Name of the collection to store intermediate values used when performing spectral normalization.

Type

str

__call__(*args, update_stats, **kwargs)[source]#

Compute the largest singular value of the weights in self.layer_instance using power iteration and normalize the weights using this value before computing the __call__ output.

Parameters
  • *args – positional arguments to be passed into the call method of the underlying layer instance in self.layer_instance.

  • update_stats – if True, update the internal u vector and sigma value after computing their updated values using power iteration. This will help the power iteration method approximate the true singular value more accurately over time.

  • **kwargs – keyword arguments to be passed into the call method of the underlying layer instance in self.layer_instance.

Returns

Output of the layer using spectral normalized weights.

Methods