Source code for flax.optim.base

# Copyright 2022 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Flax Optimizer api."""

import dataclasses
from typing import Any, List, Tuple, Optional
import warnings

from .. import jax_utils
from .. import serialization
from .. import struct
from .. import traverse_util

import jax
import jax.numpy as jnp

from ..core import FrozenDict, unfreeze

# Backwards compatibility symbol import.
ModelParamTraversal = traverse_util.ModelParamTraversal


@struct.dataclass
class OptimizerState:
  step: jnp.ndarray
  param_states: Any


[docs]class OptimizerDef: """Base class for an optimizer defintion, which specifies the initialization and gradient application logic. See docstring of :class:`Optimizer` for more details. """ def __init__(self, hyper_params): self.hyper_params = hyper_params warnings.warn( 'Use `optax` instead of `flax.optim`. Refer to the update guide ' 'https://flax.readthedocs.io/en/latest/howtos/optax_update_guide.html ' 'for detailed instructions.', DeprecationWarning)
[docs] def apply_param_gradient(self, step, hyper_params, param, state, grad): """Apply a gradient for a single parameter. Args: step: the current step of the optimizer. hyper_params: a named tuple of hyper parameters. param: the parameter that should be updated. state: a named tuple containing the state for this parameter grad: the gradient tensor for the parameter. Returns: A tuple containing the new parameter and the new state. """ raise NotImplementedError()
[docs] def init_param_state(self, param): """Initializes the state for a parameter. Args: param: the parameter for which to initialize the state. Returns: A named tuple containing the initial optimization state for the parameter. """ raise NotImplementedError()
[docs] def apply_gradient(self, hyper_params, params, state, grads): """Applies a gradient for a set of parameters. Args: hyper_params: a named tuple of hyper parameters. params: the parameters that should be updated. state: a named tuple containing the state of the optimizer grads: the gradient tensors for the parameters. Returns: A tuple containing the new parameters and the new optimizer state. """ step = state.step params_flat, treedef = jax.tree_flatten(params) states_flat = treedef.flatten_up_to(state.param_states) grads_flat = treedef.flatten_up_to(grads) out = [self.apply_param_gradient(step, hyper_params, param, state, grad) for param, state, grad in zip(params_flat, states_flat, grads_flat)] new_params_flat, new_states_flat = list(zip(*out)) if out else ((), ()) new_params = jax.tree_unflatten(treedef, new_params_flat) new_param_states = jax.tree_unflatten(treedef, new_states_flat) new_state = OptimizerState(step + 1, new_param_states) return new_params, new_state
def init_state(self, params): param_states = jax.tree_map(self.init_param_state, params) state = OptimizerState(jnp.asarray(0, dtype=jnp.int32), param_states) return state
[docs] def update_hyper_params(self, **hyper_param_overrides): """Updates the hyper parameters with a set of overrides. This method is called from Optimizer apply_gradient to create the hyper parameters for a specific optimization step. Args: **hyper_param_overrides: the hyper parameters updates will override the defaults specified in the `OptimizerDef`. Pass `hyper_params=...` to replace all hyper parameters. Returns: The new hyper parameters. """ hp = hyper_param_overrides.pop('hyper_params', self.hyper_params) if hyper_param_overrides: hp = hp.replace(**hyper_param_overrides) return hp
[docs] def create(self, target, focus: Optional['ModelParamTraversal'] = None): """Creates a new optimizer for the given target. See docstring of :class:`Optimizer` for more details. Args: target: the object to be optimized. This is typically a variable dict returned by `flax.linen.Module.init()`, but it can also be a container of variables dicts, e.g. `(v1, v2)` and `('var1': v1, 'var2': v2)` are valid inputs as well. focus: a `flax.traverse_util.Traversal` that selects which subset of the target is optimized. See docstring of :class:`MultiOptimizer` for an example of how to define a `Traversal` object. Returns: An instance of `Optimizer`. """ opt_def = self if focus: opt_def = MultiOptimizer((focus, opt_def)) state = opt_def.init_state(target) return Optimizer(opt_def, state, target)
def state_dict(self, target, state): return serialization.to_state_dict({ 'target': serialization.to_state_dict(target), 'state': serialization.to_state_dict(state) }) def restore_state(self, opt_target, opt_state, state_dict): """Restore the optimizer target and state from the state dict. This function accepts the current optimizer target and state. This lets us know the exact structure of the optimizer target and state, as well as lets us add assertions that shapes and dtypes don't change. In practice, no values in `opt_target` and `opt_state` are actually used. Only the tree structure, shapes and types. Args: opt_target: the optimizer target. opt_state: the optimizer state. state_dict: the state dict containing the desired new state of the optimizer. Returns: a tuple of the optimizer target and state with the restored values from the state dict. """ opt_target = serialization.from_state_dict(opt_target, state_dict['target']) opt_state = serialization.from_state_dict(opt_state, state_dict['state']) return opt_target, opt_state
class _NoAux: """Placeholder used to indicate a lack of auxilairy outputs.""" pass
[docs]class Optimizer(struct.PyTreeNode): """ Flax optimizers are created using the :class:`OptimizerDef` class. That class specifies the initialization and gradient application logic. Creating an optimizer using the :meth:`OptimizerDef.create` method will result in an instance of the :class:`Optimizer` class, which encapsulates the optimization target and state. The optimizer is updated using the method :meth:`apply_gradient`. Example of constructing an optimizer for a model:: from flax import optim optimizer_def = optim.GradientDescent(learning_rate=0.1) optimizer = optimizer_def.create(model) The optimizer is then used in a training step as follows:: def train_step(optimizer, data): def loss_fn(model): y = model(data) loss = ... # compute the loss aux = ... # compute auxiliary outputs (eg. training metrics) return loss, aux grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (loss, aux), grad = grad_fn(optimizer.target) new_optimizer = optimizer.apply_gradient(grad) return new_optimizer, loss, aux Distributed training only requires a few extra additions:: from flax import optim optimizer_def = optim.GradientDescent(learning_rate=0.1) optimizer = optimizer_def.create(model) optimizer = jax_utils.replicate(optimizer) def train_step(optimizer, data): def loss_fn(model): y = model(data) loss = ... # compute the loss aux = ... # compute auxiliary outputs (eg. training metrics) return loss, aux grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (loss, aux), grad = grad_fn(optimizer.target) grad = jax.lax.pmean(grad, 'batch') new_optimizer = optimizer.apply_gradient(grad) return new_optimizer, loss, aux distributed_train_step = jax.pmap(train_step, axis_name='batch') Attributes: optimizer_def: The optimizer definition. state: The initial state of the optimizer. target: The target to optimizer.""" optimizer_def: OptimizerDef = struct.field(pytree_node=False) state: Any = struct.field(pytree_node=True) target: Any = struct.field(pytree_node=True)
[docs] def apply_gradient(self, grads, **hyper_param_overrides): """Applies a pytree of gradients to the target. Args: grads: A pytree of gradients. **hyper_param_overrides: the hyper parameters passed to apply_gradient will override the defaults specified in the `OptimizerDef`. Pass `hyper_params=...` to replace all hyper parameters. Returns: A new optimizer with the updated target and state. """ hyper_params = self.optimizer_def.update_hyper_params( **hyper_param_overrides) new_target, new_state = self.optimizer_def.apply_gradient( hyper_params, self.target, self.state, grads) return self.replace(target=new_target, state=new_state)
def compute_gradient(self, loss_fn): """Computes gradient of loss_fn. DEPRECATION WARNING: compute_gradient() is deprecated. Use jax.grad() or jax.value_and_grad() instead. Args: loss_fn: a function that receives the target and returns a loss or a tuple of the loss and auxiliary outputs. Returns: A tuple consisting of the loss, auxiliary outputs if any, and a list of gradient. """ warnings.warn('compute_gradient() will be removed soon.' ' Use jax.grad() or jax.value_and_grad()' 'instead.', DeprecationWarning) def loss_wrapper(target): loss_and_aux = loss_fn(target) if isinstance(loss_and_aux, jnp.ndarray): return loss_and_aux, _NoAux else: return loss_and_aux grad_fn = jax.value_and_grad(loss_wrapper, has_aux=True) (loss, aux), grad = grad_fn(self.target) if aux is _NoAux: return loss, grad else: return loss, aux, grad compute_gradients = compute_gradient def optimize(self, loss_fn, **hyper_param_overrides): """Optimizes the target with respect to a loss function. DEPRECATION WARNING: optimize() is deprecated. Use jax.grad() or jax.value_and_grad() and apply_gradient() instead. Args: loss_fn: function that receives the target and returns a loss or a tuple of the loss and auxiliary outputs. **hyper_param_overrides: the hyper parameters passed to apply_gradient will override the defaults specified in the `OptimizerDef`. Pass `hyper_params=...` to replace all hyper parameters. Returns: A tuple consisting of the new optimizer, the loss, and the auxiliary outputs if any. """ warnings.warn('optimize() will be removed soon.' ' Use jax.grad() or jax.value_and_grad()' 'and apply_gradient() instead.', DeprecationWarning) output_and_grad = self.compute_gradient(loss_fn) grad = output_and_grad[-1] optimizer = self.apply_gradient(grad, **hyper_param_overrides) return (optimizer,) + output_and_grad[:-1] def replicate(self, devices=None, axis_name='batch'): """Replicates an optimizer for data parallel training. A replicated optimizer will automatically average the gradients across devices. For this to work correctly the optimize method should be called within the context of a `jax.pmap` call with the correct axis_name. DEPRECATION WARNING: replicate() is deprecated. Use jax_utils.replicate() instead. Args: devices: an optional list of devices defining which devices this optimizer is replicated to (default: all local devices). axis_name: the axis_name used for gradient averaging across devices. Returns: The replicated optimizer. """ if devices is None: devices = jax.local_devices() optimizer_def = ReplicatedOptimizer(self.optimizer_def, devices, axis_name) optimizer = jax_utils.replicate(self, devices=devices) return optimizer.replace(optimizer_def=optimizer_def) def unreplicate(self): """Un-replicates an optimizer. This will create a new optimizer with the target and state of the first device this optimizer was replicated to. After this call the optimizer and the target can be used outside of a `jax.pmap` call. DEPRECATION WARNING: unreplicate() is deprecated. Use jax_utils.unreplicate() instead. Returns: The optimizer that is no longer replicated. """ if not isinstance(self.optimizer_def, ReplicatedOptimizer): raise ValueError('Cannot unreplicate an optimizer ' 'that is not replicated.') optimizer_def = self.optimizer_def.optimizer_def optimizer = jax_utils.unreplicate(self) return optimizer.replace(optimizer_def=optimizer_def) def state_dict(self): return self.optimizer_def.state_dict(self.target, self.state) def restore_state(self, state): target, state = self.optimizer_def.restore_state( self.target, self.state, state) return self.replace(target=target, state=state)
# Optimizer serialization is handled by the state_dict and restore_dict methods # of the OptimizerDef. Currently, this is used to store only a single copy of # a replicated optimizer. serialization.register_serialization_state( Optimizer, Optimizer.state_dict, Optimizer.restore_state, override=True) class ReplicatedOptimizer(OptimizerDef): """Data parallel optimizer. DEPRECATION WARNING: ReplicatedOptimizer will be removed soon. Use `jax_utils.replicate(optimizer)` and `lax.pmean(grad)` to explicitly control the replication of the the optimizer and the cross replica averaging over gradients, respectively. """ def __init__(self, optimizer_def, devices=None, axis_name='batch'): super().__init__(optimizer_def.hyper_params) if devices is None: devices = jax.local_devices() self.optimizer_def = optimizer_def self.devices = devices self.axis_name = axis_name def init_state(self, params): return self.optimizer_def.init_state(params) def _cross_replica_mean(self, grad): axis_size = jax.lax.psum(1, axis_name=self.axis_name) return jax.lax.psum(grad, axis_name=self.axis_name) / axis_size def apply_gradient(self, hyper_params, params, state, grads): grads = jax.tree_map(self._cross_replica_mean, grads) return self.optimizer_def.apply_gradient(hyper_params, params, state, grads) def update_hyper_params(self, **hyper_param_overrides): return self.optimizer_def.update_hyper_params(**hyper_param_overrides) def state_dict(self, target, state): state_dict = self.optimizer_def.state_dict(target, state) # only the first copy of the parameters and optimizer state are stored. state_dict = jax.tree_map(lambda x: x[0], state_dict) return state_dict def restore_state(self, target, opt_state, state_dict): # replicate the parameters and state to all devices. state_dict = jax_utils.replicate(state_dict, devices=self.devices) return self.optimizer_def.restore_state(target, opt_state, state_dict) @dataclasses.dataclass class _ShapeDtype: shape: Any dtype: Any _value: Any _indices: List[int] @classmethod def create(cls, value): if not isinstance(value, jnp.ndarray): value = jnp.array(value) return cls(shape=value.shape, dtype=value.dtype, _value=value, _indices=[])
[docs]class MultiOptimizer(OptimizerDef): """ A MultiOptimizer is subclass of :class:`OptimizerDef` and useful for applying separate optimizer algorithms to various subsets of the model parameters. The example below creates two optimizers using :class:`flax.traverse_util.ModelParamTraversal`: one to optimize ``kernel`` parameters and to optimize ``bias`` parameters. Note each optimizer is created with a different learning rate:: kernels = traverse_util.ModelParamTraversal(lambda path, _: 'kernel' in path) biases = traverse_util.ModelParamTraversal(lambda path, _: 'bias' in path) kernel_opt = optim.Momentum(learning_rate=0.01) bias_opt = optim.Momentum(learning_rate=0.1) opt_def = MultiOptimizer((kernels, kernel_opt), (biases, bias_opt)) optimizer = opt_def.create(model) In order to train only a subset of the parameters, you can simply use a single :class:`flax.traverse_util.ModelParamTraversal` instance. If you want to update the learning rates of both optimizers online with different learning rate schedules, you should update the learning rates when applying the gradient. In the following example, the second optimizer is not doing any optimization during the first 1000 steps:: hparams = optimizer.optimizer_def.hyper_params new_optimizer = optimizer.apply_gradient( grads, hyper_params=[ hparams[0].replace(learning_rate=0.2), hparams[1].replace(learning_rate=jnp.where(step < 1000, 0., lr)), ]) """ def __init__( self, *traversals_and_optimizers: Tuple[traverse_util.Traversal, OptimizerDef]): """Create a new MultiOptimizer. See docstring of :class:`MultiOptimizer` for more details. Args: *traversals_and_optimizers: pairs of flax.traverse_util.Traversal and `flax.optim.OptimizerDef` instances. """ traversals, sub_optimizers = zip(*traversals_and_optimizers) hyper_params = [opt.hyper_params for opt in sub_optimizers] super().__init__(hyper_params) self.traversals = traversals self.sub_optimizers = sub_optimizers def init_state(self, params): param_states = jax.tree_map(_ShapeDtype.create, params) overlap = False for idx, (traversal, opt) in enumerate(zip(self.traversals, self.sub_optimizers)): for match in traversal.iterate(param_states): match._indices.append(idx) overlap |= len(match._indices) > 1 if overlap: raise ValueError( 'Multiple optimizers match the same leaves : ' + str(jax.tree_map(lambda match: match._indices, param_states))) for traversal, opt in zip(self.traversals, self.sub_optimizers): param_states = traversal.update(lambda x: opt.init_param_state(x._value), param_states) # Use None as initial state for params that are not optimized by any sub optimizer. param_states = jax.tree_map(lambda x: None if isinstance(x, _ShapeDtype) else x, param_states) return OptimizerState(jnp.asarray(0, dtype=jnp.int32), param_states) def apply_gradient(self, hyper_params, params, state, grads): new_params = params it = zip(self.traversals, self.sub_optimizers, hyper_params) new_param_states = jax.tree_map(_ShapeDtype.create, params) for focus, opt, hp in it: ps = tuple(focus.iterate(params)) gs = tuple(focus.iterate(grads)) ss = tuple(focus.iterate(state.param_states)) prev_ss = OptimizerState(state.step, ss) new_ps, new_ss = opt.apply_gradient(hp, ps, prev_ss, gs) new_params = focus.set(list(new_ps), new_params) new_param_states = focus.set(list(new_ss.param_states), new_param_states) # Update state to None when param is not optimized by any sub optimizer. new_param_states = jax.tree_map(lambda x: None if isinstance(x, _ShapeDtype) else x, new_param_states) return new_params, OptimizerState(state.step + 1, new_param_states)
[docs] def update_hyper_params(self, **hyper_param_overrides): """Updates the hyper parameters with a set of overrides. This method is called from :meth:`Optimizer.apply_gradient` to create the hyper parameters for a specific optimization step. MultiOptimizer will apply the overrides for each sub optimizer. Args: **hyper_param_overrides: the hyper parameters updates will override the defaults specified in the `OptimizerDef`. Pass `hyper_params=...` to replace all hyper parameters. Returns: The new hyper parameters. """ hps = hyper_param_overrides.pop('hyper_params', self.hyper_params) if hyper_param_overrides: hps = [hp.replace(**hyper_param_overrides) for hp in hps] return hps