flax.optim.RMSProp¶
- class flax.optim.RMSProp(learning_rate=None, beta2=0.9, eps=1e-08, centered=False)[source]¶
RMSProp optimizer
- Parameters
learning_rate (Optional[float]) –
- __init__(learning_rate=None, beta2=0.9, eps=1e-08, centered=False)[source]¶
Constructor for the RMSProp optimizer
- Parameters
learning_rate (Optional[float]) – the step size used to update the parameters.
beta2 – the coefficient used for the moving average of the gradient magnitude (default: 0.9).
eps – the term added to the gradient magnitude estimate for numerical stability.
centered – If True, gradients are normalized by the estimated variance of the gradient; if False, by the uncentered second moment. Setting this to True may help with training, but is slightly more expensive in terms of computation and memory. Defaults to False.
Methods
__init__
([learning_rate, beta2, eps, centered])Constructor for the RMSProp optimizer
apply_gradient
(hyper_params, params, state, ...)Applies a gradient for a set of parameters.
apply_param_gradient
(step, hyper_params, ...)Apply per-parameter gradients
create
(target[, focus])Creates a new optimizer for the given target.
init_param_state
(param)Initialize parameter state
init_state
(params)restore_state
(opt_target, opt_state, state_dict)Restore the optimizer target and state from the state dict.
state_dict
(target, state)update_hyper_params
(**hyper_param_overrides)Updates the hyper parameters with a set of overrides.