flax.optim.WeightNorm

class flax.optim.WeightNorm(wrapped_optimizer, wn_decay=0, wn_eps=1e-08)[source]

Adds weight normalization to an optimizer def.

See https://arxiv.org/abs/1602.07868

__init__(wrapped_optimizer, wn_decay=0, wn_eps=1e-08)[source]

Constructor for a WeightNorm optimizer.

Weight vectors are decomposed as \(w = g * v/||v||_2\), for scalar scale parameter g, and raw weight vector v. The original optimizer is then applied to the (g,v) parameterization and the updated parameters are transformed back to w-space, i.e. w,state –> (g,v) –(original optimizer)–> (g’,v’) –> w’,state’

We assume the output axis of any kernel matrix is the last one, as per the Tensorflow convention.

Parameters
  • wrapped_optimizer – another OptimizerDef

  • wn_decay – apply l2 decay to the unnormalized weight vector

  • wn_eps – additive constant for stability of the normalization (default: 1e-8).

Methods

__init__(wrapped_optimizer[, wn_decay, wn_eps])

Constructor for a WeightNorm optimizer.

apply_gradient(hyper_params, params, state, ...)

Applies a gradient for a set of parameters.

apply_param_gradient(step, hyper_params, ...)

Apply a gradient for a single parameter.

create(target[, focus])

Creates a new optimizer for the given target.

init_param_state(param)

Initializes the state for a parameter.

init_state(params)

restore_state(opt_target, opt_state, state_dict)

Restore the optimizer target and state from the state dict.

state_dict(target, state)

update_hyper_params(**hyper_param_overrides)

Updates the hyper parameters with a set of overrides.