flax.optim.WeightNorm¶
- class flax.optim.WeightNorm(wrapped_optimizer, wn_decay=0, wn_eps=1e-08)[source]¶
Adds weight normalization to an optimizer def.
See https://arxiv.org/abs/1602.07868
- __init__(wrapped_optimizer, wn_decay=0, wn_eps=1e-08)[source]¶
Constructor for a WeightNorm optimizer.
Weight vectors are decomposed as \(w = g * v/||v||_2\), for scalar scale parameter g, and raw weight vector v. The original optimizer is then applied to the (g,v) parameterization and the updated parameters are transformed back to w-space, i.e. w,state –> (g,v) –(original optimizer)–> (g’,v’) –> w’,state’
We assume the output axis of any kernel matrix is the last one, as per the Tensorflow convention.
- Parameters
wrapped_optimizer – another OptimizerDef
wn_decay – apply l2 decay to the unnormalized weight vector
wn_eps – additive constant for stability of the normalization (default: 1e-8).
Methods
__init__
(wrapped_optimizer[, wn_decay, wn_eps])Constructor for a WeightNorm optimizer.
apply_gradient
(hyper_params, params, state, ...)Applies a gradient for a set of parameters.
apply_param_gradient
(step, hyper_params, ...)Apply a gradient for a single parameter.
create
(target[, focus])Creates a new optimizer for the given target.
init_param_state
(param)Initializes the state for a parameter.
init_state
(params)restore_state
(opt_target, opt_state, state_dict)Restore the optimizer target and state from the state dict.
state_dict
(target, state)update_hyper_params
(**hyper_param_overrides)Updates the hyper parameters with a set of overrides.