class flax.optim.WeightNorm(wrapped_optimizer, wn_decay=0, wn_eps=1e-08)[source]

Adds weight normalization to an optimizer def.

See https://arxiv.org/abs/1602.07868

__init__(wrapped_optimizer, wn_decay=0, wn_eps=1e-08)[source]

Constructor for a WeightNorm optimizer.

Weight vectors are decomposed as \(w = g * v/||v||_2\), for scalar scale parameter g, and raw weight vector v. The original optimizer is then applied to the (g,v) parameterization and the updated parameters are transformed back to w-space, i.e. w,state –> (g,v) –(original optimizer)–> (g’,v’) –> w’,state’

We assume the output axis of any kernel matrix is the last one, as per the Tensorflow convention.

  • wrapped_optimizer – another OptimizerDef

  • wn_decay – apply l2 decay to the unnormalized weight vector

  • wn_eps – additive constant for stability of the normalization (default: 1e-8).


