flax.optim.Adam

class flax.optim.Adam(learning_rate=None, beta1=0.9, beta2=0.999, eps=1e-08, weight_decay=0.0)[source]

Adam optimizer.

Implements Adam - a stochastic gradient descent method (SGD) that computes individual adaptive learning rates for different parameters from estimates of first- and second-order moments of the gradients.

Reference: [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980v8) (Kingma and Ba, 2014).

__init__(learning_rate=None, beta1=0.9, beta2=0.999, eps=1e-08, weight_decay=0.0)[source]

Constructor for the Adam optimizer.

Parameters
  • learning_rate – The learning rate — the step size used to update the parameters.

  • beta1 – The exponentian decay rate for the 1st moment estimates. The coefficient used to calculate the first moments of the gradients (the moving average of the gradient) (default: 0.9).

  • beta2 – The exponentian decay rate for the 2nd moment estimates. The coefficient used to calculate the second moments of the gradients (the moving average of the gradient magnitude) (default: 0.999).

  • eps – A small scalar added to the gradient magnitude estimate to improve numerical stability (default: 1e-8).

  • weight_decay – The weight decay. Note that for adaptive gradient algorithms such as Adam this is different from using L2 regularization. The weight decay is scaled by the learning rate schedule (which is consistent with other frameworks such as PyTorch, but different from the “decoupled weight decay” in https://arxiv.org/abs/1711.05101 where the weight decay is multiplied with the “schedule multiplier”, but not with the base learning rate).

Methods

__init__([learning_rate, beta1, beta2, eps, ...])

Constructor for the Adam optimizer.

apply_gradient(hyper_params, params, state, ...)

Applies a gradient for a set of parameters.

apply_param_gradient(step, hyper_params, ...)

Apply a gradient for a single parameter.

create(target[, focus])

Creates a new optimizer for the given target.

init_param_state(param)

Initializes the state for a parameter.

init_state(params)

restore_state(opt_target, opt_state, state_dict)

Restore the optimizer target and state from the state dict.

state_dict(target, state)

update_hyper_params(**hyper_param_overrides)

Updates the hyper parameters with a set of overrides.