flax.optim.AdaBelief

class flax.optim.AdaBelief(learning_rate=None, beta1=0.9, beta2=0.999, eps=1e-16, weight_decay=0.0)[source]

AdaBelief optimizer.

Implements AdaBelief - an adaptive learning rate optimizer that achieves fast convergence, generalisation, and stability. It adapts the step size depending on its “belief” in the gradient direction — the optimizer adaptively scales step size by the difference between the predicted and observed gradients. AdaBelief is a modified version of Adam and contains the same number of parameters.

Reference: [AdaBelief optimizer: adapting stepsizes by the belief in observed gradients](https://arxiv.org/abs/2010.07468) (Juntang Zhuang et al. 2020).

learning_rate

The learning rate — the step size used to update the parameters.

beta1

The exponentian decay rate for the 1st moment estimates. The coefficient used to calculate the first moments of the gradients (the moving average of the gradient) (default: 0.9).

beta2

The exponentian decay rate for the 2nd moment estimates. The coefficient used to calculate the second moments of the gradients (the moving average of the gradient magnitude) (default: 0.999).

eps

A small scalar added to the gradient magnitude estimate to improve numerical stability (default: 1e-16). Note that AdaBelief uses eps inside sqrt, while Adam uses eps outside sqrt (default: 1e-8).

weight_decay

The learning rate decay (default: 0.0).

__init__(learning_rate=None, beta1=0.9, beta2=0.999, eps=1e-16, weight_decay=0.0)[source]

Constructor for the Adam optimizer.

Parameters
  • learning_rate – The step size used to update the parameters.

  • beta1 – The coefficient used for the moving average of the gradient (default: 0.9).

  • beta2 – The coefficient used for the moving average of the gradient magnitude (default: 0.999).

  • eps – The term added to the gradient magnitude estimate for numerical stability (default: 1e-16). Note that AdaBelief uses eps inside sqrt, while Adam uses eps outside sqrt (default: 1e-8).

  • weight_decay – AdamW style weight decay rate (relative to learning rate) (default: 0.0).

Methods

__init__([learning_rate, beta1, beta2, eps, ...])

Constructor for the Adam optimizer.

apply_gradient(hyper_params, params, state, ...)

Applies a gradient for a set of parameters.

apply_param_gradient(step, hyper_params, ...)

Apply a gradient for a single parameter.

create(target[, focus])

Creates a new optimizer for the given target.

init_param_state(param)

Initializes the state for a parameter.

init_state(params)

restore_state(opt_target, opt_state, state_dict)

Restore the optimizer target and state from the state dict.

state_dict(target, state)

update_hyper_params(**hyper_param_overrides)

Updates the hyper parameters with a set of overrides.