flax.optim package¶
Optimizer core¶
Flax Optimizer api.
-
class
flax.optim.
Optimizer
(optimizer_def, state, target)[source]¶ Wraps an optimizer with its hyper_params, state, and model parameters.
-
apply_gradient
(grads, **hyper_param_overrides)[source]¶ Applies a pytree of gradients to the target.
- Parameters
grads – A pytree of gradients.
**hyper_param_overrides – the hyper parameters passed to apply_gradient will override the defaults specified in the OptimizerDef. Pass hyper_params=… to replace all hyper parameters.
- Returns
A new optimizer with the updated target and state.
-
compute_gradients
(loss_fn)¶ Computes gradient of loss_fn.
DEPRECATION WARNING: compute_gradient() is deprecated. Use jax.grad() or jax.value_and_grad() instead.
- Parameters
loss_fn – a function that receives the target and returns a loss or a tuple of the loss and auxiliary outputs.
- Returns
- A tuple consisting of the loss, auxiliary outputs if any,
and a list of gradient.
-
optimize
(loss_fn, **hyper_param_overrides)[source]¶ Optimizes the target with respect to a loss function.
DEPRECATION WARNING: optimize() is deprecated. Use jax.grad() or jax.value_and_grad() and apply_gradient() instead.
- Parameters
loss_fn – function that receives the target and returns a loss or a tuple of the loss and auxiliary outputs.
**hyper_param_overrides – the hyper parameters passed to apply_gradient will override the defaults specified in the OptimizerDef. Pass hyper_params=… to replace all hyper parameters.
- Returns
- A tuple consisting of the new optimizer, the loss,
and the auxiliary outputs if any.
-
-
class
flax.optim.
OptimizerDef
(hyper_params)[source]¶ Base class for optimizers.
-
apply_gradient
(hyper_params, params, state, grads)[source]¶ Applies a gradient for a set of parameters.
- Parameters
hyper_params – a named tuple of hyper parameters.
params – the parameters that should be updated.
state – a named tuple containing the state of the optimizer
grads – the gradient tensors for the parameters.
- Returns
A tuple containing the new parameters and the new optimizer state.
-
apply_param_gradient
(step, hyper_params, param, state, grad)[source]¶ Apply a gradient for a single parameter.
- Parameters
step – the current step of the optimizer.
hyper_params – a named tuple of hyper parameters.
param – the parameter that should be updated.
state – a named tuple containing the state for this parameter
grad – the gradient tensor for the parameter.
- Returns
A tuple containing the new parameter and the new state.
-
create
(target, focus=None)[source]¶ Creates a new optimizer for the given target.
- Parameters
target – the object to be optimized. This will typically be an instance of flax.nn.Model.
focus – a flax.traverse_util.Traversal that selects which subset of the target is optimized.
- Returns
An instance of Optimizer.
-
init_param_state
(param)[source]¶ Initializes the state for a parameter.
- Parameters
param – the parameter for which to initialize the state.
- Returns
A named tuple containing the initial optimization state for the parameter.
-
update_hyper_params
(**hyper_param_overrides)[source]¶ Updates the hyper parameters with a set of overrides.
This method is called from Optimizer apply_gradient to create the hyper parameters for a specific optimization step.
- Parameters
**hyper_param_overrides – the hyper parameters updates will override the defaults specified in the OptimizerDef. Pass hyper_params=… to replace all hyper parameters.
- Returns
The new hyper parameters.
-
Available optimizers¶
-
class
flax.optim.
Adam
(learning_rate=None, beta1=0.9, beta2=0.999, eps=1e-08, weight_decay=0.0)[source]¶ Adam optimizer.
-
class
flax.optim.
Adafactor
(learning_rate=None, factored=True, multiply_by_parameter_scale=True, beta1=None, decay_rate=0.8, step_offset=0, clipping_threshold=1.0, weight_decay_rate=None, min_dim_size_to_factor=128, epsilon1=1e-30, epsilon2=0.001)[source]¶ Adafactor optimizer.
Adafactor is described in https://arxiv.org/abs/1804.04235.
-
class
flax.optim.
DynamicScale
(growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, fin_steps=0, scale=65536.0)[source]¶ Dynamic loss scaling for mixed precision gradients.
For many models gradient computations in float16 will result in numerical issues because small/large gradients being flushed to zero/infinity. Dynamic loss scaling is an algorithm that aims to find the largest scalar multiple for which the gradient does not overflow. This way the risk of underflow is minimized.
the value_and_grad method mimicks jax.value_and_grad. Beside the loss and gradients it also ouputs and updated DynamicScale instance with the current loss scale factor. This method also returns a boolean value indicating whether the gradients are finite.
- Example::
- def loss_fn(p):
return jnp.asarray(p, jnp.float16) ** 2
p = jnp.array(1., jnp.float32)
dyn_scale = optim.DynamicScale(growth_interval=10) compute_grad = jax.jit(lambda ds, p: ds.value_and_grad(loss_fn)(p)) for _ in range(100):
dyn_scale, is_fin, loss, grad = compute_grad(dyn_scale, p) p += jnp.where(is_fin, 0.01 * grad, 0.) print(loss)
Jax currently cannot execute conditionals efficiently on GPUs therefore we selectifly ignore the gradient update using jax.numpy.where in case of non-finite gradients.
- Attrs:
- growth_factor: how much to grow the scalar after a period of finite
gradients (default: 2.).
- backoff_factor: how much to shrink the scalar after a non-finite gradient
(default: 0.5).
- growth_interval: after how many steps of finite gradients the scale should
be increased (default: 2000).
fin_steps: indicates how many gradient steps in a row have been finite. scale: the current scale by which the loss is multiplied.
-
class
flax.optim.
LAMB
(learning_rate=None, beta1=0.9, beta2=0.999, weight_decay=0, eps=1e-06)[source]¶ Layerwise adaptive moments for batch (LAMB) optimizer.
-
class
flax.optim.
LARS
(learning_rate=None, beta=0.9, weight_decay=0, trust_coefficient=0.001, eps=0, nesterov=False)[source]¶ Layerwise adaptive rate scaling (LARS) optimizer.
-
class
flax.optim.
Momentum
(learning_rate=None, beta=0.9, weight_decay=0, nesterov=False)[source]¶ Momentum optimizer.
-
class
flax.optim.
RMSProp
(learning_rate=None, beta2=0.9, eps=1e-08, centered=False)[source]¶ RMSProp optimizer