Example: Train a Vision Transformer (ViT) for image classification#

This example guides you through developing and training a Vision Transformer (ViT) model using Flax NNX. The architecture is based on “An Image is Worth 16x16 Words” by Dosovitskiy et al. (2020). This example shows how to define a ViT model using Flax NNX, load the pretrained ImageNet weights from the ViT transformer weights of google/vit-base-patch16-224 on HuggingFace, which was pretrained on ImageNet-21k, and then fine-tune on the Food 101 dataset for image classification using data (batch dimension) and model (MLP layers) sharding and mixed precision. We will also check the results for consistency with the reference model.

This example is adapted from the JAX AI Stack tutorial Train a Vision Transformer (ViT) for image classification with JAX. The original JAX-based implementation of the ViT model can be found in the google-research/vision_transformer GitHub repository.

Setup#

This example uses HuggingFace Datasets for dataset loading, TorchVision for image augmentations, grain for efficient data loading, tqdm for a progress bar to monitor training, and matplotlib for visualization purposes. These libraries can be installed with !pip install -U datasets grain torchvision tqdm matplotlib transformers.

Start by importing JAX, JAX NumPy, Flax NNX, and Optax:

import jax
import jax.numpy as jnp
from flax import nnx
import optax

The ViT architecture#

A Vision Transformer (ViT) treats images as sequences of patches and leverages the attention mechanism from transformers. The architecture consists of the following key components:

  • Patch and position embedding: Breaking down an image into fixed-size patches and embedding each patch into a vector representation. Positional embeddings are added to encode the position of each patch within the original image, which aids with spatial information.

  • Transformer encoder: A stack of transformer encoder blocks processes the input embedded patches. Each block consists of:

    • Multi-Head (Self-)Attention: This allows the model to weigh the importance of different patches relative to each other, capturing relationships within the image.

    • Feed-forward network: Processes each patch independently, allowing a for non-linear transformations.

    • Layer normatlization and residual connections: Stabilize training and improve gradient flow in the network.

  • Classification head: The output of the transformer encoder is fed into a linear layer and then a softmax function, resulting in class probabilities for prediction.

ViT-architecture

Defining the model with Flax NNX#

from dataclasses import dataclass

from jax.sharding import PartitionSpec as P


@dataclass(slots=True, frozen=True)
class ShardingConfig:

    attn_qkvo_weight_ndh: P | None = None  # sharding for Q, K, V, Out weights
    mlp_weight_df: P | None = None
    mlp_weight_fd: P | None = None
    act_btd: P | None = None  # sharding of the activation (B, T, D)
    act_btf: P | None = None
    act_btnh: P | None = None
    act_bc: P | None = None  # sharding of the final logits

    fsdp_axis_name: str = "fsdp"

    @staticmethod
    def no_sharding():
        return ShardingConfig()

    @staticmethod
    def fsdp_sharding(fsdp_axis_name: str = "fsdp"):
        fsdp = fsdp_axis_name
        return ShardingConfig(
            attn_qkvo_weight_ndh=P(None, fsdp, None),
            mlp_weight_df=P(fsdp, None),
            mlp_weight_fd=P(None, fsdp),
            act_btd=P(fsdp, None, None),
            act_btf=P(fsdp, None, None),
            act_btnh=P(fsdp, None, None, None),
            act_bc=P(fsdp, None),
            fsdp_axis_name=fsdp_axis_name,
        )


@dataclass(slots=True, frozen=True)
class ModelConfig:
    num_classes: int = 1000
    in_channels: int = 3
    img_size: int = 224
    patch_size: int = 16
    num_layers: int = 12
    num_heads: int = 12
    mlp_dim: int = 3072
    hidden_size: int = 768
    dropout_rate: float = 0.1
    sharding: ShardingConfig = ShardingConfig.no_sharding()
    use_bfloat16: bool = True


class VisionTransformer(nnx.Module):
    def __init__(
        self,
        config: ModelConfig,
        *,
        rngs: nnx.Rngs,
    ):
        act_dtype = jnp.bfloat16 if config.use_bfloat16 else jnp.float32
        param_dtype = jnp.float32

        n_patches = (config.img_size // config.patch_size) ** 2
        self.patch_embeddings = nnx.Conv(
            config.in_channels,
            config.hidden_size,
            kernel_size=(config.patch_size, config.patch_size),
            strides=(config.patch_size, config.patch_size),
            padding="VALID",
            use_bias=True,
            dtype=act_dtype,  # activations dtype
            param_dtype=param_dtype,  # weights dtype
            rngs=rngs,
        )

        initializer = jax.nn.initializers.truncated_normal(stddev=0.02)
        self.position_embeddings = nnx.Param(
            initializer(rngs.params(), (1, n_patches + 1, config.hidden_size), param_dtype)
        ) # Shape `(1, n_patches +1, hidden_size`)
        self.dropout = nnx.Dropout(config.dropout_rate)

        self.cls_token = nnx.Param(jnp.zeros((1, 1, config.hidden_size), dtype=param_dtype))
        self.encoder = nnx.Sequential(*[
            TransformerEncoder(config, rngs=rngs) for i in range(config.num_layers)
        ])
        self.final_norm = nnx.LayerNorm(
            config.hidden_size, rngs=rngs, dtype=act_dtype, param_dtype=param_dtype
        )
        self.classifier = nnx.Linear(
            config.hidden_size, 
            config.num_classes, 
            rngs=rngs, 
            dtype=act_dtype, 
            param_dtype=param_dtype,
        )
        self.config = config

    def embed(self, x: jax.Array) -> jax.Array:
        patches = self.patch_embeddings(x, out_sharding=self.config.sharding.act_btd)
        batch_size = patches.shape[0]
        patches = patches.reshape(batch_size, -1, patches.shape[-1])
        cls_token = jnp.tile(self.cls_token.astype(x.dtype), (batch_size, 1, 1))
        if self.config.sharding.act_btd is not None:
            cls_token = jax.device_put(cls_token, device=self.config.sharding.act_btd)
        x = jnp.concat([cls_token, patches], axis=1)
        return x + self.position_embeddings.astype(x.dtype)

    def __call__(self, x: jax.Array, rngs: nnx.Rngs | None = None) -> jax.Array:
        x = self.embed(x)
        x = self.dropout(x, rngs=rngs)
        x = self.encoder(x, rngs=rngs)
        x = self.final_norm(x)
        x = x[:, 0]
        return self.classifier(x, out_sharding=self.config.sharding.act_bc)


class TransformerEncoder(nnx.Module):
    def __init__(
        self,
        config: ModelConfig,
        *,
        rngs: nnx.Rngs,
    ) -> None:
        act_dtype = jnp.bfloat16 if config.use_bfloat16 else jnp.float32
        param_dtype = jnp.float32

        self.norm1 = nnx.LayerNorm(
            config.hidden_size, rngs=rngs, dtype=act_dtype, param_dtype=param_dtype
        )
        self.mha = nnx.MultiHeadAttention(
            num_heads=config.num_heads,
            in_features=config.hidden_size,
            dropout_rate=config.dropout_rate,
            broadcast_dropout=False,
            decode=False,
            deterministic=False,
            kernel_metadata={"out_sharding": config.sharding.attn_qkvo_weight_ndh},
            out_kernel_metadata={"out_sharding": config.sharding.attn_qkvo_weight_ndh},
            keep_rngs=False,
            dtype=act_dtype, 
            param_dtype=param_dtype,
            rngs=rngs,
        )
        self.norm2 = nnx.LayerNorm(config.hidden_size, rngs=rngs)
        self.mlp_up_proj = nnx.Linear(
            config.hidden_size,
            config.mlp_dim,
            kernel_metadata={"out_sharding": config.sharding.mlp_weight_df},
            dtype=act_dtype, 
            param_dtype=param_dtype,
            rngs=rngs,
        )
        self.mlp_down_proj = nnx.Linear(
            config.mlp_dim,
            config.hidden_size,
            kernel_metadata={"out_sharding": config.sharding.mlp_weight_fd},
            dtype=act_dtype, 
            param_dtype=param_dtype,
            rngs=rngs
        )
        self.mlp_drop = nnx.Dropout(config.dropout_rate, rngs=rngs)
        self.config = config

    def attn(self, x: jax.Array, rngs: nnx.Rngs | None = None) -> jax.Array:
        return self.mha(
            x,
            rngs=rngs,
            out_sharding=self.config.sharding.act_btd,
            qkv_sharding=self.config.sharding.act_btnh,
        )

    def mlp(self, x: jax.Array, rngs: nnx.Rngs | None = None) -> jax.Array:
        x = self.mlp_up_proj(x, out_sharding=self.config.sharding.act_btf)
        x = nnx.gelu(x)
        x = self.mlp_drop(x, rngs=rngs)
        x = self.mlp_down_proj(x, out_sharding=self.config.sharding.act_btd)
        return self.mlp_drop(x, rngs=rngs)

    def __call__(self, x: jax.Array, rngs: nnx.Rngs | None = None) -> jax.Array:
        x = x + self.attn(self.norm1(x), rngs=rngs)
        x = x + self.mlp(self.norm2(x), rngs=rngs)
        return x


# We can define and check a model without sharding:
config = ModelConfig()
act_dtype = jnp.bfloat16 if config.use_bfloat16 else jnp.float32
x = jnp.ones((4, 224, 224, 3), dtype=act_dtype)
model = VisionTransformer(config, rngs=nnx.Rngs(1))
y = model(x, rngs=nnx.Rngs(0))
print("Predictions shape: ", jax.typeof(y))
del model, y, x

# We can define and check a model with fsdp-like sharding:
mesh = jax.make_mesh((jax.device_count(),), ("fsdp",))
with jax.set_mesh(mesh):
    config = ModelConfig(sharding=ShardingConfig.fsdp_sharding(fsdp_axis_name="fsdp"))
    act_dtype = jnp.bfloat16 if config.use_bfloat16 else jnp.float32
    x = jnp.ones((4, 224, 224, 3), out_sharding=jax.P("fsdp"), dtype=act_dtype)
    model = VisionTransformer(config, rngs=nnx.Rngs(1))
    y = model(x, rngs=nnx.Rngs(0))
    print("Predictions shape: ", jax.typeof(y))
    del model, y, x
Predictions shape:  bfloat16[4,1000]
Predictions shape:  bfloat16[4@fsdp,1000]

Loading the pretrained weights#

In this section, we’ll load the weights pretrained on the ImageNet dataset using HuggingFace’s transformers library.

First, import transformers.ViTForImageClassification - a ViT Model transformer with an image classification head on top.

Then, load the weights of google/vit-base-patch16-224 - a ViT model pretrained on ImageNet-21k at the 224x224 resolution - from HuggingFace.

We’ll also check whether we have consistent results with the reference model.

from transformers import ViTForImageClassification

tf_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')

# Initialize abstract NNX model to avoid memory allocation with random weights
with jax.set_mesh(mesh):
    config = ModelConfig(sharding=ShardingConfig.fsdp_sharding(fsdp_axis_name="fsdp"))
    abs_model = nnx.eval_shape(lambda: VisionTransformer(config, rngs=nnx.Rngs(0)))
/tmp/venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.
Loading weights: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 5662.01it/s]
# Copies weights from the Transformers ViT model to the NNX ViT model, reshaping layers
# to match the expected shapes in Flax.
def vit_copy_weights(
    *,
    src_model: ViTForImageClassification,
    dst_model: VisionTransformer,
    rngs_seed: int = 0
) -> VisionTransformer:

    assert isinstance(src_model, ViTForImageClassification)
    assert isinstance(dst_model, VisionTransformer)
    num_layers = dst_model.config.num_layers
    num_heads = dst_model.config.num_heads
    head_dim = dst_model.config.hidden_size // num_heads
    tf_model_state = src_model.state_dict()

    # Notice the use of `flax.nnx.state`.
    flax_model_params = nnx.state(dst_model, nnx.Param)
    flax_model_params_fstate = dict(nnx.to_flat_state(flax_model_params))

    # Mapping from Flax parameter names to TF parameter names.
    params_name_mapping = {
        ("cls_token",): ("vit", "embeddings", "cls_token"),
        ("position_embeddings",): ("vit", "embeddings", "position_embeddings"),
        **{
            ("patch_embeddings", x[0]): ("vit", "embeddings", "patch_embeddings", "projection", x[1])
            for x in [("kernel", "weight"), ("bias", "bias")]
        },
        **{
            ("encoder", "layers", i, "mha", y[0], x[0]): (
                "vit", "layers", str(i), "attention", y[1], x[1]
            )
            for x in [("kernel", "weight"), ("bias", "bias")]
            for y in [("key", "k_proj"), ("value", "v_proj"), ("query", "q_proj")]
            for i in range(num_layers)
        },
        **{
            ("encoder", "layers", i, "mha", "out", x[0]): (
                "vit", "layers", str(i), "attention", "o_proj", x[1]
            )
            for x in [("kernel", "weight"), ("bias", "bias")]
            for i in range(num_layers)
        },
        **{
            ("encoder", "layers", i, y1, x[0]): (
                "vit", "layers", str(i), "mlp", y2, x[1]
            )
            for x in [("kernel", "weight"), ("bias", "bias")]
            for y1, y2 in [("mlp_up_proj", "fc1"), ("mlp_down_proj", "fc2")]
            for i in range(num_layers)
        },
        **{
            ("encoder", "layers", i, y1, x[0]): (
                "vit", "layers", str(i), y2, x[1]
            )
            for x in [("scale", "weight"), ("bias", "bias")]
            for y1, y2 in [("norm1", "layernorm_before"), ("norm2", "layernorm_after")]
            for i in range(num_layers)
        },
        **{
            ("final_norm", x[0]): ("vit", "layernorm", x[1])
            for x in [("scale", "weight"), ("bias", "bias")]
        },
        **{
            ("classifier", x[0]): ("classifier", x[1])
            for x in [("kernel", "weight"), ("bias", "bias")]
        }
    }

    nonvisited = set(tf_model_state.keys())

    for key1, key2 in params_name_mapping.items():
        key2_str = ".".join(key2)
        assert key1 in flax_model_params_fstate, key1
        assert key2_str in tf_model_state, (key1, key2_str, tf_model_state.keys())

        nonvisited.remove(key2_str)

        src_value = tf_model_state[key2_str]
        if key2[-1] == "weight" and len(key2) >= 3 and key2[-3] == "patch_embeddings":
            assert src_value.ndim == 4
            src_value = src_value.permute(2, 3, 1, 0)

        if key2[-1] == "weight" and key2[-2] in ("k_proj", "v_proj", "q_proj"):
            assert src_value.ndim == 2
            src_value = src_value.permute(1, 0)
            src_value = src_value.reshape(src_value.shape[0], num_heads, head_dim)

        if key2[-1] == "weight" and key2[-2] in ("o_proj", "fc1", "fc2", "classifier"):
            assert src_value.ndim == 2
            src_value = src_value.permute(1, 0)
            if key2[-3:] == ("attention", "o_proj", "weight"):
                src_value = src_value.reshape(num_heads, head_dim, src_value.shape[-1])

        if key2[-1] == "bias" and key2[-2] in ("k_proj", "v_proj", "q_proj"):
            assert src_value.ndim == 1
            src_value = src_value.reshape(num_heads, head_dim)

        dst_value = flax_model_params_fstate[key1]
        assert src_value.shape == dst_value.shape, (key2, src_value.shape, key1, dst_value.shape)
        dst_value.set_value(jnp.asarray(src_value))
        assert dst_value[...].mean() == jnp.asarray(src_value).mean(), (dst_value[...].mean(), src_value.mean())

    assert len(nonvisited) == 0, nonvisited
    nnx.update(dst_model, nnx.from_flat_state(flax_model_params_fstate))

    # finally let's reseed the stochastic layers
    nnx.reseed(dst_model, default=rngs_seed)

    return dst_model


with jax.set_mesh(mesh):
    model = vit_copy_weights(src_model=tf_model, dst_model=abs_model)

Verifying image prediction#

Load a sample image from a URL, perform inference, and compare the predictions to verify the weight transfer:

import torch
import matplotlib.pyplot as plt
from transformers import ViTImageProcessor
from PIL import Image
import requests

url = "https://github.com/pytorch/vision/blob/main/gallery/assets/dog1.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw)

processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')

inputs = processor(images=image, return_tensors="pt")
tf_model.eval()
with torch.no_grad():
    outputs = tf_model(**inputs)
    logits = outputs.logits.cpu().numpy()

model.eval()
with jax.set_mesh(mesh):
    x = jnp.transpose(jnp.asarray(inputs["pixel_values"]), axes=(0, 2, 3, 1))
    # As model is sharded with fsdp it expects the input with batch dim sharded by num of available devices
    x = jnp.concat([x] * jax.device_count(), axis=0)
    output = model(x)
    output = jax.sharding.reshard(output, jax.P())[:1].astype(jnp.float32)

# Model predicts one of the 1000 ImageNet classes.
assert jnp.abs(logits[0, :] - output[0, :]).max() < 0.1

ref_class_idx = logits.argmax(-1).item()
pred_class_idx = output.argmax(-1).item()
fig, axs = plt.subplots(1, 2, figsize=(12, 8))
axs[0].set_title(
    f"Reference model:\n{tf_model.config.id2label[ref_class_idx]}\nP={nnx.softmax(logits, axis=-1)[0, ref_class_idx]:.3f}"
)
axs[0].imshow(image)
axs[1].set_title(
    f"Our model:\n{tf_model.config.id2label[pred_class_idx]}\nP={nnx.softmax(output, axis=-1)[0, pred_class_idx]:.3f}"
)
axs[1].imshow(image)
<matplotlib.image.AxesImage at 0x7442b45c1070>
../_images/3207cd8359cff9e2aff4a1b57fb9bc7a1fd71806c9d7d99e80e5e788a2305bf6.png

Replace the classifier with a smaller fully-connected layer returning 20 classes instead of 1000:

act_dtype = jnp.bfloat16 if config.use_bfloat16 else jnp.float32

with jax.set_mesh(mesh):
    model.classifier = nnx.Linear(
        model.classifier.in_features, 
        20, 
        rngs=nnx.Rngs(0),
        dtype=act_dtype,
        param_dtype=jnp.float32,
    )


with jax.set_mesh(mesh):
    model.train()    
    x = jnp.ones((4, 224, 224, 3), out_sharding=jax.P("fsdp"), dtype=act_dtype)
    y = model(x, rngs=nnx.Rngs(1))
    print("Predictions shape: ", jax.typeof(y))
Predictions shape:  bfloat16[4@fsdp,20]

Food 101 dataset#

In this section, we’ll prepare the dataset and train the ViT model. The dataset is Food 101, which consists of 101 food categories with 101,000 images.

In our example, each class will have 250 test set images and 750 training set images. The training images won’t be cleaned and will contain some amount of noise (on purpose), mostly in the form of intense colors and sometimes wrong labels. All images are rescaled to have a maximum side length of 512 pixels.

Let’s download the dataset from HuggingFace Datasets and select 20 classes to reduce the dataset size and the model training time. We’ll use TorchVision to transform input images and grain for efficient data loading.

from datasets import load_dataset

# Select first 20 classes to reduce the dataset size and the training time.
train_size = 20 * 750
val_size = 20 * 250

train_dataset = load_dataset("ethz/food101", split=f"train[:{train_size}]")
val_dataset = load_dataset("ethz/food101", split=f"validation[:{val_size}]")

# Create labels mapping where we map current labels between 0 and 19.
labels_mapping = {}
index = 0
for i in range(0, len(val_dataset), 250):
    label = val_dataset[i]["label"]
    if label not in labels_mapping:
        labels_mapping[label] = index
        index += 1

inv_labels_mapping = {v: k for k, v in labels_mapping.items()}

print("Training dataset size:", len(train_dataset))
print("Validation dataset size:", len(val_dataset))
Training dataset size: 15000
Validation dataset size: 5000
import matplotlib.pyplot as plt


def display_datapoints(*datapoints, tag="", names_map=None):
    num_samples = len(datapoints)

    fig, axs = plt.subplots(1, num_samples, figsize=(20, 10))
    for i, datapoint in enumerate(datapoints):
        if isinstance(datapoint, dict):
            img, label = datapoint["image"], datapoint["label"]
        else:
            img, label = datapoint

        if hasattr(img, "dtype") and img.dtype in (np.float32, ):
            img = ((img - img.min()) / (img.max() - img.min()) * 255.0).astype(np.uint8)

        label_str = f" ({names_map[label]})" if names_map is not None else ""
        axs[i].set_title(f"{tag}Label: {label}{label_str}")
        axs[i].imshow(img)

Visualize a few samples from the training and test sets:

display_datapoints(
    train_dataset[0], train_dataset[1000], train_dataset[2000], train_dataset[3000],
    tag="(Training) ",
    names_map=train_dataset.features["label"].names
)

display_datapoints(
    val_dataset[0], val_dataset[1000], val_dataset[2000], val_dataset[-1],
    tag="(Validation) ",
    names_map=val_dataset.features["label"].names
)
../_images/43f25dc985b4c3b9e185993917dfbb7ec0fd37a6753a91081d5abbea5dd905e6.png ../_images/2defbb376cbcc085b5fc06dbca8aa68e97611a0016d6bf62c4fd0c687a2f446e.png

We need to define training and test set image preprocessing helper functions. Training image transformations will also contain random augmentations to prevent overfitting and make the trained model more robust.

import numpy as np
from torchvision.transforms import v2 as T


img_size = 224


def to_np_array(pil_image):
  return np.asarray(pil_image.convert("RGB"))


def normalize(image):
    # Image preprocessing matches the one of pretrained ViT
    mean = np.array([0.5, 0.5, 0.5], dtype=np.float32)
    std = np.array([0.5, 0.5, 0.5], dtype=np.float32)
    image = image.astype(np.float32) / 255.0
    return (image - mean) / std


tv_train_transforms = T.Compose([
    T.RandomResizedCrop((img_size, img_size), scale=(0.7, 1.0)),
    T.RandomHorizontalFlip(),
    T.ColorJitter(0.2, 0.2, 0.2),
    T.Lambda(to_np_array),
    T.Lambda(normalize),
])


tv_test_transforms = T.Compose([
    T.Resize((img_size, img_size)),
    T.Lambda(to_np_array),
    T.Lambda(normalize),
])


def get_transform(fn):
    def wrapper(batch):
        batch["image"] = [
            fn(pil_image) for pil_image in batch["image"]
        ]
        # map label index between 0 - 19
        batch["label"] = [
            labels_mapping[label] for label in batch["label"]
        ]
        return batch
    return wrapper


train_transforms = get_transform(tv_train_transforms)
val_transforms = get_transform(tv_test_transforms)

train_dataset = train_dataset.with_transform(train_transforms)
val_dataset = val_dataset.with_transform(val_transforms)
import grain.python as grain


seed = 12
train_batch_size = 64
val_batch_size = 2 * train_batch_size


# Create an `grain.IndexSampler` with no sharding for single-device computations.
train_sampler = grain.IndexSampler(
    len(train_dataset),  # The total number of samples in the data source.
    shuffle=True,            # Shuffle the data to randomize the order.of samples
    seed=seed,               # Set a seed for reproducibility.
    shard_options=grain.NoSharding(),  # No multi-host sharding since this is a single host setup.
    num_epochs=1,            # Iterate over the dataset for one epoch.
)

val_sampler = grain.IndexSampler(
    len(val_dataset),  # The total number of samples in the data source.
    shuffle=False,         # Do not shuffle the data.
    seed=seed,             # Set a seed for reproducibility.
    shard_options=grain.NoSharding(),  # No multi-host sharding since this is a single host setup.
    num_epochs=1,          # Iterate over the dataset for one epoch.
)


train_loader = grain.DataLoader(
    data_source=train_dataset,
    sampler=train_sampler,                 # A sampler to determine how to access the data.
    worker_count=4,                        # Number of child processes launched to parallelize the transformations among.
    worker_buffer_size=2,                  # Count of output batches to produce in advance per worker.
    operations=[
        grain.Batch(train_batch_size, drop_remainder=True),
    ]
)

# Test (validation) dataset `grain.DataLoader`.
val_loader = grain.DataLoader(
    data_source=val_dataset,
    sampler=val_sampler,                   # A sampler to determine how to access the data.
    worker_count=4,                        # Number of child processes launched to parallelize the transformations among.
    worker_buffer_size=2,
    operations=[
        grain.Batch(val_batch_size),
    ]
)

Let’s visualize the training and test set batches:

train_batch = next(iter(train_loader))
val_batch = next(iter(val_loader))
/tmp/venv/lib/python3.12/site-packages/PIL/TiffImagePlugin.py:949: UserWarning: Truncated File Read
  warnings.warn(str(msg))
print("Training batch info:", train_batch["image"].shape, train_batch["image"].dtype, train_batch["label"].shape, train_batch["label"].dtype)
print("Validation batch info:", val_batch["image"].shape, val_batch["image"].dtype, val_batch["label"].shape, val_batch["label"].dtype)
Training batch info: (64, 224, 224, 3) float32 (64,) int64
Validation batch info: (128, 224, 224, 3) float32 (128,) int64
display_datapoints(
    *[(train_batch["image"][i], train_batch["label"][i]) for i in range(5)],
    tag="(Training) ",
    names_map={k: train_dataset.features["label"].names[v] for k, v in inv_labels_mapping.items()}
)
../_images/2d8a6d6fd05a236e6e01c97b0e959b778eefd81d1ea1f000b050facc7ac64a2d.png
display_datapoints(
    *[(val_batch["image"][i], val_batch["label"][i]) for i in range(5)],
    tag="(Validation) ",
    names_map={k: val_dataset.features["label"].names[v] for k, v in inv_labels_mapping.items()}
)
../_images/8d6883f3760e286f729045901b5476da17cba39276f2b37c06294f83d86ed7c0.png

Defining the optimizier, the loss function, training/test steps, and metrics#

In this section, we’ll define the optimizer, the loss function, the training and test step functions, and then begin training the model.

First, initiliaze the learning rate and the SGD optimizer with optax, using optax.sgd and flax.nnx.Optimizer:

num_epochs = 3
learning_rate = 0.001
momentum = 0.8
total_steps = len(train_dataset) // train_batch_size

lr_schedule = optax.linear_schedule(learning_rate, 0.0, num_epochs * total_steps)

iterate_subsample = np.linspace(0, num_epochs * total_steps, 100)
plt.plot(
    np.linspace(0, num_epochs, len(iterate_subsample)),
    [lr_schedule(i) for i in iterate_subsample],
    lw=3,
)
plt.title("Learning rate")
plt.xlabel("Epochs")
plt.ylabel("Learning rate")
plt.grid()
plt.xlim((0, num_epochs))
plt.show()


with jax.set_mesh(mesh):
    optimizer = nnx.Optimizer(model, optax.sgd(lr_schedule, momentum, nesterov=True), wrt=nnx.Param)
../_images/79ae59639a5bdc7c97f774d48ce9ff727cad0b5a258aecb36a7ef16b22865849.png

Define a loss function with optax.softmax_cross_entropy_with_integer_labels:

def compute_losses_and_logits(
    model: nnx.Module,
    images: jax.Array,
    labels: jax.Array,
    rngs: nnx.Rngs | None = None
) -> tuple[jax.Array, jax.Array]:
    logits = model(images, rngs=rngs)
    loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits.astype(jnp.float32), labels=labels
    ).mean()
    return loss, logits

Set up the train and test steps (with nnx.jit and nnx.value_and_grad:

@nnx.jit(donate_argnames=("model", "optimizer"))
def train_step(
    model: nnx.Module, optimizer: nnx.Optimizer, rngs: nnx.Rngs, batch: tuple[jax.Array, jax.Array]
):
    images, labels = batch
    grad_fn = nnx.value_and_grad(compute_losses_and_logits, has_aux=True)
    (loss, _), grads = grad_fn(model, images, labels, rngs.fork())

    optimizer.update(model, grads)

    return loss


@nnx.jit
def eval_step(
    model: nnx.Module, batch: tuple[jax.Array, jax.Array], eval_metrics: nnx.MultiMetric
):
    images, labels = batch
    loss, logits = compute_losses_and_logits(model, images, labels)
    eval_metrics.update(
        loss=loss,
        logits=logits,
        labels=labels,
    )

Instantiae the metrics function with nnx.MultiMetric:

eval_metrics = nnx.MultiMetric(
    loss=nnx.metrics.Average('loss'),
    accuracy=nnx.metrics.Accuracy(),
)


train_metrics_history = {
    "train_loss": [],
}

eval_metrics_history = {
    "val_loss": [],
    "val_accuracy": [],
}
import tqdm


bar_format = "{desc}[{n_fmt}/{total_fmt}]{postfix} [{elapsed}<{remaining}]"

# We define a view of the model sharing the weights but with attributes set for evaluation
eval_model = nnx.view(model, deterministic=True)
rngs = nnx.Rngs(12)

def train_one_epoch(epoch):
    with tqdm.tqdm(
        desc=f"[train] epoch: {epoch + 1}/{num_epochs}, ",
        total=total_steps,
        bar_format=bar_format,
        leave=True,
    ) as pbar, jax.set_mesh(mesh):
        prev_loss = None
        for batch in train_loader:

            # Convert np.ndarray to jax.Array on GPUs
            images = jax.device_put(batch["image"].astype(act_dtype), device=jax.P("fsdp"))
            labels = jax.device_put(batch["label"].astype(int), device=jax.P("fsdp"))

            loss = train_step(model, optimizer, rngs, (images, labels))
            if prev_loss is not None:
                # Async metrics recording and printing
                train_metrics_history["train_loss"].append(prev_loss.item())
                pbar.set_postfix({"loss": prev_loss.item()})
            prev_loss = loss
            pbar.update(1)


def evaluate_model(epoch):
    # Computes the metrics on the training and test sets after each training epoch.
    with jax.set_mesh(mesh):
        eval_metrics.reset()  # Reset the eval metrics
        for val_batch in val_loader:

            # Convert np.ndarray to jax.Array on GPUs
            images = jax.device_put(val_batch["image"].astype(act_dtype), device=jax.P("fsdp"))
            labels = jax.device_put(val_batch["label"].astype(int), device=jax.P("fsdp"))

            eval_step(eval_model, (images, labels), eval_metrics)

    for metric, value in eval_metrics.compute().items():
        eval_metrics_history[f'val_{metric}'].append(value)

    print(f"[val] epoch: {epoch + 1}/{num_epochs}")
    print(f"- total loss: {eval_metrics_history['val_loss'][-1]:0.4f}")
    print(f"- Accuracy: {eval_metrics_history['val_accuracy'][-1]:0.4f}")

Training the model#

Begin training the model:

config
ModelConfig(num_classes=1000, in_channels=3, img_size=224, patch_size=16, num_layers=12, num_heads=12, mlp_dim=3072, hidden_size=768, dropout_rate=0.1, sharding=ShardingConfig(attn_qkvo_weight_ndh=P(None, 'fsdp', None), mlp_weight_df=P('fsdp', None), mlp_weight_fd=P(None, 'fsdp'), act_btd=P('fsdp', None, None), act_btf=P('fsdp', None, None), act_btnh=P('fsdp', None, None, None), act_bc=P('fsdp', None), fsdp_axis_name='fsdp'), use_bfloat16=True)
%%time

for epoch in range(num_epochs):
    train_one_epoch(epoch)
    evaluate_model(epoch)
[train] epoch: 1/3, [232/234], loss=0.592 [02:06<00:01]
/tmp/venv/lib/python3.12/site-packages/PIL/TiffImagePlugin.py:949: UserWarning: Truncated File Read
  warnings.warn(str(msg))
[val] epoch: 1/3
- total loss: 0.3671
- Accuracy: 0.9210
[train] epoch: 2/3, [232/234], loss=0.389 [01:27<00:00]
/tmp/venv/lib/python3.12/site-packages/PIL/TiffImagePlugin.py:949: UserWarning: Truncated File Read
  warnings.warn(str(msg))
[val] epoch: 2/3
- total loss: 0.2645
- Accuracy: 0.9338
[train] epoch: 3/3, [232/234], loss=0.45 [01:26<00:00] 
/tmp/venv/lib/python3.12/site-packages/PIL/TiffImagePlugin.py:949: UserWarning: Truncated File Read
  warnings.warn(str(msg))
[val] epoch: 3/3
- total loss: 0.2466
- Accuracy: 0.9370
CPU times: user 2min 36s, sys: 45 s, total: 3min 21s
Wall time: 5min 48s

Visualize the collected metrics:

plt.plot(train_metrics_history["train_loss"], label="Loss value during the training")
plt.legend()
<matplotlib.legend.Legend at 0x7441746b5490>
../_images/4308ce27410763b0817a82552231149f310d97c7d9615f64d55d4edc72d90b83.png
fig, axs = plt.subplots(1, 2, figsize=(10, 10))
axs[0].set_title("Loss value on validation set")
axs[0].plot(eval_metrics_history["val_loss"])
axs[1].set_title("Accuracy on validation set")
axs[1].plot(eval_metrics_history["val_accuracy"])
[<matplotlib.lines.Line2D at 0x744195f61130>]
../_images/386ee1cdf7db6f7e432706b488eb9c9ac411904fc48078408e834fad733ccf21.png

Check the model’s predictions on the test data:

test_indices = [1, 250, 500, 750, 1000, 1234]

test_images = [val_dataset[i]["image"] for i in test_indices]
expected_labels = [val_dataset[i]["label"] for i in test_indices]

with jax.set_mesh(mesh):
    inputs = jnp.asarray(test_images, out_sharding=jax.P("fsdp"))
    preds = eval_model(inputs)
    preds = jax.sharding.reshard(preds, jax.P())
num_samples = len(test_indices)
names_map = train_dataset.features["label"].names

probas = nnx.softmax(preds, axis=1)
pred_labels = probas.argmax(axis=1)


fig, axs = plt.subplots(1, num_samples, figsize=(20, 10))
for i in range(num_samples):
    img, expected_label = test_images[i], expected_labels[i]

    pred_label = pred_labels[i].item()
    proba = probas[i, pred_label].item()
    if img.dtype in (np.float32, ):
        img = ((img - img.min()) / (img.max() - img.min()) * 255.0).astype(np.uint8)

    expected_label_str = names_map[inv_labels_mapping[expected_label]]
    pred_label_str = names_map[inv_labels_mapping[pred_label]]
    axs[i].set_title(f"Expected: {expected_label_str} vs \nPredicted: {pred_label_str}, P={proba:.2f}")
    axs[i].imshow(img)
../_images/63f0a4d4d98e2cf47784b07034745e165e512aed3d59a26ae378f0c1a1f3405b.png

Further reading#

In this example we implemented the ViT model and finetuned it on a subset of the Food 101 dataset.

For further reading, check out other Examples.