Open In Colab Open On GitHub

Annotated MNIST

This tutorial demonstrates how to construct a simple convolutional neural network (CNN) using the Flax Linen API and train the network for image classification on the MNIST dataset.

Note: This notebook is based on Flax’s official MNIST Example. If you see any changes between the two feel free to create a pull request to synchronize this Colab with the actual example.

1. Imports

Import JAX, JAX NumPy, Flax, ordinary NumPy, and TensorFlow Datasets (TFDS). Flax can use any data-loading pipeline and this example demonstrates how to utilize TFDS.

!pip install -q flax
import jax
import jax.numpy as jnp                # JAX NumPy

from flax import linen as nn           # The Linen API
from import train_state  # Useful dataclass to keep train state

import numpy as np                     # Ordinary NumPy
import optax                           # Optimizers
import tensorflow_datasets as tfds     # TFDS for MNIST

2. Define network

Create a convolutional neural network with the Linen API by subclassing `Module <>`__. Because the architecture in this example is relatively simple—you’re just stacking layers—you can define the inlined submodules directly within the __call__ method and wrap it with the `@compact <>`__ decorator.

class CNN(nn.Module):
  """A simple CNN model."""

  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    x = nn.log_softmax(x)
    return x

3. Define loss

Define a cross-entropy loss function using just `jax.numpy <>`__ that takes the model’s logits and label vectors and returns a scalar loss. The labels can be one-hot encoded with `jax.nn.one_hot <>`__, as demonstrated below.

Note that for demonstration purposes, we return nn.log_softmax() from the model and then simply multiply these (normalized) logits with the labels. In our examples/mnist folder we actually return non-normalized logits and then use optax.softmax_cross_entropy() to compute the loss, which has the same result.

def cross_entropy_loss(*, logits, labels):
  one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
  return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))

4. Metric computation

For loss and accuracy metrics, create a separate function:

def compute_metrics(*, logits, labels):
  loss = cross_entropy_loss(logits=logits, labels=labels)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy,
  return metrics

5. Loading data

Define a function that loads and prepares the MNIST dataset and converts the samples to floating-point numbers.

def get_datasets():
  """Load MNIST train and test datasets into memory."""
  ds_builder = tfds.builder('mnist')
  train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
  test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
  train_ds['image'] = jnp.float32(train_ds['image']) / 255.
  test_ds['image'] = jnp.float32(test_ds['image']) / 255.
  return train_ds, test_ds

6. Create train state

A common pattern in Flax is to create a single dataclass that represents the entire training state, including step number, parameters, and optimizer state.

Also adding optimizer & model to this state has the advantage that we only need to pass around a single argument to functions like train_step() (see below).

Because this is such a common pattern, Flax provides the class that serves most basic usecases. Usually one would subclass it to add more data to be tracked, but in this example we can use it without any modifications.

def create_train_state(rng, learning_rate, momentum):
  """Creates initial `TrainState`."""
  cnn = CNN()
  params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
  tx = optax.sgd(learning_rate, momentum)
  return train_state.TrainState.create(
      apply_fn=cnn.apply, params=params, tx=tx)

7. Training step

A function that:

Use JAX’s `@jit <>`__ decorator to trace the entire train_step function and just-in-time compile it with XLA into fused device operations that run faster and more efficiently on hardware accelerators.

def train_step(state, batch):
  """Train for a single step."""
  def loss_fn(params):
    logits = CNN().apply({'params': params}, batch['image'])
    loss = cross_entropy_loss(logits=logits, labels=batch['label'])
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, logits), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  metrics = compute_metrics(logits=logits, labels=batch['label'])
  return state, metrics

8. Evaluation step

Create a function that evaluates your model on the test set with `Module.apply <>`__

def eval_step(params, batch):
  logits = CNN().apply({'params': params}, batch['image'])
  return compute_metrics(logits=logits, labels=batch['label'])

9. Train function

Define a training function that:

def train_epoch(state, train_ds, batch_size, epoch, rng):
  """Train for a single epoch."""
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, train_ds_size)
  perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
  perms = perms.reshape((steps_per_epoch, batch_size))
  batch_metrics = []
  for perm in perms:
    batch = {k: v[perm, ...] for k, v in train_ds.items()}
    state, metrics = train_step(state, batch)

  # compute mean of metrics across each batch in epoch.
  batch_metrics_np = jax.device_get(batch_metrics)
  epoch_metrics_np = {
      k: np.mean([metrics[k] for metrics in batch_metrics_np])
      for k in batch_metrics_np[0]}

  print('train epoch: %d, loss: %.4f, accuracy: %.2f' % (
      epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100))

  return state

10. Eval function

Create a model evaluation function that:

  • Retrieves the evaluation metrics from the device with jax.device_get.

  • Copies the metrics data stored in a JAX pytree.

def eval_model(params, test_ds):
  metrics = eval_step(params, test_ds)
  metrics = jax.device_get(metrics)
  summary = jax.tree_map(lambda x: x.item(), metrics)
  return summary['loss'], summary['accuracy']

11. Download data

train_ds, test_ds = get_datasets()

12. Seed randomness

rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

13. Initialize train state

Remember that function initializes both the model parameters and the optimizer and puts both into the training state dataclass that is returned.

learning_rate = 0.1
momentum = 0.9
state = create_train_state(init_rng, learning_rate, momentum)
del init_rng  # Must not be used anymore.

14. Train and evaluate

Once the training and testing is done after 10 epochs, the output should show that your model was able to achieve approximately 99% accuracy.

num_epochs = 10
batch_size = 32
for epoch in range(1, num_epochs + 1):
  # Use a separate PRNG key to permute image data during shuffling
  rng, input_rng = jax.random.split(rng)
  # Run an optimization step over a training batch
  state = train_epoch(state, train_ds, batch_size, epoch, input_rng)
  # Evaluate on the test set after each training epoch
  test_loss, test_accuracy = eval_model(state.params, test_ds)
  print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (
      epoch, test_loss, test_accuracy * 100))
train epoch: 1, loss: 0.1293, accuracy: 96.04
 test epoch: 1, loss: 0.06, accuracy: 97.93
train epoch: 2, loss: 0.0509, accuracy: 98.43
 test epoch: 2, loss: 0.04, accuracy: 98.88
train epoch: 3, loss: 0.0332, accuracy: 99.04
 test epoch: 3, loss: 0.03, accuracy: 98.87
train epoch: 4, loss: 0.0230, accuracy: 99.30
 test epoch: 4, loss: 0.04, accuracy: 98.83
train epoch: 5, loss: 0.0207, accuracy: 99.37
 test epoch: 5, loss: 0.05, accuracy: 98.78
train epoch: 6, loss: 0.0178, accuracy: 99.44
 test epoch: 6, loss: 0.03, accuracy: 99.11
train epoch: 7, loss: 0.0142, accuracy: 99.52
 test epoch: 7, loss: 0.04, accuracy: 98.77
train epoch: 8, loss: 0.0137, accuracy: 99.60
 test epoch: 8, loss: 0.04, accuracy: 98.87
train epoch: 9, loss: 0.0096, accuracy: 99.69
 test epoch: 9, loss: 0.04, accuracy: 98.98
train epoch: 10, loss: 0.0108, accuracy: 99.67
 test epoch: 10, loss: 0.04, accuracy: 99.01

Congrats! You made it to the end of the annotated MNIST example. You can revisit the same example, but structured differently as a couple of Python modules, test modules, config files, another Colab, and documentation in Flax’s Git repo: