flax.linen.SpectralNorm#
- class flax.linen.SpectralNorm(layer_instance, n_steps=1, epsilon=1e-12, dtype=None, param_dtype=<class 'jax.numpy.float32'>, error_on_non_matrix=False, collection_name='batch_stats', parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Spectral normalization. See:
Spectral normalization normalizes the weight params so that the spectral norm of the matrix is equal to 1. This is implemented as a layer wrapper where each wrapped layer will have its params spectral normalized before computing its
__call__
output.Usage Note: The initialized variables dict will contain, in addition to a ‘params’ collection, a separate ‘batch_stats’ collection that will contain a
u
vector andsigma
value, which are intermediate values used when performing spectral normalization. During training, we pass inupdate_stats=True
andmutable=['batch_stats']
so thatu
andsigma
are updated with the most recently computed values using power iteration. This will help the power iteration method approximate the true singular value more accurately over time. During eval, we pass inupdate_stats=False
to ensure we get deterministic behavior from the model. For example:Example usage:
>>> import flax, flax.linen as nn >>> import jax, jax.numpy as jnp >>> import optax >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x, train): ... x = nn.Dense(3)(x) ... # only spectral normalize the params of the second Dense layer ... x = nn.SpectralNorm(nn.Dense(4))(x, update_stats=train) ... x = nn.Dense(5)(x) ... return x >>> # init >>> x = jnp.ones((1, 2)) >>> y = jnp.ones((1, 5)) >>> model = Foo() >>> variables = model.init(jax.random.PRNGKey(0), x, train=False) >>> flax.core.freeze(jax.tree_map(jnp.shape, variables)) FrozenDict({ batch_stats: { SpectralNorm_0: { Dense_1/kernel/sigma: (), Dense_1/kernel/u: (1, 4), }, }, params: { Dense_0: { bias: (3,), kernel: (2, 3), }, Dense_1: { bias: (4,), kernel: (3, 4), }, Dense_2: { bias: (5,), kernel: (4, 5), }, }, }) >>> # train >>> def train_step(variables, x, y): ... def loss_fn(params): ... logits, updates = model.apply( ... {'params': params, 'batch_stats': variables['batch_stats']}, ... x, ... train=True, ... mutable=['batch_stats'], ... ) ... loss = jnp.mean(optax.l2_loss(predictions=logits, targets=y)) ... return loss, updates ... ... (loss, updates), grads = jax.value_and_grad(loss_fn, has_aux=True)( ... variables['params'] ... ) ... return { ... 'params': jax.tree_map( ... lambda p, g: p - 0.1 * g, variables['params'], grads ... ), ... 'batch_stats': updates['batch_stats'], ... }, loss >>> for _ in range(10): ... variables, loss = train_step(variables, x, y) >>> # inference / eval >>> out = model.apply(variables, x, train=False)
- layer_instance#
Module instance that is wrapped with SpectralNorm
- n_steps#
How many steps of power iteration to perform to approximate the singular value of the weight params.
- Type
int
- epsilon#
A small float added to l2-normalization to avoid dividing by zero.
- Type
float
- dtype#
the dtype of the result (default: infer from input and params).
- Type
Optional[Any]
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type
Any
- error_on_non_matrix#
Spectral normalization is only defined on matrices. By default, this module will return scalars unchanged and flatten higher-order tensors in their leading dimensions. Setting this flag to True will instead throw an error if a weight tensor with dimension greater than 2 is used by the layer.
- Type
bool
- collection_name#
Name of the collection to store intermediate values used when performing spectral normalization.
- Type
str
- __call__(*args, update_stats, **kwargs)[source]#
Compute the largest singular value of the weights in
self.layer_instance
using power iteration and normalize the weights using this value before computing the__call__
output.- Parameters
*args – positional arguments to be passed into the call method of the underlying layer instance in
self.layer_instance
.update_stats – if True, update the internal
u
vector andsigma
value after computing their updated values using power iteration. This will help the power iteration method approximate the true singular value more accurately over time.**kwargs – keyword arguments to be passed into the call method of the underlying layer instance in
self.layer_instance
.
- Returns
Output of the layer using spectral normalized weights.
Methods