# Learning Rate Scheduling¶

The learning rate is considered one of the most important hyperparameters for training deep neural networks, but choosing it can be quite hard. To simplify this, one can use a so-called cyclic learning rate, which virtually eliminates the need for experimentally finding the best value and schedule for the global learning rate. Instead of monotonically decreasing the learning rate, this method lets the learning rate cyclically vary between reasonable boundary values. Here we will show you how to implement a triangular learning rate scheduler, as described in the paper “Cyclical Learning Rates for Training Neural Networks”.

We will show you how to…

• define a learning rate schedule

• train a simple model using that schedule

The triangular schedule makes your learning rate vary as a triangle wave during training, so over the course of a period (`steps_per_cycle` training steps) the value will start at `lr_min`, increase linearly to `lr_max`, and then decrease again to `lr_min`.

```def create_triangular_schedule(lr_min, lr_max, steps_per_cycle):
top = (steps_per_cycle + 1) // 2
def learning_rate_fn(step):
cycle_step = step % steps_per_cycle
if cycle_step < top:
lr = lr_min + cycle_step/top * (lr_max - lr_min)
else:
lr = lr_max - ((cycle_step - top)/top) * (lr_max - lr_min)
return lr
return learning_rate_fn
```

To use the schedule, one must create a learning rate function by passing the hyperparameters to the create_triangular_schedule function and then use that function to compute the learning rate for your updates. For example using this schedule on MNIST would require changing the train_step function

 Default learning rate Triangular learning rate schedule ```@jax.jit def train_step(optimizer, batch): def loss_fn(params): logits = CNN().apply({'params': params}, batch['image']) loss = cross_entropy_loss(logits, batch['label']) return loss, logits grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (_, logits), grad = grad_fn(optimizer.target) optimizer = optimizer.apply_gradient(grad) metrics = compute_metrics(logits, batch['label']) return optimizer, metrics ``` ```@jax.jit def train_step(optimizer, batch, learning_rate_fn): def loss_fn(params): logits = CNN().apply({'params': params}, batch['image']) loss = cross_entropy_loss(logits, batch['label']) return loss, logits grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (_, logits), grad = grad_fn(optimizer.target) step = optimizer.state.step lr = learning_rate_fn(step) optimizer = optimizer.apply_gradient(grad, {"learning_rate": lr}) metrics = compute_metrics(logits, batch['label']) return optimizer, metrics ```

And the train_epoch function:

 Default learning rate Triangular learning rate schedule ```def train_epoch(optimizer, train_ds, batch_size, epoch, rng): """Train for a single epoch.""" train_ds_size = len(train_ds['image']) steps_per_epoch = train_ds_size // batch_size perms = jax.m random.permutation(rng, len(train_ds['image'])) perms = perms[:steps_per_epoch * batch_size] perms = perms.reshape((steps_per_epoch, batch_size)) batch_metrics = [] for perm in perms: batch = {k: v[perm, ...] for k, v in train_ds.items()} optimizer, metrics = train_step(optimizer, batch) batch_metrics.append(metrics) # compute mean of metrics across each batch in epoch. batch_metrics = jax.device_get(batch_metrics) epoch_metrics = { k: np.mean([metrics[k] for metrics in batch_metrics]) for k in batch_metrics} logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch, epoch_metrics['loss'], epoch_metrics['accuracy'] * 100) return optimizer, epoch_metrics ``` ```def train_epoch(optimizer, train_ds, batch_size, epoch, rng): """Train for a single epoch.""" train_ds_size = len(train_ds['image']) steps_per_epoch = train_ds_size // batch_size # 4 cycles per epoch learning_rate_fn = create_triangular_schedule( 3e-3, 3e-2, steps_per_epoch // 4) perms = jax.random.permutation(rng, len(train_ds['image'])) perms = perms[:steps_per_epoch * batch_size] perms = perms.reshape((steps_per_epoch, batch_size)) batch_metrics = [] for perm in perms: batch = {k: v[perm, ...] for k, v in train_ds.items()} optimizer, metrics = train_step(optimizer, batch, learning_rate_fn) batch_metrics.append(metrics) # compute mean of metrics across each batch in epoch. batch_metrics = jax.device_get(batch_metrics) epoch_metrics = { k: np.mean([metrics[k] for metrics in batch_metrics]) for k in batch_metrics} logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch, epoch_metrics['loss'], epoch_metrics['accuracy'] * 100) return optimizer, epoch_metrics ```