flax.optim.AdaBelief¶
- class flax.optim.AdaBelief(learning_rate=None, beta1=0.9, beta2=0.999, eps=1e-16, weight_decay=0.0)[source]¶
AdaBelief optimizer.
Implements AdaBelief - an adaptive learning rate optimizer that achieves fast convergence, generalisation, and stability. It adapts the step size depending on its “belief” in the gradient direction — the optimizer adaptively scales step size by the difference between the predicted and observed gradients. AdaBelief is a modified version of Adam and contains the same number of parameters.
Reference: [AdaBelief optimizer: adapting stepsizes by the belief in observed gradients](https://arxiv.org/abs/2010.07468) (Juntang Zhuang et al. 2020).
- learning_rate¶
The learning rate — the step size used to update the parameters.
- beta1¶
The exponentian decay rate for the 1st moment estimates. The coefficient used to calculate the first moments of the gradients (the moving average of the gradient) (default: 0.9).
- beta2¶
The exponentian decay rate for the 2nd moment estimates. The coefficient used to calculate the second moments of the gradients (the moving average of the gradient magnitude) (default: 0.999).
- eps¶
A small scalar added to the gradient magnitude estimate to improve numerical stability (default: 1e-16). Note that AdaBelief uses eps inside sqrt, while Adam uses eps outside sqrt (default: 1e-8).
- weight_decay¶
The learning rate decay (default: 0.0).
- __init__(learning_rate=None, beta1=0.9, beta2=0.999, eps=1e-16, weight_decay=0.0)[source]¶
Constructor for the Adam optimizer.
- Parameters
learning_rate – The step size used to update the parameters.
beta1 – The coefficient used for the moving average of the gradient (default: 0.9).
beta2 – The coefficient used for the moving average of the gradient magnitude (default: 0.999).
eps – The term added to the gradient magnitude estimate for numerical stability (default: 1e-16). Note that AdaBelief uses eps inside sqrt, while Adam uses eps outside sqrt (default: 1e-8).
weight_decay – AdamW style weight decay rate (relative to learning rate) (default: 0.0).
Methods
__init__
([learning_rate, beta1, beta2, eps, ...])Constructor for the Adam optimizer.
apply_gradient
(hyper_params, params, state, ...)Applies a gradient for a set of parameters.
apply_param_gradient
(step, hyper_params, ...)Apply a gradient for a single parameter.
create
(target[, focus])Creates a new optimizer for the given target.
init_param_state
(param)Initializes the state for a parameter.
init_state
(params)restore_state
(opt_target, opt_state, state_dict)Restore the optimizer target and state from the state dict.
state_dict
(target, state)update_hyper_params
(**hyper_param_overrides)Updates the hyper parameters with a set of overrides.