Open in Colab Open On GitHub

Getting Started#

This tutorial demonstrates how to construct a simple convolutional neural network (CNN) using the Flax Linen API and train the network for image classification on the MNIST dataset.

Note: This notebook is based on Flax’s official MNIST Example. If you see any changes between the two feel free to create a pull request to synchronize this Colab with the actual example.

# Check GPU
!nvidia-smi -L
GPU 0: Tesla T4 (UUID: GPU-d0ec4f05-c4a8-7952-556f-b668f22fe9c7)

1. Imports#

Import JAX, JAX NumPy, Flax, ordinary NumPy, and TensorFlow Datasets (TFDS). Flax can use any data-loading pipeline and this example demonstrates how to utilize TFDS.

!pip install -q flax
import jax
import jax.numpy as jnp                # JAX NumPy

from flax import linen as nn           # The Linen API
from import train_state  # Useful dataclass to keep train state

import numpy as np                     # Ordinary NumPy
import optax                           # Optimizers
import tensorflow_datasets as tfds     # TFDS for MNIST

2. Define network#

Create a convolutional neural network with the Linen API by subclassing Module. Because the architecture in this example is relatively simple—you’re just stacking layers—you can define the inlined submodules directly within the __call__ method and wrap it with the @compact decorator.

class CNN(nn.Module):
  """A simple CNN model."""

  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    return x

3. Define loss#

We simply use optax.softmax_cross_entropy(). Note that this function expects both logits and labels to have shape [batch, num_classes]. Since the labels will be read from TFDS as integer values, we first need to convert them to a onehot encoding.

Our function returns a simple scalar value ready for optimization, so we first take the mean of the vector shaped [batch] returned by Optax’s loss function.

def cross_entropy_loss(*, logits, labels):
  labels_onehot = jax.nn.one_hot(labels, num_classes=10)
  return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()

4. Metric computation#

For loss and accuracy metrics, create a separate function:

def compute_metrics(*, logits, labels):
  loss = cross_entropy_loss(logits=logits, labels=labels)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy,
  return metrics

5. Loading data#

Define a function that loads and prepares the MNIST dataset and converts the samples to floating-point numbers.

def get_datasets():
  """Load MNIST train and test datasets into memory."""
  ds_builder = tfds.builder('mnist')
  train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
  test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
  train_ds['image'] = jnp.float32(train_ds['image']) / 255.
  test_ds['image'] = jnp.float32(test_ds['image']) / 255.
  return train_ds, test_ds

6. Create train state#

A common pattern in Flax is to create a single dataclass that represents the entire training state, including step number, parameters, and optimizer state.

Also adding optimizer & model to this state has the advantage that we only need to pass around a single argument to functions like train_step() (see below).

Because this is such a common pattern, Flax provides the class that serves most basic usecases. Usually one would subclass it to add more data to be tracked, but in this example we can use it without any modifications.

def create_train_state(rng, learning_rate, momentum):
  """Creates initial `TrainState`."""
  cnn = CNN()
  params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
  tx = optax.sgd(learning_rate, momentum)
  return train_state.TrainState.create(
      apply_fn=cnn.apply, params=params, tx=tx)

7. Training step#

A function that:

  • Evaluates the neural network given the parameters and a batch of input images with the Module.apply method.

  • Computes the cross_entropy_loss loss function.

  • Evaluates the gradient of the loss function using jax.grad.

  • Applies a pytree of gradients to the optimizer to update the model’s parameters.

  • Computes the metrics using compute_metrics (defined earlier).

Use JAX’s @jit decorator to trace the entire train_step function and just-in-time compile it with XLA into fused device operations that run faster and more efficiently on hardware accelerators.

def train_step(state, batch):
  """Train for a single step."""
  def loss_fn(params):
    logits = CNN().apply({'params': params}, batch['image'])
    loss = cross_entropy_loss(logits=logits, labels=batch['label'])
    return loss, logits
  grad_fn = jax.grad(loss_fn, has_aux=True)
  grads, logits = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  metrics = compute_metrics(logits=logits, labels=batch['label'])
  return state, metrics

8. Evaluation step#

Create a function that evaluates your model on the test set with Module.apply

def eval_step(params, batch):
  logits = CNN().apply({'params': params}, batch['image'])
  return compute_metrics(logits=logits, labels=batch['label'])

9. Train function#

Define a training function that:

  • Shuffles the training data before each epoch using jax.random.permutation that takes a PRNGKey as a parameter (check the JAX - the sharp bits).

  • Runs an optimization step for each batch.

  • Retrieves the training metrics from the device with jax.device_get and computes their mean across each batch in an epoch.

  • Returns the optimizer with updated parameters and the training loss and accuracy metrics.

def train_epoch(state, train_ds, batch_size, epoch, rng):
  """Train for a single epoch."""
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, train_ds_size)
  perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
  perms = perms.reshape((steps_per_epoch, batch_size))
  batch_metrics = []
  for perm in perms:
    batch = {k: v[perm, ...] for k, v in train_ds.items()}
    state, metrics = train_step(state, batch)

  # compute mean of metrics across each batch in epoch.
  batch_metrics_np = jax.device_get(batch_metrics)
  epoch_metrics_np = {
      k: np.mean([metrics[k] for metrics in batch_metrics_np])
      for k in batch_metrics_np[0]}

  print('train epoch: %d, loss: %.4f, accuracy: %.2f' % (
      epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100))

  return state

10. Eval function#

Create a model evaluation function that:

  • Retrieves the evaluation metrics from the device with jax.device_get.

  • Copies the metrics data stored in a JAX pytree.

def eval_model(params, test_ds):
  metrics = eval_step(params, test_ds)
  metrics = jax.device_get(metrics)
  summary = jax.tree_util.tree_map(lambda x: x.item(), metrics)
  return summary['loss'], summary['accuracy']

11. Download data#

train_ds, test_ds = get_datasets()
WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/tensorflow_datasets/core/ get_single_element (from is deprecated and will be removed in a future version.
Instructions for updating:
Use ``.
WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/tensorflow_datasets/core/ get_single_element (from is deprecated and will be removed in a future version.
Instructions for updating:
Use ``.

12. Seed randomness#

rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

13. Initialize train state#

Remember that function initializes both the model parameters and the optimizer and puts both into the training state dataclass that is returned.

learning_rate = 0.1
momentum = 0.9
state = create_train_state(init_rng, learning_rate, momentum)
del init_rng  # Must not be used anymore.

14. Train and evaluate#

Once the training and testing is done after 10 epochs, the output should show that your model was able to achieve approximately 99% accuracy.

num_epochs = 10
batch_size = 32
for epoch in range(1, num_epochs + 1):
  # Use a separate PRNG key to permute image data during shuffling
  rng, input_rng = jax.random.split(rng)
  # Run an optimization step over a training batch
  state = train_epoch(state, train_ds, batch_size, epoch, input_rng)
  # Evaluate on the test set after each training epoch
  test_loss, test_accuracy = eval_model(state.params, test_ds)
  print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (
      epoch, test_loss, test_accuracy * 100))
train epoch: 1, loss: 0.1398, accuracy: 95.77
 test epoch: 1, loss: 0.05, accuracy: 98.59
train epoch: 2, loss: 0.0488, accuracy: 98.50
 test epoch: 2, loss: 0.05, accuracy: 98.41
train epoch: 3, loss: 0.0347, accuracy: 98.96
 test epoch: 3, loss: 0.04, accuracy: 98.79
train epoch: 4, loss: 0.0242, accuracy: 99.26
 test epoch: 4, loss: 0.04, accuracy: 99.04
train epoch: 5, loss: 0.0212, accuracy: 99.37
 test epoch: 5, loss: 0.03, accuracy: 99.00
train epoch: 6, loss: 0.0170, accuracy: 99.49
 test epoch: 6, loss: 0.03, accuracy: 99.09
train epoch: 7, loss: 0.0114, accuracy: 99.63
 test epoch: 7, loss: 0.03, accuracy: 99.08
train epoch: 8, loss: 0.0093, accuracy: 99.71
 test epoch: 8, loss: 0.03, accuracy: 99.11
train epoch: 9, loss: 0.0109, accuracy: 99.65
 test epoch: 9, loss: 0.04, accuracy: 99.08
train epoch: 10, loss: 0.0103, accuracy: 99.68
 test epoch: 10, loss: 0.03, accuracy: 99.01

Congrats! You made it to the end of the annotated MNIST example. You can revisit the same example, but structured differently as a couple of Python modules, test modules, config files, another Colab, and documentation in Flax’s Git repo: