Source code for flax.training.lr_schedule

# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Learning rate schedules used in FLAX image classification examples.

Note that with `FLIP #1009`_ learning rate schedules in ``flax.training`` are
**effectively deprecated** in favor of Optax_ schedules. Please refer to
`Optimizer Schedules`_ for more information.

.. _FLIP #1009: https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md
.. _Optax: https://github.com/deepmind/optax
.. _Optimizer Schedules: https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules
"""

import jax.numpy as jnp
import numpy as np
from absl import logging


def _piecewise_constant(boundaries, values, t):
  index = jnp.sum(boundaries < t)
  return jnp.take(values, index)


[docs]def create_constant_learning_rate_schedule( base_learning_rate, steps_per_epoch, warmup_length=0.0 ): """Create a constant learning rate schedule with optional warmup. Note that with `FLIP #1009`_ learning rate schedules in ``flax.training`` are **effectively deprecated** in favor of Optax_ schedules. Please refer to `Optimizer Schedules`_ for more information. .. _FLIP #1009: https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md .. _Optax: https://github.com/deepmind/optax .. _Optimizer Schedules: https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules Holds the learning rate constant. This function also offers a learing rate warmup as per https://arxiv.org/abs/1706.02677, for the purpose of training with large mini-batches. Args: base_learning_rate: the base learning rate steps_per_epoch: the number of iterations per epoch warmup_length: if > 0, the learning rate will be modulated by a warmup factor that will linearly ramp-up from 0 to 1 over the first ``warmup_length`` epochs Returns: Function ``f(step) -> lr`` that computes the learning rate for a given step. """ logging.warning( 'Learning rate schedules in ``flax.training`` are effectively deprecated ' 'in favor of Optax schedules. Please refer to ' 'https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules' ' for alternatives.' ) def learning_rate_fn(step): lr = base_learning_rate if warmup_length > 0.0: lr = lr * jnp.minimum(1.0, step / float(warmup_length) / steps_per_epoch) return lr return learning_rate_fn
[docs]def create_stepped_learning_rate_schedule( base_learning_rate, steps_per_epoch, lr_sched_steps, warmup_length=0.0 ): """Create a stepped learning rate schedule with optional warmup. Note that with `FLIP #1009`_ learning rate schedules in ``flax.training`` are **effectively deprecated** in favor of Optax_ schedules. Please refer to `Optimizer Schedules`_ for more information. .. _FLIP #1009: https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md .. _Optax: https://github.com/deepmind/optax .. _Optimizer Schedules: https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules A stepped learning rate schedule decreases the learning rate by specified amounts at specified epochs. The steps are given as the ``lr_sched_steps`` parameter. A common ImageNet schedule decays the learning rate by a factor of 0.1 at epochs 30, 60 and 80. This would be specified as:: [ [30, 0.1], [60, 0.01], [80, 0.001] ] This function also offers a learing rate warmup as per https://arxiv.org/abs/1706.02677, for the purpose of training with large mini-batches. Args: base_learning_rate: the base learning rate steps_per_epoch: the number of iterations per epoch lr_sched_steps: the schedule as a list of steps, each of which is a ``[epoch, lr_factor]`` pair; the step occurs at epoch ``epoch`` and sets the learning rate to ``base_learning_rage * lr_factor`` warmup_length: if > 0, the learning rate will be modulated by a warmup factor that will linearly ramp-up from 0 to 1 over the first ``warmup_length`` epochs Returns: Function ``f(step) -> lr`` that computes the learning rate for a given step. """ logging.warning( 'Learning rate schedules in ``flax.training`` are effectively deprecated ' 'in favor of Optax schedules. Please refer to ' 'https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules' ' for alternatives.' ) boundaries = [step[0] for step in lr_sched_steps] decays = [step[1] for step in lr_sched_steps] boundaries = np.array(boundaries) * steps_per_epoch boundaries = np.round(boundaries).astype(int) values = np.array([1.0] + decays) * base_learning_rate def learning_rate_fn(step): lr = _piecewise_constant(boundaries, values, step) if warmup_length > 0.0: lr = lr * jnp.minimum(1.0, step / float(warmup_length) / steps_per_epoch) return lr return learning_rate_fn
[docs]def create_cosine_learning_rate_schedule( base_learning_rate, steps_per_epoch, halfcos_epochs, warmup_length=0.0 ): """Create a cosine learning rate schedule with optional warmup. Note that with `FLIP #1009`_ learning rate schedules in ``flax.training`` are **effectively deprecated** in favor of Optax_ schedules. Please refer to `Optimizer Schedules`_ for more information. .. _FLIP #1009: https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md .. _Optax: https://github.com/deepmind/optax .. _Optimizer Schedules: https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules A cosine learning rate schedule modules the learning rate with half a cosine wave, gradually scaling it to 0 at the end of training. This function also offers a learing rate warmup as per https://arxiv.org/abs/1706.02677, for the purpose of training with large mini-batches. Args: base_learning_rate: the base learning rate steps_per_epoch: the number of iterations per epoch halfcos_epochs: the number of epochs to complete half a cosine wave; normally the number of epochs used for training warmup_length: if > 0, the learning rate will be modulated by a warmup factor that will linearly ramp-up from 0 to 1 over the first ``warmup_length`` epochs Returns: Function ``f(step) -> lr`` that computes the learning rate for a given step. """ logging.warning( 'Learning rate schedules in ``flax.training`` are effectively deprecated ' 'in favor of Optax schedules. Please refer to ' 'https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules' ' for alternatives.' ) halfwavelength_steps = halfcos_epochs * steps_per_epoch def learning_rate_fn(step): scale_factor = jnp.cos(step * jnp.pi / halfwavelength_steps) * 0.5 + 0.5 lr = base_learning_rate * scale_factor if warmup_length > 0.0: lr = lr * jnp.minimum(1.0, step / float(warmup_length) / steps_per_epoch) return lr return learning_rate_fn