flax.optim.Momentum

class flax.optim.Momentum(learning_rate=None, beta=0.9, weight_decay=0, nesterov=False)[source]

Momentum optimizer.

__init__(learning_rate=None, beta=0.9, weight_decay=0, nesterov=False)[source]

Constructor for the Momentum optimizer.

Parameters
  • learning_rate – the step size used to update the parameters.

  • beta – the coefficient used for the moving average of the gradient (default: 0.9).

  • weight_decay – weight decay coefficient to apply (default: 0).

  • nesterov – whether to use Nesterov momentum (default: False).

Methods

__init__([learning_rate, beta, ...])

Constructor for the Momentum optimizer.

apply_gradient(hyper_params, params, state, ...)

Applies a gradient for a set of parameters.

apply_param_gradient(step, hyper_params, ...)

Apply a gradient for a single parameter.

create(target[, focus])

Creates a new optimizer for the given target.

init_param_state(param)

Initializes the state for a parameter.

init_state(params)

restore_state(opt_target, opt_state, state_dict)

Restore the optimizer target and state from the state dict.

state_dict(target, state)

update_hyper_params(**hyper_param_overrides)

Updates the hyper parameters with a set of overrides.