flax.optim package
Contents
flax.optim package#
Flax Optimizer api.
Note that with FLIP #1009 the optimizers in flax.optim
were effectively
deprecated in favor of Optax optimizers. By now, optax
should support all
of the original features from flax.optim
(otherwise please create a Github
issue on optax
).
Optimizer Base Classes#
- class flax.optim.Optimizer(optimizer_def, state, target)[source]#
Flax optimizers are created using the
OptimizerDef
class. That class specifies the initialization and gradient application logic. Creating an optimizer using theOptimizerDef.create()
method will result in an instance of theOptimizer
class, which encapsulates the optimization target and state. The optimizer is updated using the methodapply_gradient()
.Example of constructing an optimizer for a model:
from flax import optim optimizer_def = optim.GradientDescent(learning_rate=0.1) optimizer = optimizer_def.create(model)
The optimizer is then used in a training step as follows:
def train_step(optimizer, data): def loss_fn(model): y = model(data) loss = ... # compute the loss aux = ... # compute auxiliary outputs (eg. training metrics) return loss, aux grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (loss, aux), grad = grad_fn(optimizer.target) new_optimizer = optimizer.apply_gradient(grad) return new_optimizer, loss, aux
Distributed training only requires a few extra additions:
from flax import optim optimizer_def = optim.GradientDescent(learning_rate=0.1) optimizer = optimizer_def.create(model) optimizer = jax_utils.replicate(optimizer) def train_step(optimizer, data): def loss_fn(model): y = model(data) loss = ... # compute the loss aux = ... # compute auxiliary outputs (eg. training metrics) return loss, aux grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (loss, aux), grad = grad_fn(optimizer.target) grad = jax.lax.pmean(grad, 'batch') new_optimizer = optimizer.apply_gradient(grad) return new_optimizer, loss, aux distributed_train_step = jax.pmap(train_step, axis_name='batch')
- optimizer_def#
The optimizer definition.
- state#
The initial state of the optimizer.
- Type
Any
- target#
The target to optimizer.
- Type
Any
- apply_gradient(grads, **hyper_param_overrides)[source]#
Applies a pytree of gradients to the target.
- Parameters
grads – A pytree of gradients.
**hyper_param_overrides – the hyper parameters passed to apply_gradient will override the defaults specified in the OptimizerDef. Pass hyper_params=… to replace all hyper parameters.
- Returns
A new optimizer with the updated target and state.
- class flax.optim.OptimizerDef(hyper_params)[source]#
Base class for an optimizer defintion, which specifies the initialization and gradient application logic.
See docstring of
Optimizer
for more details.- apply_gradient(hyper_params, params, state, grads)[source]#
Applies a gradient for a set of parameters.
- Parameters
hyper_params – a named tuple of hyper parameters.
params – the parameters that should be updated.
state – a named tuple containing the state of the optimizer
grads – the gradient tensors for the parameters.
- Returns
A tuple containing the new parameters and the new optimizer state.
- apply_param_gradient(step, hyper_params, param, state, grad)[source]#
Apply a gradient for a single parameter.
- Parameters
step – the current step of the optimizer.
hyper_params – a named tuple of hyper parameters.
param – the parameter that should be updated.
state – a named tuple containing the state for this parameter
grad – the gradient tensor for the parameter.
- Returns
A tuple containing the new parameter and the new state.
- create(target, focus=None)[source]#
Creates a new optimizer for the given target.
See docstring of
Optimizer
for more details.- Parameters
target – the object to be optimized. This is typically a variable dict returned by flax.linen.Module.init(), but it can also be a container of variables dicts, e.g. (v1, v2) and (‘var1’: v1, ‘var2’: v2) are valid inputs as well.
focus – a flax.traverse_util.Traversal that selects which subset of the target is optimized. See docstring of
MultiOptimizer
for an example of how to define a Traversal object.
- Returns
An instance of Optimizer.
- init_param_state(param)[source]#
Initializes the state for a parameter.
- Parameters
param – the parameter for which to initialize the state.
- Returns
A named tuple containing the initial optimization state for the parameter.
- update_hyper_params(**hyper_param_overrides)[source]#
Updates the hyper parameters with a set of overrides.
This method is called from Optimizer apply_gradient to create the hyper parameters for a specific optimization step.
- Parameters
**hyper_param_overrides – the hyper parameters updates will override the defaults specified in the OptimizerDef. Pass hyper_params=… to replace all hyper parameters.
- Returns
The new hyper parameters.
MultiOptimizer#
- class flax.optim.MultiOptimizer(*traversals_and_optimizers)[source]#
A MultiOptimizer is subclass of
OptimizerDef
and useful for applying separate optimizer algorithms to various subsets of the model parameters.The example below creates two optimizers using
flax.traverse_util.ModelParamTraversal
: one to optimizekernel
parameters and to optimizebias
parameters. Note each optimizer is created with a different learning rate:kernels = traverse_util.ModelParamTraversal(lambda path, _: 'kernel' in path) biases = traverse_util.ModelParamTraversal(lambda path, _: 'bias' in path) kernel_opt = optim.Momentum(learning_rate=0.01) bias_opt = optim.Momentum(learning_rate=0.1) opt_def = MultiOptimizer((kernels, kernel_opt), (biases, bias_opt)) optimizer = opt_def.create(model)
In order to train only a subset of the parameters, you can simply use a single
flax.traverse_util.ModelParamTraversal
instance.If you want to update the learning rates of both optimizers online with different learning rate schedules, you should update the learning rates when applying the gradient. In the following example, the second optimizer is not doing any optimization during the first 1000 steps:
hparams = optimizer.optimizer_def.hyper_params new_optimizer = optimizer.apply_gradient( grads, hyper_params=[ hparams[0].replace(learning_rate=0.2), hparams[1].replace(learning_rate=jnp.where(step < 1000, 0., lr)), ])
- update_hyper_params(**hyper_param_overrides)[source]#
Updates the hyper parameters with a set of overrides.
This method is called from
Optimizer.apply_gradient()
to create the hyper parameters for a specific optimization step. MultiOptimizer will apply the overrides for each sub optimizer.- Parameters
**hyper_param_overrides – the hyper parameters updates will override the defaults specified in the OptimizerDef. Pass hyper_params=… to replace all hyper parameters.
- Returns
The new hyper parameters.
Available optimizers#
|
Adam optimizer. |
|
AdaBelief optimizer. |
|
Adafactor optimizer. |
|
Adagrad optimizer |
|
Adadelta optimizer. |
|
Gradient descent optimizer. |
|
Layerwise adaptive moments for batch (LAMB) optimizer. |
|
Layerwise adaptive rate scaling (LARS) optimizer. |
|
Momentum optimizer. |
|
RMSProp optimizer |
|
Adds weight normalization to an optimizer def. |