flax.optim package

Optimizer Base Classes

class flax.optim.Optimizer(optimizer_def, state, target)[source]

Flax optimizers are created using the OptimizerDef class. That class specifies the initialization and gradient application logic. Creating an optimizer using the OptimizerDef.create() method will result in an instance of the Optimizer class, which encapsulates the optimization target and state. The optimizer is updated using the method apply_gradient().

Example of constructing an optimizer for a model:

from flax import optim
optimizer_def = optim.GradientDescent(learning_rate=0.1)
optimizer = optimizer_def.create(model)

The optimizer is then used in a training step as follows:

def train_step(optimizer, data):
  def loss_fn(model):
    y = model(data)
    loss = ... # compute the loss
    aux = ... # compute auxiliary outputs (eg. training metrics)
    return loss, aux
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, aux), grad = grad_fn(optimizer.target)
  new_optimizer = optimizer.apply_gradient(grad)
  return new_optimizer, loss, aux

Distributed training only requires a few extra additions:

from flax import optim
optimizer_def = optim.GradientDescent(learning_rate=0.1)
optimizer = optimizer_def.create(model)
optimizer = jax_utils.replicate(optimizer)

def train_step(optimizer, data):
  def loss_fn(model):
    y = model(data)
    loss = ... # compute the loss
    aux = ... # compute auxiliary outputs (eg. training metrics)
    return loss, aux
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, aux), grad = grad_fn(optimizer.target)
  grad = jax.lax.pmean(grad, 'batch')
  new_optimizer = optimizer.apply_gradient(grad)
  return new_optimizer, loss, aux

distributed_train_step = jax.pmap(train_step, axis_name='batch')
optimizer_def

The optimizer definition.

Type

flax.optim.base.OptimizerDef

state

The initial state of the optimizer.

Type

Any

target

The target to optimizer.

Type

Any

apply_gradient(grads, **hyper_param_overrides)[source]

Applies a pytree of gradients to the target.

Parameters
  • grads – A pytree of gradients.

  • **hyper_param_overrides – the hyper parameters passed to apply_gradient will override the defaults specified in the OptimizerDef. Pass hyper_params=… to replace all hyper parameters.

Returns

A new optimizer with the updated target and state.

class flax.optim.OptimizerDef(hyper_params)[source]

Base class for an optimizer defintion, which specifies the initialization and gradient application logic.

See docstring of Optimizer for more details.

apply_gradient(hyper_params, params, state, grads)[source]

Applies a gradient for a set of parameters.

Parameters
  • hyper_params – a named tuple of hyper parameters.

  • params – the parameters that should be updated.

  • state – a named tuple containing the state of the optimizer

  • grads – the gradient tensors for the parameters.

Returns

A tuple containing the new parameters and the new optimizer state.

apply_param_gradient(step, hyper_params, param, state, grad)[source]

Apply a gradient for a single parameter.

Parameters
  • step – the current step of the optimizer.

  • hyper_params – a named tuple of hyper parameters.

  • param – the parameter that should be updated.

  • state – a named tuple containing the state for this parameter

  • grad – the gradient tensor for the parameter.

Returns

A tuple containing the new parameter and the new state.

create(target, focus=None)[source]

Creates a new optimizer for the given target.

See docstring of Optimizer for more details.

Parameters
  • target – the object to be optimized. This is typically a variable dict returned by flax.linen.Module.init(), but it can also be a container of variables dicts, e.g. (v1, v2) and (‘var1’: v1, ‘var2’: v2) are valid inputs as well.

  • focus (Optional[flax.optim.base.ModelParamTraversal]) – a flax.traverse_util.Traversal that selects which subset of the target is optimized. See docstring of MultiOptimizer for an example of how to define a Traversal object.

Returns

An instance of Optimizer.

init_param_state(param)[source]

Initializes the state for a parameter.

Parameters

param – the parameter for which to initialize the state.

Returns

A named tuple containing the initial optimization state for the parameter.

update_hyper_params(**hyper_param_overrides)[source]

Updates the hyper parameters with a set of overrides.

This method is called from Optimizer apply_gradient to create the hyper parameters for a specific optimization step.

Parameters

**hyper_param_overrides – the hyper parameters updates will override the defaults specified in the OptimizerDef. Pass hyper_params=… to replace all hyper parameters.

Returns

The new hyper parameters.

MultiOptimizer

class flax.optim.MultiOptimizer(*traversals_and_optimizers)[source]

A MultiOptimizer is subclass of OptimizerDef and useful for applying separate optimizer algorithms to various subsets of the model parameters.

The example below creates two optimizers using ModelParamTraversal: one to optimize kernel parameters and to optimize bias parameters. Note each optimizer is created with a different learning rate:

kernels = optim.ModelParamTraversal(lambda path, _: 'kernel' in path)
biases = optim.ModelParamTraversal(lambda path, _: 'bias' in path)
kernel_opt = optim.Momentum(learning_rate=0.01)
bias_opt = optim.Momentum(learning_rate=0.1)
opt_def = MultiOptimizer((kernels, kernel_opt), (biases, bias_opt))
optimizer = opt_def.create(model)

In order to train only a subset of the parameters, you can simply use a single ModelParamTraversal instance.

If you want to update the learning rates of both optimizers online with different learning rate schedules, you should update the learning rates when applying the gradient. In the following example, the second optimizer is not doing any optimization during the first 1000 steps:

hparams = optimizer.optimizer_def.hyper_params
new_optimizer = optimizer.apply_gradient(
    grads,
    hyper_params=[
      hparams[0].replace(learning_rate=0.2),
      hparams[1].replace(learning_rate=jnp.where(step < 1000, 0., lr)),
    ])
update_hyper_params(**hyper_param_overrides)[source]

Updates the hyper parameters with a set of overrides.

This method is called from Optimizer.apply_gradient() to create the hyper parameters for a specific optimization step. MultiOptimizer will apply the overrides for each sub optimizer.

Parameters

**hyper_param_overrides – the hyper parameters updates will override the defaults specified in the OptimizerDef. Pass hyper_params=… to replace all hyper parameters.

Returns

The new hyper parameters.

class flax.optim.ModelParamTraversal(filter_fn)[source]

Select model parameters using a name filter.

This traversal operates on a nested dictionary of parameters and selects a subset based on the filter_fn argument.

See MultiOptimizer for an example of how to use ModelParamTraversal to update subsets of the parameter tree with a specific optimizer.

Backward compatibility: When using the old api the parameters can be encapsulated in a flax.nn.Model instance.

__init__(filter_fn)[source]

Constructor a new ModelParamTraversal.

Parameters

filter_fn – a function that takes a parameter’s full name and its value and returns whether this parameter should be selected or not. The name of a parameter is determined by the module hierarchy and the parameter name (for example: ‘/module/sub_module/parameter_name’).

Available optimizers

Adam([learning_rate, beta1, beta2, eps, …])

Adam optimizer.

Adagrad([learning_rate, eps])

Adagrad optimizer

DynamicScale([growth_factor, …])

Dynamic loss scaling for mixed precision gradients.

GradientDescent([learning_rate])

Gradient descent optimizer.

LAMB([learning_rate, beta1, beta2, …])

Layerwise adaptive moments for batch (LAMB) optimizer.

LARS([learning_rate, beta, weight_decay, …])

Layerwise adaptive rate scaling (LARS) optimizer.

Momentum([learning_rate, beta, …])

Momentum optimizer.

RMSProp([learning_rate, beta2, eps, centered])

RMSProp optimizer

WeightNorm(wrapped_optimizer[, wn_decay, wn_eps])

Adds weight normalization to an optimizer def.