Open in Colab Open On GitHub

MNIST Tutorial#

Welcome to Flax NNX! This tutorial will guide you through building and training a simple convolutional neural network (CNN) on the MNIST dataset using the Flax NNX API. Flax NNX is a Python neural network library built upon JAX and currently offered as an experimental module within Flax.

1. Install Flax#

If flax is not installed in your environment, you can install it from PyPI, uncomment and run the following cell:

# !pip install flax

2. Load the MNIST Dataset#

First, the MNIST dataset is loaded and prepared for training and testing using Tensorflow Datasets. Image values are normalized, the data is shuffled and divided into batches, and samples are prefetched to enhance performance.

import tensorflow_datasets as tfds  # TFDS for MNIST
import tensorflow as tf  # TensorFlow operations

tf.random.set_seed(0)  # set random seed for reproducibility

train_steps = 1200
eval_every = 200
batch_size = 32

train_ds: tf.data.Dataset = tfds.load('mnist', split='train')
test_ds: tf.data.Dataset = tfds.load('mnist', split='test')

train_ds = train_ds.map(
  lambda sample: {
    'image': tf.cast(sample['image'], tf.float32) / 255,
    'label': sample['label'],
  }
)  # normalize train set
test_ds = test_ds.map(
  lambda sample: {
    'image': tf.cast(sample['image'], tf.float32) / 255,
    'label': sample['label'],
  }
)  # normalize test set

# create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from
train_ds = train_ds.repeat().shuffle(1024)
# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency
train_ds = train_ds.batch(batch_size, drop_remainder=True).take(train_steps).prefetch(1)
# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency
test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)
/usr/local/google/home/cgarciae/flax/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
2024-07-10 15:24:11.227958: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-07-10 15:24:12.227896: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT

3. Define the Network with Flax NNX#

Create a convolutional neural network with Flax NNX by subclassing nnx.Module.

from flax import nnx  # Flax NNX API
from functools import partial

class CNN(nnx.Module):
  """A simple CNN model."""

  def __init__(self, *, rngs: nnx.Rngs):
    self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
    self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
    self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
    self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
    self.linear2 = nnx.Linear(256, 10, rngs=rngs)

  def __call__(self, x):
    x = self.avg_pool(nnx.relu(self.conv1(x)))
    x = self.avg_pool(nnx.relu(self.conv2(x)))
    x = x.reshape(x.shape[0], -1)  # flatten
    x = nnx.relu(self.linear1(x))
    x = self.linear2(x)
    return x

model = CNN(rngs=nnx.Rngs(0))
nnx.display(model)
(Loading...)

Run model#

Let’s put our model to the test! We’ll perform a forward pass with arbitrary data and print the results.

import jax.numpy as jnp  # JAX NumPy

y = model(jnp.ones((1, 28, 28, 1)))
nnx.display(y)
(Loading...)

4. Create Optimizer and Metrics#

In Flax NNX, we create an Optimizer object to manage the model’s parameters and apply gradients during training. Optimizer receives the model’s reference so it can update its parameters, and an optax optimizer to define the update rules. Additionally, we’ll define a MultiMetric object to keep track of the Accuracy and the Average loss.

import optax

learning_rate = 0.005
momentum = 0.9

optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum))
metrics = nnx.MultiMetric(
  accuracy=nnx.metrics.Accuracy(),
  loss=nnx.metrics.Average('loss'),
)

nnx.display(optimizer)
(Loading...)

5. Define step functions#

We define a loss function using cross entropy loss (see more details in optax.softmax_cross_entropy_with_integer_labels()) that our model will optimize over. In addition to the loss, the logits are also outputted since they will be used to calculate the accuracy metric during training and testing. During training, we’ll use nnx.value_and_grad to compute the gradients and update the model’s parameters using the optimizer. During both training and testing, the loss and logits are used to calculate the metrics.

def loss_fn(model: CNN, batch):
  logits = model(batch['image'])
  loss = optax.softmax_cross_entropy_with_integer_labels(
    logits=logits, labels=batch['label']
  ).mean()
  return loss, logits

@nnx.jit
def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
  """Train for a single step."""
  grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(model, batch)
  metrics.update(loss=loss, logits=logits, labels=batch['label'])  # inplace updates
  optimizer.update(grads)  # inplace updates

@nnx.jit
def eval_step(model: CNN, metrics: nnx.MultiMetric, batch):
  loss, logits = loss_fn(model, batch)
  metrics.update(loss=loss, logits=logits, labels=batch['label'])  # inplace updates

The nnx.jit decorator traces the train_step function for just-in-time compilation with XLA, optimizing performance on hardware accelerators. nnx.jit is similar to jax.jit, except it can transforms functions that contain Flax NNX objects as inputs and outputs.

NOTE: in the above code we performed serveral inplace updates to the model, optimizer, and metrics, and we did not explicitely return the state updates. This is because Flax NNX transforms respect reference semantics for Flax NNX objects, and will propagate the state updates of the objects passed as input arguments. This is a key feature of Flax NNX that allows for a more concise and readable code.

6. Train and Evaluate#

Now we train a model using batches of data for 10 epochs, evaluate its performance on the test set after each epoch, and log the training and testing metrics (loss and accuracy) throughout the process. Typically this leads to a model with around 99% accuracy.

metrics_history = {
  'train_loss': [],
  'train_accuracy': [],
  'test_loss': [],
  'test_accuracy': [],
}

for step, batch in enumerate(train_ds.as_numpy_iterator()):
  # Run the optimization for one step and make a stateful update to the following:
  # - the train state's model parameters
  # - the optimizer state
  # - the training loss and accuracy batch metrics
  train_step(model, optimizer, metrics, batch)

  if step > 0 and (step % eval_every == 0 or step == train_steps - 1):  # one training epoch has passed
    # Log training metrics
    for metric, value in metrics.compute().items():  # compute metrics
      metrics_history[f'train_{metric}'].append(value)  # record metrics
    metrics.reset()  # reset metrics for test set

    # Compute metrics on the test set after each training epoch
    for test_batch in test_ds.as_numpy_iterator():
      eval_step(model, metrics, test_batch)

    # Log test metrics
    for metric, value in metrics.compute().items():
      metrics_history[f'test_{metric}'].append(value)
    metrics.reset()  # reset metrics for next training epoch

    print(
      f"[train] step: {step}, "
      f"loss: {metrics_history['train_loss'][-1]}, "
      f"accuracy: {metrics_history['train_accuracy'][-1] * 100}"
    )
    print(
      f"[test] step: {step}, "
      f"loss: {metrics_history['test_loss'][-1]}, "
      f"accuracy: {metrics_history['test_accuracy'][-1] * 100}"
    )
2024-07-10 15:24:26.290421: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
[train] step: 200, loss: 0.3102289140224457, accuracy: 90.08084869384766
[test] step: 200, loss: 0.13239526748657227, accuracy: 95.52284240722656
2024-07-10 15:24:32.398018: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
[train] step: 400, loss: 0.12522409856319427, accuracy: 96.515625
[test] step: 400, loss: 0.07021520286798477, accuracy: 97.8465576171875
2024-07-10 15:24:38.439548: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
[train] step: 600, loss: 0.09092658758163452, accuracy: 97.25
[test] step: 600, loss: 0.08268354833126068, accuracy: 97.30569458007812
2024-07-10 15:24:44.516602: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
[train] step: 800, loss: 0.07523862272500992, accuracy: 97.921875
[test] step: 800, loss: 0.060881033539772034, accuracy: 98.036865234375
2024-07-10 15:24:50.557494: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
[train] step: 1000, loss: 0.063808374106884, accuracy: 98.09375
[test] step: 1000, loss: 0.07719086110591888, accuracy: 97.4258804321289
2024-07-10 15:24:54.450444: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
[train] step: 1199, loss: 0.07750937342643738, accuracy: 97.47173309326172
[test] step: 1199, loss: 0.05415954813361168, accuracy: 98.32732391357422
2024-07-10 15:24:56.610632: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-07-10 15:24:56.615182: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence

7. Visualize Metrics#

Use Matplotlib to create plots for loss and accuracy.

import matplotlib.pyplot as plt  # Visualization

# Plot loss and accuracy in subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
ax1.set_title('Loss')
ax2.set_title('Accuracy')
for dataset in ('train', 'test'):
  ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')
  ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')
ax1.legend()
ax2.legend()
plt.show()
_images/d36a49b3d50ea3065ae6ac199fb1947b9020a24529772a8984f8286715da991e.png

10. Perform inference on test set#

Define a jitted inference function, pred_step, to generate predictions on the test set using the learned model parameters. This will enable you to visualize test images alongside their predicted labels for a qualitative assessment of model performance.

@nnx.jit
def pred_step(model: CNN, batch):
  logits = model(batch['image'])
  return logits.argmax(axis=1)
test_batch = test_ds.as_numpy_iterator().next()
pred = pred_step(model, test_batch)

fig, axs = plt.subplots(5, 5, figsize=(12, 12))
for i, ax in enumerate(axs.flatten()):
  ax.imshow(test_batch['image'][i, ..., 0], cmap='gray')
  ax.set_title(f'label={pred[i]}')
  ax.axis('off')
_images/489b1b86b5d822076b751c8f480bb3815838366937bbc4837bf06947daf42380.png

Congratulations! You made it to the end of the annotated MNIST example.