https://colab.research.google.com/assets/colab-badge.svg

Upgrading my Codebase to Optax

We have proposed to replace flax.optim with Optax in 2021 with FLIP #1009 and the Flax optimizers are now effectively deprecated. This guide is targeted towards flax.optim users to help them update their code to Optax.

See also Optax’s quick start documentation: https://optax.readthedocs.io/en/latest/optax-101.html

Replacing flax.optim with optax

Optax has drop-in replacements for all of Flax’s optimizers. Refer to Optax’s documentation Common Optimizers for API details.

The usage is very similar, with the difference that optax does not keep a copy of the params, so they need to be passed around separately. Flax provides the utility TrainState to store optimizer state, parameters, and other associated data in a single dataclass (not used in code below).

flax.optim

optax

@jax.jit
def train_step(optimizer, batch):
  grads = jax.grad(loss)(optimizer.target, batch)


  return optimizer.apply_gradient(grads)

optimizer_def = flax.optim.Momentum(
    learning_rate, momentum)
optimizer = optimizer_def.create(variables['params'])

for batch in get_ds_train():
  optimizer = train_step(optimizer, batch)
@jax.jit
def train_step(params, opt_state, batch):
  grads = jax.grad(loss)(params, batch)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return params, opt_state

tx = optax.sgd(learning_rate, momentum)
params = variables['params']
opt_state = tx.init(params)

for batch in ds_train:
  params, opt_state = train_step(params, opt_state, batch)

Composable Gradient Transformations

The function optax.sgd() used in the code snippet above is simply a wrapper for the sequential application of two gradient transformations. Instead of using this alias, it is common to use optax.chain() to combine multiple of these generic building blocks.

Pre-defined alias

Combining transformations

# Note that the aliases follow the convention to use positive
# values for the learning rate by default.
tx = optax.sgd(learning_rate, momentum)
#

tx = optax.chain(
    # 1. Step: keep a trace of past updates and add to gradients.
    optax.trace(decay=momentum),
    # 2. Step: multiply result from step 1 with negative learning rate.
    # Note that `optax.apply_updates()` simply adds the final updates to the
    # parameters, so we must make sure to flip the sign here for gradient
    # descent.
    optax.scale(-learning_rate),
)

Weight Decay

Some of Flax’s optimizers also include a weight decay. In Optax, some optimizers also have a weight decay parameter (such as optax.adamw()), and to others the weight decay can be added as another “gradient transformation” optax.add_decayed_weights() that adds an update derived from the parameters.

flax.optim

optax

optimizer_def = flax.optim.Adam(
    learning_rate, weight_decay=weight_decay)
optimizer = optimizer_def.create(variables['params'])
# (Note that you could also use `optax.adamw()` in this case)
tx = optax.chain(
    optax.scale_by_adam(),
    optax.add_decayed_weights(weight_decay),
    # params -= learning_rate * (adam(grads) + params * weight_decay)
    optax.scale(-learning_rate),
)
# Note that you'll need to specify `params` when computing the udpates:
# tx.update(grads, opt_state, params)

Gradient Clipping

Training can be stabilized by clipping gradients to a global norm (Pascanu et al, 2012). In Flax this is often done by processing the gradients before passing them to the optimizer. With Optax this becomes just another gradient transformation optax.clip_by_global_norm().

flax.optim

optax

def train_step(optimizer, batch):
  grads = jax.grad(loss)(optimizer.target, batch)
  grads_flat, _ = jax.tree_flatten(grads)
  global_l2 = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads_flat]))
  g_factor = jnp.minimum(1.0, grad_clip_norm / global_l2)
  grads = jax.tree_map(lambda g: g * g_factor, grads)
  return optimizer.apply_gradient(grads)
tx = optax.chain(
    optax.clip_by_global_norm(grad_clip_norm),
    optax.trace(decay=momentum),
    optax.scale(-learning_rate),
)

Learning Rate Schedules

For learning rate schedules, Flax allows overwriting hyper parameters when applying the gradients. Optax maintains a step counter and provides this as an argument to a function for scaling the updates added with optax.scale_by_schedule(). Optax also allows specifying a functions to inject arbitrary scalar values for other gradient updates via optax.inject_hyperparams().

Read more about learning rate schedules in the Learning Rate Scheduling guide.

Read more about schedules defined in Optax under Optimizer Schedules. the standard optimizers (like optax.adam(), optax.sgd() etc.) also accept a learning rate schedule as a parameter for learning_rate.

flax.optim

optax

def train_step(step, optimizer, batch):
  grads = jax.grad(loss)(optimizer.target, batch)
  return step + 1, optimizer.apply_gradient(grads, learning_rate=schedule(step))
tx = optax.chain(
    optax.trace(decay=momentum),
    # Note that we still want a negative value for scaling the updates!
    optax.scale_by_schedule(lambda step: -schedule(step)),
)

Multiple Optimizers / Updating a Subset of Parameters

In Flax, traversals are used to specify which parameters should be updated by an optimizer. And you can combine traversals using flax.optim.MultiOptimizer to apply different optimizers on different parameters. The equivalent in Optax is optax.masked() and optax.chain().

Note that the example below is using flax.traverse_util to create the boolean masks required by optax.masked() - alternatively you could also create them manually, or use optax.multi_transform() that takes a multivalent pytree to specify gradient transformations.

Beware that optax.masked() flattens the pytree internally and the inner gradient transformations will only be called with that partial flattened view of the params/gradients. This is not a problem usually, but it makes it hard to nest multiple levels of masked gradient transformations (because the inner masks will expect the mask to be defined in terms of the partial flattened view that is not readily available outside the outer mask).

flax.optim

optax

kernels = flax.traverse_util.ModelParamTraversal(lambda p, _: 'kernel' in p)
biases = flax.traverse_util.ModelParamTraversal(lambda p, _: 'bias' in p)

kernel_opt = flax.optim.Momentum(learning_rate, momentum)
bias_opt = flax.optim.Momentum(learning_rate * 0.1, momentum)


optimizer = flax.optim.MultiOptimizer(
    (kernels, kernel_opt),
    (biases, bias_opt)
).create(variables['params'])
kernels = flax.traverse_util.ModelParamTraversal(lambda p, _: 'kernel' in p)
biases = flax.traverse_util.ModelParamTraversal(lambda p, _: 'bias' in p)

all_false = jax.tree_map(lambda _: False, params)
kernels_mask = kernels.update(lambda _: True, all_false)
biases_mask = biases.update(lambda _: True, all_false)

tx = optax.chain(
    optax.trace(decay=momentum),
    optax.masked(optax.scale(-learning_rate), kernels_mask),
    optax.masked(optax.scale(-learning_rate * 0.1), biases_mask),
)

Final Words

All above patterns can of course also be mixed and Optax makes it possible to encapsulate all these transformations into a single place outside the main training loop, which makes testing much easier.