# Learning Rate Scheduling¶

The learning rate is considered one of the most important hyperparameters for
training deep neural networks, but choosing it can be quite hard.
To simplify this, one can use a so-called *cyclic learning rate*, which
virtually eliminates the need for experimentally finding the best value and
schedule for the global learning rate. Instead of monotonically decreasing the
learning rate, this method lets the learning rate cyclically vary between
reasonable boundary values.
Here we will show you how to implement a triangular learning rate scheduler,
as described in the paper “Cyclical Learning Rates for Training Neural Networks”.

We will show you how to…

define a learning rate schedule

train a simple model using that schedule

The triangular schedule makes your learning rate vary as a triangle wave during training, so over the course of a period (`steps_per_cycle`

training steps) the value will start at `lr_min`

, increase linearly to `lr_max`

, and then decrease again to `lr_min`

.

```
def create_triangular_schedule(lr_min, lr_max, steps_per_cycle):
top = (steps_per_cycle + 1) // 2
def learning_rate_fn(step):
cycle_step = step % steps_per_cycle
if cycle_step < top:
lr = lr_min + cycle_step/top * (lr_max - lr_min)
else:
lr = lr_max - ((cycle_step - top)/top) * (lr_max - lr_min)
return lr
return learning_rate_fn
```

To use the schedule, one must create a learning rate function by passing the hyperparameters to the create_triangular_schedule function and then use that function to compute the learning rate for your updates. For example using this schedule on MNIST would require changing the train_step function

Default learning rate |
Triangular learning rate schedule |

```
@jax.jit
def train_step(optimizer, batch):
def loss_fn(params):
logits = CNN().apply({'params': params}, batch['image'])
loss = cross_entropy_loss(logits, batch['label'])
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, logits), grad = grad_fn(optimizer.target)
optimizer = optimizer.apply_gradient(grad)
metrics = compute_metrics(logits, batch['label'])
return optimizer, metrics
``` |
```
@jax.jit
def train_step(optimizer, batch, learning_rate_fn):
def loss_fn(params):
logits = CNN().apply({'params': params}, batch['image'])
loss = cross_entropy_loss(logits, batch['label'])
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, logits), grad = grad_fn(optimizer.target)
step = optimizer.state.step
lr = learning_rate_fn(step)
optimizer = optimizer.apply_gradient(grad, {"learning_rate": lr})
metrics = compute_metrics(logits, batch['label'])
return optimizer, metrics
``` |

And the train_epoch function:

Default learning rate |
Triangular learning rate schedule |

```
def train_epoch(optimizer, train_ds, batch_size, epoch, rng):
"""Train for a single epoch."""
train_ds_size = len(train_ds['image'])
steps_per_epoch = train_ds_size // batch_size
perms = jax.m random.permutation(rng, len(train_ds['image']))
perms = perms[:steps_per_epoch * batch_size]
perms = perms.reshape((steps_per_epoch, batch_size))
batch_metrics = []
for perm in perms:
batch = {k: v[perm, ...] for k, v in train_ds.items()}
optimizer, metrics = train_step(optimizer, batch)
batch_metrics.append(metrics)
# compute mean of metrics across each batch in epoch.
batch_metrics = jax.device_get(batch_metrics)
epoch_metrics = {
k: np.mean([metrics[k] for metrics in batch_metrics])
for k in batch_metrics[0]}
logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
epoch_metrics['loss'], epoch_metrics['accuracy'] * 100)
return optimizer, epoch_metrics
``` |
```
def train_epoch(optimizer, train_ds, batch_size, epoch, rng):
"""Train for a single epoch."""
train_ds_size = len(train_ds['image'])
steps_per_epoch = train_ds_size // batch_size
# 4 cycles per epoch
learning_rate_fn = create_triangular_schedule(
3e-3, 3e-2, steps_per_epoch // 4)
perms = jax.random.permutation(rng, len(train_ds['image']))
perms = perms[:steps_per_epoch * batch_size]
perms = perms.reshape((steps_per_epoch, batch_size))
batch_metrics = []
for perm in perms:
batch = {k: v[perm, ...] for k, v in train_ds.items()}
optimizer, metrics = train_step(optimizer, batch, learning_rate_fn)
batch_metrics.append(metrics)
# compute mean of metrics across each batch in epoch.
batch_metrics = jax.device_get(batch_metrics)
epoch_metrics = {
k: np.mean([metrics[k] for metrics in batch_metrics])
for k in batch_metrics[0]}
logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
epoch_metrics['loss'], epoch_metrics['accuracy'] * 100)
return optimizer, epoch_metrics
``` |