# flax.optim package¶

## Optimizer core¶

Flax Optimizer api.

class flax.optim.Optimizer(optimizer_def, state, target)[source]

Wraps an optimizer with its hyper_params, state, and model parameters.

apply_gradient(grads, **hyper_param_overrides)[source]

Applies a pytree of gradients to the target.

Parameters

• **hyper_param_overrides – the hyper parameters passed to apply_gradient will override the defaults specified in the OptimizerDef. Pass hyper_params=… to replace all hyper parameters.

Returns

A new optimizer with the updated target and state.

compute_gradients(loss_fn)

Parameters

loss_fn – a function that receives the target and returns a loss or a tuple of the loss and auxiliary outputs.

Returns

A tuple consisting of the loss, auxiliary outputs if any,

optimize(loss_fn, **hyper_param_overrides)[source]

Optimizes the target with respect to a loss function.

Parameters
• loss_fn – function that receives the target and returns a loss or a tuple of the loss and auxiliary outputs.

• **hyper_param_overrides – the hyper parameters passed to apply_gradient will override the defaults specified in the OptimizerDef. Pass hyper_params=… to replace all hyper parameters.

Returns

A tuple consisting of the new optimizer, the loss,

and the auxiliary outputs if any.

class flax.optim.OptimizerDef(hyper_params)[source]

Base class for optimizers.

apply_gradient(hyper_params, params, state, grads)[source]

Applies a gradient for a set of parameters.

Parameters
• hyper_params – a named tuple of hyper parameters.

• params – the parameters that should be updated.

• state – a named tuple containing the state of the optimizer

Returns

A tuple containing the new parameters and the new optimizer state.

apply_param_gradient(step, hyper_params, param, state, grad)[source]

Apply a gradient for a single parameter.

Parameters
• step – the current step of the optimizer.

• hyper_params – a named tuple of hyper parameters.

• param – the parameter that should be updated.

• state – a named tuple containing the state for this parameter

Returns

A tuple containing the new parameter and the new state.

create(target, focus=None)[source]

Creates a new optimizer for the given target.

Parameters
• target – the object to be optimized. This will typically be an instance of flax.nn.Model.

• focus – a flax.traverse_util.Traversal that selects which subset of the target is optimized.

Returns

An instance of Optimizer.

init_param_state(param)[source]

Initializes the state for a parameter.

Parameters

param – the parameter for which to initialize the state.

Returns

A named tuple containing the initial optimization state for the parameter.

update_hyper_params(**hyper_param_overrides)[source]

Updates the hyper parameters with a set of overrides.

This method is called from Optimizer apply_gradient to create the hyper parameters for a specific optimization step.

Parameters

**hyper_param_overrides – the hyper parameters updates will override the defaults specified in the OptimizerDef. Pass hyper_params=… to replace all hyper parameters.

Returns

The new hyper parameters.

class flax.optim.MultiOptimizer(*traversals_and_optimizers)[source]

Combine a set of optimizers by applying each to a subset of the parameters.

class flax.optim.ModelParamTraversal(filter_fn)[source]

Select model parameters using a name filter.

## Available optimizers¶

class flax.optim.Adam(learning_rate=None, beta1=0.9, beta2=0.999, eps=1e-08, weight_decay=0.0)[source]

class flax.optim.Adafactor(learning_rate=None, factored=True, multiply_by_parameter_scale=True, beta1=None, decay_rate=0.8, step_offset=0, clipping_threshold=1.0, weight_decay_rate=None, min_dim_size_to_factor=128, epsilon1=1e-30, epsilon2=0.001)[source]

class flax.optim.Adagrad(learning_rate=None, eps=1e-08)[source]

class flax.optim.DynamicScale(growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, fin_steps=0, scale=65536.0)[source]

Dynamic loss scaling for mixed precision gradients.

For many models gradient computations in float16 will result in numerical issues because small/large gradients being flushed to zero/infinity. Dynamic loss scaling is an algorithm that aims to find the largest scalar multiple for which the gradient does not overflow. This way the risk of underflow is minimized.

the value_and_grad method mimicks jax.value_and_grad. Beside the loss and gradients it also ouputs and updated DynamicScale instance with the current loss scale factor. This method also returns a boolean value indicating whether the gradients are finite.

Example::
def loss_fn(p):

return jnp.asarray(p, jnp.float16) ** 2

p = jnp.array(1., jnp.float32)

dyn_scale = optim.DynamicScale(growth_interval=10) compute_grad = jax.jit(lambda ds, p: ds.value_and_grad(loss_fn)(p)) for _ in range(100):

Jax currently cannot execute conditionals efficiently on GPUs therefore we selectifly ignore the gradient update using jax.numpy.where in case of non-finite gradients.

Attrs:
growth_factor: how much to grow the scalar after a period of finite

backoff_factor: how much to shrink the scalar after a non-finite gradient

(default: 0.5).

growth_interval: after how many steps of finite gradients the scale should

be increased (default: 2000).

fin_steps: indicates how many gradient steps in a row have been finite. scale: the current scale by which the loss is multiplied.

class flax.optim.GradientDescent(learning_rate=None)[source]

class flax.optim.LAMB(learning_rate=None, beta1=0.9, beta2=0.999, weight_decay=0, eps=1e-06)[source]

Layerwise adaptive moments for batch (LAMB) optimizer.

class flax.optim.LARS(learning_rate=None, beta=0.9, weight_decay=0, trust_coefficient=0.001, eps=0, nesterov=False)[source]

Layerwise adaptive rate scaling (LARS) optimizer.

class flax.optim.Momentum(learning_rate=None, beta=0.9, weight_decay=0, nesterov=False)[source]

Momentum optimizer.

class flax.optim.RMSProp(learning_rate=None, beta2=0.9, eps=1e-08, centered=False)[source]

RMSProp optimizer

class flax.optim.WeightNorm(wrapped_optimizer, wn_decay=0, wn_eps=1e-08)[source]

Adds weight normalization to an optimizer def.