class flax.optim.LAMB(learning_rate=None, beta1=0.9, beta2=0.999, weight_decay=0, eps=1e-06)[source]

Layerwise adaptive moments for batch (LAMB) optimizer.

See https://arxiv.org/abs/1904.00962

__init__(learning_rate=None, beta1=0.9, beta2=0.999, weight_decay=0, eps=1e-06)[source]

Constructor for the LAMB optimizer.

  • learning_rate – the step size used to update the parameters.

  • beta1 – the coefficient used for the moving average of the gradient (default: 0.9).

  • beta2 – the coefficient used for the moving average of the squared gradient (default: 0.999).

  • weight_decay – weight decay coefficient to apply

  • eps – epsilon used for Adam update computation (default: 1e-6).


__init__([learning_rate, beta1, beta2, ...])

Constructor for the LAMB optimizer.

apply_gradient(hyper_params, params, state, ...)

Applies a gradient for a set of parameters.

apply_param_gradient(step, hyper_params, ...)

Apply a gradient for a single parameter.

create(target[, focus])

Creates a new optimizer for the given target.


Initializes the state for a parameter.


restore_state(opt_target, opt_state, state_dict)

Restore the optimizer target and state from the state dict.

state_dict(target, state)


Updates the hyper parameters with a set of overrides.