flax.optim package

Optimizer core

Flax Optimizer api.

class flax.optim.Optimizer(optimizer_def, state, target)[source]

Wraps an optimizer with its hyper_params, state, and model parameters.

apply_gradient(grads, **hyper_param_overrides)[source]

Applies a pytree of gradients to the target.

Parameters
  • grads – A pytree of gradients.

  • **hyper_param_overrides – the hyper parameters passed to apply_gradient will override the defaults specified in the OptimizerDef. Pass hyper_params=… to replace all hyper parameters.

Returns

A new optimizer with the updated target and state.

compute_gradients(loss_fn)

Computes gradient of loss_fn.

DEPRECATION WARNING: compute_gradient() is deprecated. Use jax.grad() or jax.value_and_grad() instead.

Parameters

loss_fn – a function that receives the target and returns a loss or a tuple of the loss and auxiliary outputs.

Returns

A tuple consisting of the loss, auxiliary outputs if any,

and a list of gradient.

optimize(loss_fn, **hyper_param_overrides)[source]

Optimizes the target with respect to a loss function.

DEPRECATION WARNING: optimize() is deprecated. Use jax.grad() or jax.value_and_grad() and apply_gradient() instead.

Parameters
  • loss_fn – function that receives the target and returns a loss or a tuple of the loss and auxiliary outputs.

  • **hyper_param_overrides – the hyper parameters passed to apply_gradient will override the defaults specified in the OptimizerDef. Pass hyper_params=… to replace all hyper parameters.

Returns

A tuple consisting of the new optimizer, the loss,

and the auxiliary outputs if any.

class flax.optim.OptimizerDef(hyper_params)[source]

Base class for optimizers.

apply_gradient(hyper_params, params, state, grads)[source]

Applies a gradient for a set of parameters.

Parameters
  • hyper_params – a named tuple of hyper parameters.

  • params – the parameters that should be updated.

  • state – a named tuple containing the state of the optimizer

  • grads – the gradient tensors for the parameters.

Returns

A tuple containing the new parameters and the new optimizer state.

apply_param_gradient(step, hyper_params, param, state, grad)[source]

Apply a gradient for a single parameter.

Parameters
  • step – the current step of the optimizer.

  • hyper_params – a named tuple of hyper parameters.

  • param – the parameter that should be updated.

  • state – a named tuple containing the state for this parameter

  • grad – the gradient tensor for the parameter.

Returns

A tuple containing the new parameter and the new state.

create(target, focus=None)[source]

Creates a new optimizer for the given target.

Parameters
  • target – the object to be optimized. This will typically be an instance of flax.nn.Model.

  • focus – a flax.traverse_util.Traversal that selects which subset of the target is optimized.

Returns

An instance of Optimizer.

init_param_state(param)[source]

Initializes the state for a parameter.

Parameters

param – the parameter for which to initialize the state.

Returns

A named tuple containing the initial optimization state for the parameter.

update_hyper_params(**hyper_param_overrides)[source]

Updates the hyper parameters with a set of overrides.

This method is called from Optimizer apply_gradient to create the hyper parameters for a specific optimization step.

Parameters

**hyper_param_overrides – the hyper parameters updates will override the defaults specified in the OptimizerDef. Pass hyper_params=… to replace all hyper parameters.

Returns

The new hyper parameters.

class flax.optim.MultiOptimizer(*traversals_and_optimizers)[source]

Combine a set of optimizers by applying each to a subset of the parameters.

class flax.optim.ModelParamTraversal(filter_fn)[source]

Select model parameters using a name filter.

Available optimizers

class flax.optim.Adam(learning_rate=None, beta1=0.9, beta2=0.999, eps=1e-08, weight_decay=0.0)[source]

Adam optimizer.

See: http://arxiv.org/abs/1412.6980

class flax.optim.Adafactor(learning_rate=None, factored=True, multiply_by_parameter_scale=True, beta1=None, decay_rate=0.8, step_offset=0, clipping_threshold=1.0, weight_decay_rate=None, min_dim_size_to_factor=128, epsilon1=1e-30, epsilon2=0.001)[source]

Adafactor optimizer.

Adafactor is described in https://arxiv.org/abs/1804.04235.

class flax.optim.Adagrad(learning_rate=None, eps=1e-08)[source]

Adagrad optimizer

class flax.optim.DynamicScale(growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, fin_steps=0, scale=65536.0)[source]

Dynamic loss scaling for mixed precision gradients.

For many models gradient computations in float16 will result in numerical issues because small/large gradients being flushed to zero/infinity. Dynamic loss scaling is an algorithm that aims to find the largest scalar multiple for which the gradient does not overflow. This way the risk of underflow is minimized.

the value_and_grad method mimicks jax.value_and_grad. Beside the loss and gradients it also ouputs and updated DynamicScale instance with the current loss scale factor. This method also returns a boolean value indicating whether the gradients are finite.

Example::
def loss_fn(p):

return jnp.asarray(p, jnp.float16) ** 2

p = jnp.array(1., jnp.float32)

dyn_scale = optim.DynamicScale(growth_interval=10) compute_grad = jax.jit(lambda ds, p: ds.value_and_grad(loss_fn)(p)) for _ in range(100):

dyn_scale, is_fin, loss, grad = compute_grad(dyn_scale, p) p += jnp.where(is_fin, 0.01 * grad, 0.) print(loss)

Jax currently cannot execute conditionals efficiently on GPUs therefore we selectifly ignore the gradient update using jax.numpy.where in case of non-finite gradients.

Attrs:
growth_factor: how much to grow the scalar after a period of finite

gradients (default: 2.).

backoff_factor: how much to shrink the scalar after a non-finite gradient

(default: 0.5).

growth_interval: after how many steps of finite gradients the scale should

be increased (default: 2000).

fin_steps: indicates how many gradient steps in a row have been finite. scale: the current scale by which the loss is multiplied.

class flax.optim.GradientDescent(learning_rate=None)[source]

Gradient descent optimizer.

class flax.optim.LAMB(learning_rate=None, beta1=0.9, beta2=0.999, weight_decay=0, eps=1e-06)[source]

Layerwise adaptive moments for batch (LAMB) optimizer.

See https://arxiv.org/abs/1904.00962

class flax.optim.LARS(learning_rate=None, beta=0.9, weight_decay=0, trust_coefficient=0.001, eps=0, nesterov=False)[source]

Layerwise adaptive rate scaling (LARS) optimizer.

See https://arxiv.org/abs/1708.03888

class flax.optim.Momentum(learning_rate=None, beta=0.9, weight_decay=0, nesterov=False)[source]

Momentum optimizer.

class flax.optim.RMSProp(learning_rate=None, beta2=0.9, eps=1e-08, centered=False)[source]

RMSProp optimizer

class flax.optim.WeightNorm(wrapped_optimizer, wn_decay=0, wn_eps=1e-08)[source]

Adds weight normalization to an optimizer def.

See https://arxiv.org/abs/1602.07868