flax.optim.LAMB¶
- class flax.optim.LAMB(learning_rate=None, beta1=0.9, beta2=0.999, weight_decay=0, eps=1e-06)[source]¶
Layerwise adaptive moments for batch (LAMB) optimizer.
See https://arxiv.org/abs/1904.00962
- __init__(learning_rate=None, beta1=0.9, beta2=0.999, weight_decay=0, eps=1e-06)[source]¶
Constructor for the LAMB optimizer.
- Parameters
learning_rate – the step size used to update the parameters.
beta1 – the coefficient used for the moving average of the gradient (default: 0.9).
beta2 – the coefficient used for the moving average of the squared gradient (default: 0.999).
weight_decay – weight decay coefficient to apply
eps – epsilon used for Adam update computation (default: 1e-6).
Methods
__init__
([learning_rate, beta1, beta2, ...])Constructor for the LAMB optimizer.
apply_gradient
(hyper_params, params, state, ...)Applies a gradient for a set of parameters.
apply_param_gradient
(step, hyper_params, ...)Apply a gradient for a single parameter.
create
(target[, focus])Creates a new optimizer for the given target.
init_param_state
(param)Initializes the state for a parameter.
init_state
(params)restore_state
(opt_target, opt_state, state_dict)Restore the optimizer target and state from the state dict.
state_dict
(target, state)update_hyper_params
(**hyper_param_overrides)Updates the hyper parameters with a set of overrides.