flax.optim.RMSProp

class flax.optim.RMSProp(learning_rate=None, beta2=0.9, eps=1e-08, centered=False)[source]

RMSProp optimizer

Parameters

learning_rate (Optional[float]) –

__init__(learning_rate=None, beta2=0.9, eps=1e-08, centered=False)[source]

Constructor for the RMSProp optimizer

Parameters
  • learning_rate (Optional[float]) – the step size used to update the parameters.

  • beta2 – the coefficient used for the moving average of the gradient magnitude (default: 0.9).

  • eps – the term added to the gradient magnitude estimate for numerical stability.

  • centered – If True, gradients are normalized by the estimated variance of the gradient; if False, by the uncentered second moment. Setting this to True may help with training, but is slightly more expensive in terms of computation and memory. Defaults to False.

Methods

__init__([learning_rate, beta2, eps, centered])

Constructor for the RMSProp optimizer

apply_gradient(hyper_params, params, state, ...)

Applies a gradient for a set of parameters.

apply_param_gradient(step, hyper_params, ...)

Apply per-parameter gradients

create(target[, focus])

Creates a new optimizer for the given target.

init_param_state(param)

Initialize parameter state

init_state(params)

restore_state(opt_target, opt_state, state_dict)

Restore the optimizer target and state from the state dict.

state_dict(target, state)

update_hyper_params(**hyper_param_overrides)

Updates the hyper parameters with a set of overrides.