Example: Image segmentation with UNETR model#

This example demonstrates how to implement and train a model on image segmentation task. Below, we will be using the Oxford Pets dataset containing images and masks of cats and dogs. We will implement from scratch the UNETR model using Flax NNX. We will train the model on a training set and compute image segmentation metrics on the training and validation sets. We will use Orbax checkpoint manager to store best models during the training.

Prepare image segmentation dataset and dataloaders#

In this section we use the Oxford Pets dataset. We download images and masks and provide a code to work with the dataset. This approach can be easily extended to any image segmentation datasets and users can reuse this code for their own datasets.

In the code below we make a choice of using OpenCV and Pillow to read images and masks as NumPy arrays, Albumentations for data augmentations and grain for batched data loading. Alternatively, one can use tensorflow_dataset or torchvision for the same task.

Requirements installation#

We will need to install the following Python packages:

!pip install -U opencv-python-headless grain albumentations Pillow
!pip install -U flax optax orbax-checkpoint
import jax
import flax
import optax
import orbax.checkpoint as ocp
print("Jax version:", jax.__version__)
print("Flax version:", flax.__version__)
print("Optax version:", optax.__version__)
print("Orbax version:", ocp.__version__)
Jax version: 0.10.1.dev20260519
Flax version: 0.12.7
Optax version: 0.2.8
Orbax version: 0.11.33

Data download#

Let’s download the data and extract images and masks.

# !rm -rf /tmp/data/oxford_pets
# !mkdir -p /tmp/data/oxford_pets
# !wget https://thor.robots.ox.ac.uk/datasets/pets/images.tar.gz -O /tmp/data/oxford_pets/images.tar.gz
# !wget https://thor.robots.ox.ac.uk/datasets/pets/annotations.tar.gz -O /tmp/data/oxford_pets/annotations.tar.gz

# !cd /tmp/data/oxford_pets && tar -xf images.tar.gz
# !cd /tmp/data/oxford_pets && tar -xf annotations.tar.gz
# !ls /tmp/data/oxford_pets

We can also inspect the downloaded images folder, listing a subset of these files:

!ls /tmp/data/oxford_pets/images | wc -l
!ls /tmp/data/oxford_pets/images | head
!ls /tmp/data/oxford_pets/annotations/trimaps | wc -l
!ls /tmp/data/oxford_pets/annotations/trimaps | head
7393
Abyssinian_1.jpg
Abyssinian_10.jpg
Abyssinian_100.jpg
Abyssinian_100.mat
Abyssinian_101.jpg
Abyssinian_101.mat
Abyssinian_102.jpg
Abyssinian_102.mat
Abyssinian_103.jpg
Abyssinian_104.jpg
ls: write error: Broken pipe
7390
Abyssinian_1.png
Abyssinian_10.png
Abyssinian_100.png
Abyssinian_101.png
Abyssinian_102.png
Abyssinian_103.png
Abyssinian_104.png
Abyssinian_105.png
Abyssinian_106.png
Abyssinian_107.png
ls: write error: Broken pipe

Train/Eval datasets#

Let’s implement the dataset class providing the access to the images and masks. The class implements __len__ and __getitem__ methods. In this example, we do not have a hard training and validation data split, so we will use the total dataset and make a random training/validation split by indices. For this purpose we provide a helper class to map indices into training and validation parts.

from typing import Any
from pathlib import Path

import cv2
import numpy as np
from PIL import Image  # we'll read images with opencv and use Pillow as a fallback


class OxfordPetsDataset:
    def __init__(self, path: Path):
        assert path.exists(), path
        self.path: Path = path
        self.images = sorted((self.path / "images").glob("*.jpg"))
        self.masks = [
            self.path / "annotations" / "trimaps" / path.with_suffix(".png").name
            for path in self.images
        ]
        assert len(self.images) == len(self.masks), (len(self.images), len(self.masks))

    def __len__(self) -> int:
        return len(self.images)

    def read_image_opencv(self, path: Path):
        img = cv2.imread(str(path))
        if img is not None:
            return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        else:
            None

    def read_image_pillow(self, path: Path):
        img = Image.open(str(path))
        img = img.convert("RGB")
        return np.asarray(img)

    def read_mask(self, path: Path):
        mask = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
        # mask has values: 1, 2, 3
        # 1 - object mask
        # 2 - background
        # 3 - boundary
        # Define mask as 0-based int values
        mask = mask - 1
        return mask.astype("uint8")

    def __getitem__(self, index: int) -> dict[str, np.ndarray]:
        img_path, mask_path = self.images[index], self.masks[index]
        img = self.read_image_opencv(img_path)
        if img is None:
            # Fallback to Pillow if OpenCV fails to read an image
            img = self.read_image_pillow(img_path)
        mask = self.read_mask(mask_path)
        return {
            "image": img,
            "mask": mask,
        }


class SubsetDataset:
    def __init__(self, dataset, indices: list[int]):
        # Check input indices values:
        for i in indices:
            assert 0 <= i < len(dataset)
        self.dataset = dataset
        self.indices = indices

    def __len__(self) -> int:
        return len(self.indices)

    def __getitem__(self, index: int) -> Any:
        i = self.indices[index]
        return self.dataset[i]

Now, let’s define the total dataset and compute data indices for training and validation splits:

seed = 12
train_split = 0.7
dataset_path = Path("/tmp/data/oxford_pets")

dataset = OxfordPetsDataset(dataset_path)

rng = np.random.default_rng(seed=seed)
le = len(dataset)
data_indices = list(range(le))

# Let's remove few indices corresponding to corrupted images
# to avoid libjpeg warnings during the data loading
corrupted_data_indices = [3017, 3425]
for index in corrupted_data_indices:
    data_indices.remove(index)

random_indices = rng.permutation(data_indices)

train_val_split_index = int(train_split * le)
train_indices = random_indices[:train_val_split_index]
val_indices = random_indices[train_val_split_index:]

# Ensure there is no overlapping
assert len(set(train_indices) & set(val_indices)) == 0

train_dataset = SubsetDataset(dataset, indices=train_indices)
val_dataset = SubsetDataset(dataset, indices=val_indices)

print("Training dataset size:", len(train_dataset))
print("Validation dataset size:", len(val_dataset))
Training dataset size: 5173
Validation dataset size: 2215

To verify our work so far, let’s display few training and validation images and masks:

import matplotlib.pyplot as plt


def display_datapoint(datapoint, label=""):
    img, mask = datapoint["image"], datapoint["mask"]
    if img.dtype in (np.float32, ):
        img = ((img - img.min()) / (img.max() - img.min()) * 255.0).astype(np.uint8)
    fig, axs = plt.subplots(1, 3, figsize=(10, 10))
    axs[0].set_title(f"Image{label}")
    axs[0].imshow(img)
    axs[1].set_title(f"Mask{label}")
    axs[1].imshow(mask)
    axs[2].set_title("Image + Mask")
    axs[2].imshow(img)
    axs[2].imshow(mask, alpha=0.5)



display_datapoint(train_dataset[0], label=" (train set)")
display_datapoint(val_dataset[0], label=" (val set)")
../_images/3bfd65b1ec2eb5a5a872c7c6566ce7ead712d536b92c5a1db9ee212e4e3e8d1d.png ../_images/206376806bcb8c4662f1d6b52b8df8c61d2db33a9ce5c8a8884c2496b21a13a1.png

Data augmentations#

Next, let’s define a simple data augmentation pipeline of joined image and mask transformations using Albumentations. We apply geometric and color transformations to increase the diversity of the training data. For more details on the Albumentations transformations, we can check Albumentations reference API.

import albumentations as albu


img_size = 256

train_transforms = albu.Compose([
    albu.Affine(rotate=(-35, 35), fill_mask=1, p=0.3),  # Random rotations -35 to 35 degrees
    albu.RandomResizedCrop(size=(img_size, img_size), scale=(0.7, 1.0)),  # Crop a random part of the input and rescale it to a specified size
    albu.HorizontalFlip(p=0.5),  # Horizontal random flip
    albu.RandomBrightnessContrast(p=0.4),  # Randomly changes the brightness and contrast
    albu.Normalize(),  # Normalize the image and cast to float
])


val_transforms = albu.Compose([
    albu.Resize(width=img_size, height=img_size),
    albu.Normalize(),  # Normalize the image and cast to float
])
output = train_transforms(**train_dataset[0])
img, mask = output["image"], output["mask"]
print("Image array info:", img.dtype, img.shape, img.min(), img.mean(), img.max())
print("Mask array info:", mask.dtype, mask.shape, mask.min(), mask.max())
Image array info: float32 (256, 256, 3) -2.117904 0.12848356 2.5702832
Mask array info: uint8 (256, 256) 0 2
output = val_transforms(**val_dataset[0])
img, mask = output["image"], output["mask"]
print("Image array info:", img.dtype, img.shape, img.min(), img.mean(), img.max())
print("Mask array info:", mask.dtype, mask.shape, mask.min(), mask.max())
Image array info: float32 (256, 256, 3) -2.117904 -0.30076742 2.6399999
Mask array info: uint8 (256, 256) 0 2

Data loaders#

Let’s now use grain to perform data loading, augmentations and batching on a single device using multiple workers. We will create a random index sampler for training and an unshuffled sampler for validation.

from typing import Any, Callable

import grain.python as grain


class DataAugs(grain.MapTransform):
    def __init__(self, transforms: Callable):
        self.albu_transforms = transforms

    def map(self, data):
        output = self.albu_transforms(**data)
        return output
train_batch_size = 72
val_batch_size = 2 * train_batch_size


# Create an IndexSampler with no sharding for single-device computations
train_sampler = grain.IndexSampler(
    len(train_dataset),  # The total number of samples in the data source
    shuffle=True,            # Shuffle the data to randomize the order of samples
    seed=seed,               # Set a seed for reproducibility
    shard_options=grain.NoSharding(),  # No multi-host sharding since this is a single host setup.
    num_epochs=1,            # Iterate over the dataset for one epoch
)

val_sampler = grain.IndexSampler(
    len(val_dataset),  # The total number of samples in the data source
    shuffle=False,         # Do not shuffle the data
    seed=seed,             # Set a seed for reproducibility
    shard_options=grain.NoSharding(),  # No multi-host sharding since this is a single host setup.
    num_epochs=1,          # Iterate over the dataset for one epoch
)
train_loader = grain.DataLoader(
    data_source=train_dataset,
    sampler=train_sampler,                 # Sampler to determine how to access the data
    worker_count=4,                        # Number of child processes launched to parallelize the transformations among
    worker_buffer_size=2,                  # Count of output batches to produce in advance per worker
    operations=[
        DataAugs(train_transforms),
        grain.Batch(train_batch_size, drop_remainder=True),
    ]
)

# Validation dataset loader
val_loader = grain.DataLoader(
    data_source=val_dataset,
    sampler=val_sampler,                   # Sampler to determine how to access the data
    worker_count=4,                        # Number of child processes launched to parallelize the transformations among
    worker_buffer_size=2,
    operations=[
        DataAugs(val_transforms),
        grain.Batch(val_batch_size, drop_remainder=True),  # drop remainder for val loader as we shard on batch dim
    ]
)

# Training dataset loader for evaluation (without dataaugs)
train_eval_loader = grain.DataLoader(
    data_source=train_dataset,
    sampler=train_sampler,                 # Sampler to determine how to access the data
    worker_count=4,                        # Number of child processes launched to parallelize the transformations among
    worker_buffer_size=2,                  # Count of output batches to produce in advance per worker
    operations=[
        DataAugs(val_transforms),
        grain.Batch(val_batch_size, drop_remainder=True),
    ]
)
train_batch = next(iter(train_loader))
val_batch = next(iter(val_loader))
print("Train images batch info:", type(train_batch["image"]), train_batch["image"].shape, train_batch["image"].dtype)
print("Train masks batch info:", type(train_batch["mask"]), train_batch["mask"].shape, train_batch["mask"].dtype)
Train images batch info: <class 'grain._src.python.ipc.shared_memory_array.SharedMemoryArray'> (72, 256, 256, 3) float32
Train masks batch info: <class 'grain._src.python.ipc.shared_memory_array.SharedMemoryArray'> (72, 256, 256) uint8

Finally, let’s display the training and validation data:

images, masks = train_batch["image"], train_batch["mask"]

for img, mask in zip(images[:3], masks[:3]):
    display_datapoint({"image": img, "mask": mask}, label=" (augmented train set)")
../_images/95c1ae6a18c47ae031ab0ff89998227be3eaf38421d42992d66be55eef2c6f08.png ../_images/d824726ecbc16bae38cde123da56e9a87aa2158be35e86dcc227f2a165fc826f.png ../_images/b8ff1f96cdab050c09fed8a8203bc3454f31f270617cbc27bbb64acdf5548a24.png
images, masks = val_batch["image"], val_batch["mask"]

for img, mask in zip(images[:3], masks[:3]):
    display_datapoint({"image": img, "mask": mask}, label=" (augmented validation set)")
../_images/a674ace04c2340b48679e7b1c5b041df2f765613c32b5396a1b65877b05a0a77.png ../_images/46871c1753e31b5847dedc7509acc4c31d4c513a2eb435511bd64aaf062cb3fb.png ../_images/0cbc62d80577d88333bf745312a534da02fba67e8c2232480fa7ea530af2de44.png

Model for Image Segmentation#

In this section we will implement the UNETR model from scratch using Flax NNX. The reference PyTorch implementation of this model can be found on the MONAI Library GitHub repository.

The UNETR model utilizes a transformer as the encoder to learn sequence representations of the input and to capture the global multi-scale information, while also following the “U-shaped” network design like UNet model: image.png

The UNETR architecture on the image above is processing 3D inputs, but it can be easily adapted to 2D input.

The transformer encoder of UNETR is Vision Transformer (ViT). The feature maps returned by ViT have all the same spatial size: (H / 16, W / 16) and deconvolutions are used to upsample the feature maps. Finally, the feature maps are upsampled and concatenated up to the original image size.

from flax import nnx
import jax.numpy as jnp

Vision Transformer encoder implementation#

Below, we will implement the following modules:

  • Vision Transformer, ViT

    • PatchEmbeddingBlock: patch embedding block, which maps patches of pixels to a sequence of vectors

    • ViTEncoderBlock: vision transformer encoder block

      • MLPBlock: multilayer perceptron block

from dataclasses import dataclass
from jax.sharding import PartitionSpec as P


@dataclass(slots=True, frozen=True)
class ShardingConfig:
    attn_qkvo_weight_ndh: P | None = None  # sharding for Q, K, V, Out weights
    mlp_weight_df: P | None = None
    mlp_weight_fd: P | None = None
    act_btd: P | None = None  # sharding of the activation (B, T, D)
    act_btf: P | None = None
    act_btnh: P | None = None

    fsdp_axis_name: str = "fsdp"

    @staticmethod
    def no_sharding():
        return ShardingConfig()

    @staticmethod
    def fsdp_sharding(fsdp_axis_name: str = "fsdp"):
        fsdp = fsdp_axis_name
        return ShardingConfig(
            attn_qkvo_weight_ndh=P(None, fsdp, None),
            mlp_weight_df=P(fsdp, None),
            mlp_weight_fd=P(None, fsdp),
            act_btd=P(fsdp, None, None),
            act_btf=P(fsdp, None, None),
            act_btnh=P(fsdp, None, None, None),
            fsdp_axis_name=fsdp_axis_name,
        )


@dataclass(slots=True, frozen=True)
class ViTConfig:
    in_channels: int = 3
    img_size: int = 224
    patch_size: int = 16
    num_layers: int = 12
    num_heads: int = 12
    mlp_dim: int = 3072
    hidden_size: int = 768
    dropout_rate: float = 0.1
    sharding: ShardingConfig = ShardingConfig.no_sharding()
from typing import Callable


class PatchEmbedding(nnx.Module):
    def __init__(
        self,
        in_channels: int,  # dimension of input channels.
        img_size: int,  # dimension of input image.
        patch_size: int,  # dimension of patch size.
        hidden_size: int,  # dimension of hidden layer.
        dropout_rate: float = 0.0,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ):
        n_patches = (img_size // patch_size) ** 2
        self.patch_embeddings = nnx.Conv(
            in_channels,
            hidden_size,
            kernel_size=(patch_size, patch_size),
            strides=(patch_size, patch_size),
            padding="VALID",
            use_bias=True,
            rngs=rngs,
        )

        initializer = jax.nn.initializers.truncated_normal(stddev=0.02)
        self.position_embeddings = nnx.Param(
            initializer(rngs.params(), (1, n_patches, hidden_size), jnp.float32)
        )
        self.dropout = nnx.Dropout(dropout_rate, rngs=rngs)

    def __call__(
        self, 
        x: jax.Array, 
        rngs: nnx.Rngs | None = None, 
        out_sharding: P | None = None,
    ) -> jax.Array:
        x = self.patch_embeddings(x, out_sharding=out_sharding)
        x = x.reshape(x.shape[0], -1, x.shape[-1])
        embeddings = x + self.position_embeddings
        embeddings = self.dropout(embeddings, rngs=rngs)
        return embeddings


class MLPBlock(nnx.Module):
    def __init__(
        self,
        hidden_size: int,  # dimension of hidden layer.
        mlp_dim: int,      # dimension of feedforward layer
        dropout_rate: float = 0.0,
        activation_layer: Callable = nnx.gelu,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
        up_proj_sharding: P | None = None,
        down_proj_sharding: P | None = None,
    ):
        self.up_proj = nnx.Linear(
            hidden_size, 
            mlp_dim, 
            rngs=rngs,
            kernel_metadata={"out_sharding": up_proj_sharding},
        )
        self.act = activation_layer
        self.drop = nnx.Dropout(dropout_rate, rngs=rngs)
        self.down_proj = nnx.Linear(
            mlp_dim, 
            hidden_size, 
            rngs=rngs,
            kernel_metadata={"out_sharding": down_proj_sharding},
        )

    def __call__(
        self, 
        x: jax.Array, 
        rngs: nnx.Rngs | None = None,
        up_proj_out_sharding: P | None = None,
        out_sharding: P | None = None,        
    ) -> jax.Array:
        x = self.up_proj(x, out_sharding=up_proj_out_sharding)
        x = self.act(x)
        x = self.drop(x)
        x = self.down_proj(x, out_sharding=out_sharding)
        return self.drop(x)


class ViTEncoderBlock(nnx.Module):
    def __init__(
        self,
        hidden_size: int,  # dimension of hidden layer.
        mlp_dim: int,      # dimension of feedforward layer.
        num_heads: int,    # number of attention heads
        dropout_rate: float = 0.0,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
        sharding: ShardingConfig = ShardingConfig(),
    ) -> None:
        self.mlp = MLPBlock(
            hidden_size, 
            mlp_dim, 
            dropout_rate, 
            rngs=rngs, 
            up_proj_sharding=sharding.mlp_weight_df,
            down_proj_sharding=sharding.mlp_weight_fd,
        )
        self.norm1 = nnx.LayerNorm(hidden_size, rngs=rngs)
        self.attn = nnx.MultiHeadAttention(
            num_heads=num_heads,
            in_features=hidden_size,
            dropout_rate=dropout_rate,
            broadcast_dropout=False,
            decode=False,
            deterministic=False,
            rngs=rngs,
            keep_rngs=False,
            kernel_metadata={"out_sharding": sharding.attn_qkvo_weight_ndh},
            out_kernel_metadata={"out_sharding": sharding.attn_qkvo_weight_ndh},            
        )
        self.norm2 = nnx.LayerNorm(hidden_size, rngs=rngs)
        self.sharding = sharding

    def __call__(
        self, 
        x: jax.Array,
        rngs: nnx.Rngs | None = None,
    ) -> jax.Array:
        x = x + self.attn(
            self.norm1(x), 
            rngs=rngs,
            out_sharding=self.sharding.act_btd,
            qkv_sharding=self.sharding.act_btnh,            
        )
        x = x + self.mlp(
            self.norm2(x),
            rngs=rngs,
            out_sharding=self.sharding.act_btd,
            up_proj_out_sharding=self.sharding.act_btf,            
        )
        return x


class ViT(nnx.Module):
    def __init__(
        self,
        config,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ):
        if config.hidden_size % config.num_heads != 0:
            raise ValueError("hidden_size should be divisible by num_heads.")

        self.config = config
        self.patch_embedding = PatchEmbedding(
            in_channels=config.in_channels,
            img_size=config.img_size,
            patch_size=config.patch_size,
            hidden_size=config.hidden_size,
            dropout_rate=config.dropout_rate,
            rngs=rngs,
        )
        self.blocks = nnx.List([
            ViTEncoderBlock(
                config.hidden_size, 
                config.mlp_dim, 
                config.num_heads, 
                config.dropout_rate, 
                rngs=rngs,
                sharding=config.sharding,
            )
            for i in range(config.num_layers)
        ])
        self.norm = nnx.LayerNorm(config.hidden_size, rngs=rngs)

    def __call__(
        self, 
        x: jax.Array,
        rngs: nnx.Rngs | None = None,
    ) -> jax.Array:
        x = self.patch_embedding(
            x, 
            rngs=rngs,
            out_sharding=self.config.sharding.act_btd,
        )
        hidden_states_out = []
        for blk in self.blocks:
            x = blk(x, rngs=rngs)
            hidden_states_out.append(x)
        x = self.norm(x)
        return x, hidden_states_out


config = ViTConfig()
mod = ViT(config)
x = jnp.ones((4, 224, 224, 3))
y, hstates = mod(x, rngs=nnx.Rngs(1))
print(y.shape, len(hstates))
del mod, x
(4, 196, 768) 12

At this point we implemented the encoder of the UNETR model. As we can see from the above output, ViT provides one encoded feature map and a list of intermediate feature maps. Three of them will be used in the decoding part.

UNETR blocks implementation#

Now, we can implement remaining blocks and assemble them together in the UNETR implementation

Below, we will implement the following modules:

  • UNETR

    • UnetrBasicBlock: creates the first skip connection from the input.

      • UnetResBlock

    • UnetrPrUpBlock: projection upsampling modules to create skip connections from the intermediate feature maps provided by ViT.

    • UnetrUpBlock: upsampling modules used in the decoder

class Conv2dNormActivation(nnx.Sequential):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        padding: int | None = None,
        groups: int = 1,
        norm_layer: Callable[..., nnx.Module] = nnx.BatchNorm,
        activation_layer: Callable = nnx.relu,
        dilation: int = 1,
        bias: bool | None = None,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ):
        self.out_channels = out_channels

        if padding is None:
            padding = (kernel_size - 1) // 2 * dilation
        if bias is None:
            bias = norm_layer is None

        # sequence integer pairs that give the padding to apply before
        # and after each spatial dimension
        padding = ((padding, padding), (padding, padding))

        layers = [
            nnx.Conv(
                in_channels,
                out_channels,
                kernel_size=(kernel_size, kernel_size),
                strides=(stride, stride),
                padding=padding,
                kernel_dilation=(dilation, dilation),
                feature_group_count=groups,
                use_bias=bias,
                rngs=rngs,
            )
        ]

        if norm_layer is not None:
            layers.append(norm_layer(out_channels, rngs=rngs))

        if activation_layer is not None:
            layers.append(activation_layer)

        super().__init__(*layers)
class InstanceNorm(nnx.GroupNorm):
    def __init__(self, num_features, **kwargs):
        num_groups, group_size = num_features, None
        super().__init__(
            num_features,
            num_groups=num_groups,
            group_size=group_size,
            **kwargs,
        )


class UnetResBlock(nnx.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int,
        norm_layer: Callable[..., nnx.Module] = InstanceNorm,
        activation_layer: Callable = nnx.leaky_relu,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ):
        self.conv_norm_act1 = Conv2dNormActivation(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            norm_layer=norm_layer,
            activation_layer=activation_layer,
            rngs=rngs,
        )
        self.conv_norm2 = Conv2dNormActivation(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=1,
            norm_layer=norm_layer,
            activation_layer=None,
            rngs=rngs,
        )

        self.downsample = (in_channels != out_channels) or (stride != 1)
        if self.downsample:
            self.conv_norm3 = Conv2dNormActivation(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1,
                stride=stride,
                norm_layer=norm_layer,
                activation_layer=None,
                rngs=rngs,
            )
        self.act = activation_layer

    def __call__(self, x: jax.Array) -> jax.Array:
        residual = x
        out = self.conv_norm_act1(x)
        out = self.conv_norm2(out)
        if self.downsample:
            residual = self.conv_norm3(residual)
        out += residual
        out = self.act(out)
        return out
class UnetrBasicBlock(nnx.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int,
        norm_layer: Callable[..., nnx.Module] = InstanceNorm,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ):
        self.layer = UnetResBlock(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            norm_layer=norm_layer,
        )

    def __call__(self, x: jax.Array) -> jax.Array:
        return self.layer(x)
class UnetrPrUpBlock(nnx.Module):
    def __init__(
        self,
        in_channels: int,  # number of input channels.
        out_channels: int, # number of output channels.
        num_layer: int,    # number of upsampling blocks.
        kernel_size: int,
        stride: int,
        upsample_kernel_size: int = 2,  # convolution kernel size for transposed convolution layers.
        norm_layer: Callable[..., nnx.Module] = InstanceNorm,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ):
        upsample_stride = upsample_kernel_size
        self.transp_conv_init = nnx.ConvTranspose(
            in_features=in_channels,
            out_features=out_channels,
            kernel_size=(upsample_kernel_size, upsample_kernel_size),
            strides=(upsample_stride, upsample_stride),
            padding="VALID",
            rngs=rngs,
        )
        self.blocks = nnx.List([
            nnx.Sequential(
                nnx.ConvTranspose(
                    in_features=out_channels,
                    out_features=out_channels,
                    kernel_size=(upsample_kernel_size, upsample_kernel_size),
                    strides=(upsample_stride, upsample_stride),
                    rngs=rngs,
                ),
                UnetResBlock(
                    in_channels=out_channels,
                    out_channels=out_channels,
                    kernel_size=kernel_size,
                    stride=stride,
                    norm_layer=norm_layer,
                    rngs=rngs,
                ),
            )
            for i in range(num_layer)
        ])

    def __call__(self, x: jax.Array) -> jax.Array:
        x = self.transp_conv_init(x)
        for blk in self.blocks:
            x = blk(x)
        return x
class UnetrUpBlock(nnx.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        upsample_kernel_size: int = 2,  # convolution kernel size for transposed convolution layers.
        norm_layer: Callable[..., nnx.Module] = InstanceNorm,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ) -> None:
        upsample_stride = upsample_kernel_size
        self.transp_conv = nnx.ConvTranspose(
            in_features=in_channels,
            out_features=out_channels,
            kernel_size=(upsample_kernel_size, upsample_kernel_size),
            strides=(upsample_stride, upsample_stride),
            padding="VALID",
            rngs=rngs,
        )
        self.conv_block = UnetResBlock(
            out_channels + out_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=1,
            norm_layer=norm_layer,
            rngs=rngs,
        )

    def __call__(self, x: jax.Array, skip: jax.Array) -> jax.Array:
        out = self.transp_conv(x)
        out = jnp.concat((out, skip), axis=-1)
        out = self.conv_block(out)
        return out
@dataclass(slots=True, frozen=True)
class ModelConfig:
    out_channels: int
    vit_config: ViTConfig
    feature_size: int = 16
    norm_layer: Callable[..., nnx.Module] = InstanceNorm


class UNETR(nnx.Module):
    """UNETR model ported to NNX from MONAI implementation:
    - https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/nets/unetr.py
    """
    def __init__(
        self,
        config: ModelConfig,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ):
        self.config = config
        vit_config = config.vit_config
        feat_size = vit_config.img_size // vit_config.patch_size

        self.vit = ViT(
            config=config.vit_config,
            rngs=rngs,
        )
        self.encoder1 = UnetrBasicBlock(
            in_channels=vit_config.in_channels,
            out_channels=config.feature_size,
            kernel_size=3,
            stride=1,
            norm_layer=config.norm_layer,
            rngs=rngs,
        )
        self.encoder2 = UnetrPrUpBlock(
            in_channels=vit_config.hidden_size,
            out_channels=config.feature_size * 2,
            num_layer=2,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=2,
            norm_layer=config.norm_layer,
            rngs=rngs,
        )
        self.encoder3 = UnetrPrUpBlock(
            in_channels=vit_config.hidden_size,
            out_channels=config.feature_size * 4,
            num_layer=1,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=2,
            norm_layer=config.norm_layer,
            rngs=rngs,
        )
        self.encoder4 = UnetrPrUpBlock(
            in_channels=vit_config.hidden_size,
            out_channels=config.feature_size * 8,
            num_layer=0,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=2,
            norm_layer=config.norm_layer,
            rngs=rngs,
        )
        self.decoder5 = UnetrUpBlock(
            in_channels=vit_config.hidden_size,
            out_channels=config.feature_size * 8,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_layer=config.norm_layer,
            rngs=rngs,
        )
        self.decoder4 = UnetrUpBlock(
            in_channels=config.feature_size * 8,
            out_channels=config.feature_size * 4,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_layer=config.norm_layer,
            rngs=rngs,
        )
        self.decoder3 = UnetrUpBlock(
            in_channels=config.feature_size * 4,
            out_channels=config.feature_size * 2,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_layer=config.norm_layer,
            rngs=rngs,
        )
        self.decoder2 = UnetrUpBlock(
            in_channels=config.feature_size * 2,
            out_channels=config.feature_size,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_layer=config.norm_layer,
            rngs=rngs,
        )

        self.out = nnx.Conv(
            in_features=config.feature_size,
            out_features=config.out_channels,
            kernel_size=(1, 1),
            strides=(1, 1),
            padding="VALID",
            use_bias=True,
            rngs=rngs,
        )

        self.proj_axes = (0, 1, 2, 3)
        self.proj_view_shape = [feat_size, feat_size, vit_config.hidden_size]

    def proj_feat(self, x: jax.Array) -> jax.Array:
        new_view = [x.shape[0]] + self.proj_view_shape
        x = x.reshape(new_view)
        x = jnp.permute_dims(x, self.proj_axes)
        return x

    def __call__(
        self, 
        x_in: jax.Array,
        rngs: nnx.Rngs | None = None,
    ) -> jax.Array:
        x, hidden_states_out = self.vit(x_in, rngs=rngs)
        enc1 = self.encoder1(x_in)
        x2 = hidden_states_out[3]
        enc2 = self.encoder2(self.proj_feat(x2))
        x3 = hidden_states_out[6]
        enc3 = self.encoder3(self.proj_feat(x3))
        x4 = hidden_states_out[9]
        enc4 = self.encoder4(self.proj_feat(x4))
        dec4 = self.proj_feat(x)
        dec3 = self.decoder5(dec4, enc4)
        dec2 = self.decoder4(dec3, enc3)
        dec1 = self.decoder3(dec2, enc2)
        out = self.decoder2(dec1, enc1)
        return self.out(out)
# We'll use a different number of heads to make a smaller model
mesh = jax.make_mesh((jax.device_count(),), ("fsdp",))
with jax.set_mesh(mesh):
    config = ModelConfig(
        out_channels=3, 
        vit_config=ViTConfig(
            num_heads=4,
            img_size=256,
            sharding=ShardingConfig.fsdp_sharding()
        ),
    )
    model = UNETR(config, rngs=nnx.Rngs(0))
    x = jnp.ones((4, 256, 256, 3), out_sharding=jax.P("fsdp"))
    y = model(x, rngs=nnx.Rngs(0))
    print("Predictions shape: ", jax.typeof(y))
E0519 14:57:51.784197   56855 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0519 14:57:53.856762   56823 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0519 14:58:00.225197   56823 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0519 14:58:01.807232   56813 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0519 14:58:03.571723   56835 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
Predictions shape:  float32[4@fsdp,256,256,3]

We can visualize and inspect the architecture on the implemented model using nnx.display(model).

Train the model#

In previous sections we defined training and validation dataloaders and the model. In this section we will train the model and define the loss function and the optimizer to perform the parameters optimization.

For the semantic segmentation task, we can define the loss function as a sum of Cross-Entropy and Jaccard loss functions. The Cross-Entropy loss function is a standard loss function for a multi-class classification tasks and the Jaccard loss function helps directly optimizing Intersection-over-Union measure for semantic segmentation.

import optax

num_epochs = 15
total_steps = len(train_dataset) // train_batch_size
learning_rate = 0.003
momentum = 0.9
lr_schedule = optax.linear_schedule(learning_rate, 0.0, num_epochs * total_steps)

iterate_subsample = np.linspace(0, num_epochs * total_steps, 100)
plt.plot(
    np.linspace(0, num_epochs, len(iterate_subsample)),
    [lr_schedule(i) for i in iterate_subsample],
    lw=3,
)
plt.title("Learning rate")
plt.xlabel("Epochs")
plt.ylabel("Learning rate")
plt.grid()
plt.xlim((0, num_epochs))
plt.show()

with jax.set_mesh(mesh):
    optimizer = nnx.Optimizer(model, optax.adam(lr_schedule, momentum), wrt=nnx.Param)
../_images/4b09fa1be3e36e33701656c2e235b601c0fd762e2775cc2ffa96694007f5a34c.png

Let us implement Jaccard loss and the loss function combining Cross-Entropy and Jaccard losses.

def compute_softmax_jaccard_loss(logits, masks, reduction="mean"):
    assert reduction in ("mean", "sum")
    y_pred = nnx.softmax(logits, axis=-1)
    b, c = y_pred.shape[0], y_pred.shape[-1]
    y = nnx.one_hot(masks, num_classes=c, axis=-1)

    y_pred = y_pred.reshape((b, -1, c))
    y = y.reshape((b, -1, c))

    intersection = y_pred * y
    union = y_pred + y - intersection + 1e-8

    intersection = jnp.sum(intersection, axis=1)
    union = jnp.sum(union, axis=1)

    if reduction == "mean":
        intersection = jnp.mean(intersection)
        union = jnp.mean(union)
    elif reduction == "sum":
        intersection = jnp.sum(intersection)
        union = jnp.sum(union)

    return 1.0 - intersection / union


def compute_losses_and_logits(
    model: nnx.Module, 
    images: jax.Array, 
    masks: jax.Array,
    rngs: nnx.Rngs | None = None,
):
    logits = model(images, rngs=rngs)
    xentropy_loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=masks
    ).mean()
    jacc_loss = compute_softmax_jaccard_loss(logits=logits, masks=masks)
    loss = xentropy_loss + jacc_loss
    return loss, (xentropy_loss, jacc_loss, logits)

Now, we will implement a confusion matrix metric derived from nnx.Metric. A confusion matrix will help us to compute the Intersection-Over-Union (IoU) metric per class and on average. Finally, we can also compute the accuracy metric using the confusion matrix.

class ConfusionMatrix(nnx.Metric):
    def __init__(
        self,
        num_classes: int,
        average: str | None = None,
    ):
        assert average in (None, "samples", "recall", "precision")
        assert num_classes > 0
        self.num_classes = num_classes
        self.average = average
        self.confusion_matrix = nnx.metrics.MetricState(
            jnp.zeros((self.num_classes, self.num_classes), dtype=jnp.int32)
        )
        self.count = nnx.metrics.MetricState(jnp.array(0, dtype=jnp.int32))

    def reset(self):
        self.confusion_matrix.value = jnp.zeros((self.num_classes, self.num_classes), dtype=jnp.int32)
        self.count.value = jnp.array(0, dtype=jnp.int32)

    def _check_shape(self, y_pred: jax.Array, y: jax.Array):
        if y_pred.shape[-1] != self.num_classes:
            raise ValueError(f"y_pred does not have correct number of classes: {y_pred.shape[-1]} vs {self.num_classes}")

        if not (y.ndim + 1 == y_pred.ndim):
            raise ValueError(
                f"y_pred must have shape (batch_size, num_classes (currently set to {self.num_classes}), ...) "
                "and y must have shape of (batch_size, ...), "
                f"but given {y.shape} vs {y_pred.shape}."
            )

    def update(self, **kwargs):
        # We assume that y.max() < self.num_classes and y.min() >= 0
        assert "y" in kwargs
        assert "y_pred" in kwargs
        y_pred = kwargs["y_pred"]
        y = kwargs["y"]
        self._check_shape(y_pred, y)
        self.count[...] += y_pred.shape[0]

        y_pred = jnp.argmax(y_pred, axis=-1).ravel()
        y = y.ravel()
        indices = self.num_classes * y + y_pred
        if not jax.typeof(indices).sharding.is_fully_replicated:
            indices = jax.sharding.reshard(indices, jax.P())
        matrix = jnp.bincount(indices, minlength=self.num_classes**2, length=self.num_classes**2)
        matrix = matrix.reshape((self.num_classes, self.num_classes))
        self.confusion_matrix[...] += matrix

    def compute(self) -> jax.Array:
        if self.average:
            confusion_matrix = self.confusion_matrix.get_value().astype("float")
            if self.average == "samples":
                return confusion_matrix / self.count.get_value()
            else:
                return self.normalize(self.confusion_matrix.get_value(), self.average)
        return self.confusion_matrix.get_value()

    @staticmethod
    def normalize(matrix: jax.Array, average: str) -> jax.Array:
        """Normalize given `matrix` with given `average`."""
        if average == "recall":
            return matrix / (jnp.expand_dims(matrix.sum(axis=1), axis=1) + 1e-15)
        elif average == "precision":
            return matrix / (matrix.sum(axis=0) + 1e-15)
        else:
            raise ValueError("Argument average should be one of 'samples', 'recall', 'precision'")


def compute_iou(cm: jax.Array) -> jax.Array:
    return jnp.diag(cm) / (cm.sum(axis=1) + cm.sum(axis=0) - jnp.diag(cm) + 1e-15)


def compute_mean_iou(cm: jax.Array) -> jax.Array:
    return compute_iou(cm).mean()


def compute_accuracy(cm: jax.Array) -> jax.Array:
    return jnp.diag(cm).sum() / (cm.sum() + 1e-15)

Next, let’s define training and evaluation steps:

@nnx.jit(donate_argnames=("model", "optimizer"))
def train_step(
    model: nnx.Module, 
    optimizer: nnx.Optimizer, 
    rngs: nnx.Rngs,
    batch: tuple[jax.Array, jax.Array],
):
    images, masks = batch
    grad_fn = nnx.value_and_grad(compute_losses_and_logits, has_aux=True)
    (loss, (xentropy_loss, jacc_loss, logits)), grads = grad_fn(model, images, masks, rngs.fork())
    optimizer.update(model, grads)  # In-place updates.

    return loss, xentropy_loss, jacc_loss
@nnx.jit
def eval_step(
    model: nnx.Module, batch: tuple[jax.Array, jax.Array], eval_metrics: nnx.MultiMetric
):
    images, masks = batch
    loss, (_, _, logits) = compute_losses_and_logits(model, images, masks)
    eval_metrics.update(
        total_loss=loss,
        y_pred=logits,
        y=masks,
    )

We will also define metrics we want to compute during the evaluation phase: total loss and confusion matrix computed on training and validation datasets. Finally, we define helper objects to store the metrics history. Metrics like IoU per class, mean IoU and accuracy will be computed using the confusion matrix in the evaluation code.

eval_metrics = nnx.MultiMetric(
    total_loss=nnx.metrics.Average('total_loss'),
    confusion_matrix=ConfusionMatrix(num_classes=3),
)


eval_metrics_history = {
    "train_total_loss": [],
    "train_IoU": [],
    "train_mean_IoU": [],
    "train_accuracy": [],

    "val_total_loss": [],
    "val_IoU": [],
    "val_mean_IoU": [],
    "val_accuracy": [],
}

Let us define the training and evaluation logic. We define as well a checkpoint manager to store two best models defined by validation mean IoU metric value.

import time
import orbax.checkpoint as ocp
from orbax.checkpoint.checkpoint_managers import preservation_policy as preservation_policy_lib
from orbax.checkpoint.path import atomicity

# We define a view of the model sharing the weights but with attributes set for evaluation
eval_model = nnx.view(model, deterministic=True)
rngs = nnx.Rngs(12)


def train_one_epoch(epoch):
    start_time = time.time()

    with jax.set_mesh(mesh):
        for step, batch in enumerate(train_loader):
    
            # Convert np.ndarray to jax.Array on GPUs
            images = jax.device_put(batch["image"], device=jax.P("fsdp"))
            masks = jax.device_put(batch["mask"].astype(int), device=jax.P("fsdp"))
            
            total_loss, xentropy_loss, jaccard_loss = train_step(model, optimizer, rngs, (images, masks))
    
            print(
                f"\r[train] epoch: {epoch + 1}/{num_epochs}, iteration: {step}/{total_steps}, "
                f"total loss: {total_loss.item():.4f} ",
                f"xentropy loss: {xentropy_loss.item():.4f} ",
                f"jaccard loss: {jaccard_loss.item():.4f} ",
                end="")
            print("\r", end="")

    elapsed = time.time() - start_time
    print(
        f"\n[train] epoch: {epoch + 1}/{num_epochs}, elapsed time: {elapsed:.2f} seconds"
    )


def evaluate_model(epoch):
    start_time = time.time()

    with jax.set_mesh(mesh):
        # Compute the metrics on the train and val sets after each training epoch.
        for tag, eval_loader in [("train", train_eval_loader), ("val", val_loader)]:
            eval_metrics.reset()  # Reset the eval metrics
            for val_batch in eval_loader:
    
                # Convert np.ndarray to jax.Array on GPUs
                images = jax.device_put(val_batch["image"], device=jax.P("fsdp"))
                masks = jax.device_put(val_batch["mask"].astype(int), device=jax.P("fsdp"))
                    
                eval_step(eval_model, (images, masks), eval_metrics)
    
            for metric, value in eval_metrics.compute().items():
                if metric == "confusion_matrix":
                    eval_metrics_history[f"{tag}_IoU"].append(
                        compute_iou(value)
                    )
                    eval_metrics_history[f"{tag}_mean_IoU"].append(
                        compute_mean_iou(value)
                    )
                    eval_metrics_history[f"{tag}_accuracy"].append(
                        compute_accuracy(value)
                    )
                else:
                    eval_metrics_history[f'{tag}_{metric}'].append(value)
    
            print(
                f"[{tag}] epoch: {epoch + 1}/{num_epochs} "
                f"\n - total loss: {eval_metrics_history[f'{tag}_total_loss'][-1]:0.4f} "
                f"\n - IoU per class: {eval_metrics_history[f'{tag}_IoU'][-1].tolist()} "
                f"\n - Mean IoU: {eval_metrics_history[f'{tag}_mean_IoU'][-1]:0.4f} "
                f"\n - Accuracy: {eval_metrics_history[f'{tag}_accuracy'][-1]:0.4f} "
                "\n"
            )

    elapsed = time.time() - start_time
    print(
        f"[evaluation] epoch: {epoch + 1}/{num_epochs}, elapsed time: {elapsed:.2f} seconds"
    )

    return eval_metrics_history['val_mean_IoU'][-1]


checkpoint_path = ocp.test_utils.erase_and_create_empty("/tmp/output-oxford-model/")
checkpoint_mngr = ocp.CheckpointManager(
  checkpoint_path,
  options=ocp.CheckpointManagerOptions(
      preservation_policy=preservation_policy_lib.LatestN(1),
      temporary_path_class=atomicity.CommitFileTemporaryPath,
  ),
)

def save_model(epoch):
    checkpoint_mngr.save(epoch, args=ocp.args.StandardSave(nnx.state(model)))
    checkpoint_mngr.wait_until_finished()

Now we can start the training. It can take around 45 minutes using a single GPU and use 19GB of GPU memory.

%%time

best_val_mean_iou = 0.0
for epoch in range(num_epochs):
    train_one_epoch(epoch)
    if (epoch % 3 == 0) or (epoch == num_epochs - 1):
        val_mean_iou = evaluate_model(epoch)
        if val_mean_iou > best_val_mean_iou:
            save_model(epoch)
            best_val_mean_iou = val_mean_iou
E0519 14:58:32.077355   56846 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0519 14:58:33.187096   56826 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0519 14:58:35.464285   56851 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0519 14:58:35.925939   56849 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0519 14:58:36.720911   56812 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0519 14:58:38.163433   56820 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0519 14:58:38.678921   56833 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0519 14:58:39.371561   56808 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0519 14:58:40.181144   56841 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0519 14:58:40.444962   56839 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0519 14:58:41.022651   56855 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0519 14:58:43.031146   56818 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0519 14:58:43.543444   56849 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0519 14:58:44.419156   56828 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0519 14:58:44.632559   56835 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0519 14:58:45.369472   56852 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
[train] epoch: 1/15, iteration: 67/71, total loss: 1.5253  xentropy loss: 0.8606  jaccard loss: 0.6647 
[train] epoch: 1/15, elapsed time: 166.39 seconds
E0519 15:01:03.430295   56840 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0519 15:01:04.671282   56808 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0519 15:01:05.447053   56825 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0519 15:01:05.753242   56839 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0519 15:01:06.452114   56821 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0519 15:01:06.968773   56830 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0519 15:01:07.644647   56850 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0519 15:01:09.240093   56843 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0519 15:01:10.042085   56827 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0519 15:01:10.998473   56822 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
[train] epoch: 1/15 
 - total loss: 1.4837 
 - IoU per class: [0.34460192918777466, 0.5718100666999817, 0.09155400097370148] 
 - Mean IoU: 0.3360 
 - Accuracy: 0.6143 

[val] epoch: 1/15 
 - total loss: 1.4854 
 - IoU per class: [0.3414991796016693, 0.570045530796051, 0.09269510954618454] 
 - Mean IoU: 0.3347 
 - Accuracy: 0.6125 

[evaluation] epoch: 1/15, elapsed time: 46.60 seconds
[train] epoch: 2/15, iteration: 67/71, total loss: 1.4079  xentropy loss: 0.7911  jaccard loss: 0.6168 
[train] epoch: 2/15, elapsed time: 63.04 seconds
[train] epoch: 3/15, iteration: 67/71, total loss: 1.3208  xentropy loss: 0.7426  jaccard loss: 0.5781 
[train] epoch: 3/15, elapsed time: 63.87 seconds
[train] epoch: 4/15, iteration: 67/71, total loss: 1.2886  xentropy loss: 0.7257  jaccard loss: 0.5630 
[train] epoch: 4/15, elapsed time: 63.48 seconds
[train] epoch: 4/15 
 - total loss: 1.2749 
 - IoU per class: [0.46864765882492065, 0.6711030602455139, 0.11210985481739044] 
 - Mean IoU: 0.4173 
 - Accuracy: 0.7023 

[val] epoch: 4/15 
 - total loss: 1.2808 
 - IoU per class: [0.46540209650993347, 0.6677289605140686, 0.11328594386577606] 
 - Mean IoU: 0.4155 
 - Accuracy: 0.6998 

[evaluation] epoch: 4/15, elapsed time: 29.41 seconds
[train] epoch: 5/15, iteration: 67/71, total loss: 1.2797  xentropy loss: 0.7253  jaccard loss: 0.5544 
[train] epoch: 5/15, elapsed time: 62.84 seconds
[train] epoch: 6/15, iteration: 67/71, total loss: 1.2547  xentropy loss: 0.7107  jaccard loss: 0.5440 
[train] epoch: 6/15, elapsed time: 62.77 seconds
[train] epoch: 7/15, iteration: 67/71, total loss: 1.2391  xentropy loss: 0.7005  jaccard loss: 0.5386 
[train] epoch: 7/15, elapsed time: 64.31 seconds
[train] epoch: 7/15 
 - total loss: 1.2049 
 - IoU per class: [0.5125540494918823, 0.6961598992347717, 0.11728569120168686] 
 - Mean IoU: 0.4420 
 - Accuracy: 0.7255 

[val] epoch: 7/15 
 - total loss: 1.2156 
 - IoU per class: [0.5070227384567261, 0.6905422210693359, 0.11896316707134247] 
 - Mean IoU: 0.4388 
 - Accuracy: 0.7215 

[evaluation] epoch: 7/15, elapsed time: 29.63 seconds
[train] epoch: 8/15, iteration: 67/71, total loss: 1.2238  xentropy loss: 0.6901  jaccard loss: 0.5337 
[train] epoch: 8/15, elapsed time: 63.24 seconds
[train] epoch: 9/15, iteration: 67/71, total loss: 1.2219  xentropy loss: 0.6897  jaccard loss: 0.5322 
[train] epoch: 9/15, elapsed time: 63.83 seconds
[train] epoch: 10/15, iteration: 67/71, total loss: 1.2129  xentropy loss: 0.6861  jaccard loss: 0.5268 
[train] epoch: 10/15, elapsed time: 62.33 seconds
[train] epoch: 10/15 
 - total loss: 1.1916 
 - IoU per class: [0.5221982002258301, 0.6984012126922607, 0.11644359678030014] 
 - Mean IoU: 0.4457 
 - Accuracy: 0.7289 

[val] epoch: 10/15 
 - total loss: 1.2030 
 - IoU per class: [0.516647458076477, 0.6924700140953064, 0.11903645843267441] 
 - Mean IoU: 0.4427 
 - Accuracy: 0.7249 

[evaluation] epoch: 10/15, elapsed time: 29.98 seconds
[train] epoch: 11/15, iteration: 67/71, total loss: 1.2074  xentropy loss: 0.6832  jaccard loss: 0.5242 
[train] epoch: 11/15, elapsed time: 64.54 seconds
[train] epoch: 12/15, iteration: 67/71, total loss: 1.1970  xentropy loss: 0.6752  jaccard loss: 0.5219 
[train] epoch: 12/15, elapsed time: 63.41 seconds
[train] epoch: 13/15, iteration: 67/71, total loss: 1.1962  xentropy loss: 0.6771  jaccard loss: 0.5191 
[train] epoch: 13/15, elapsed time: 63.26 seconds
[train] epoch: 13/15 
 - total loss: 1.1548 
 - IoU per class: [0.5388510227203369, 0.7134674191474915, 0.14061923325061798] 
 - Mean IoU: 0.4643 
 - Accuracy: 0.7411 

[val] epoch: 13/15 
 - total loss: 1.1695 
 - IoU per class: [0.5313435196876526, 0.7059954404830933, 0.1432698518037796] 
 - Mean IoU: 0.4602 
 - Accuracy: 0.7359 

[evaluation] epoch: 13/15, elapsed time: 29.63 seconds
[train] epoch: 14/15, iteration: 67/71, total loss: 1.1926  xentropy loss: 0.6742  jaccard loss: 0.5184 
[train] epoch: 14/15, elapsed time: 63.99 seconds
[train] epoch: 15/15, iteration: 67/71, total loss: 1.1957  xentropy loss: 0.6767  jaccard loss: 0.5190 
[train] epoch: 15/15, elapsed time: 63.26 seconds
[train] epoch: 15/15 
 - total loss: 1.1403 
 - IoU per class: [0.5411749482154846, 0.7202478051185608, 0.14533428847789764] 
 - Mean IoU: 0.4689 
 - Accuracy: 0.7456 

[val] epoch: 15/15 
 - total loss: 1.1577 
 - IoU per class: [0.5310024619102478, 0.7110327482223511, 0.14793400466442108] 
 - Mean IoU: 0.4633 
 - Accuracy: 0.7390 

[evaluation] epoch: 15/15, elapsed time: 29.74 seconds
CPU times: user 6min 32s, sys: 3min 8s, total: 9min 41s
Wall time: 21min 13s

We can check the saved models:

!ls /tmp/output-oxford-model/
/usr/lib/python3.12/pty.py:95: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  pid, fd = os.forkpty()
14

and visualize collected metrics:

epochs = [i for i in range(num_epochs) if (i % 3 == 0) or (i == num_epochs - 1)]

plt.plot(epochs, eval_metrics_history["train_total_loss"], label="Loss value on training set")
plt.plot(epochs, eval_metrics_history["val_total_loss"], label="Loss value on validation set")
plt.legend()
<matplotlib.legend.Legend at 0x7d81dc2e5eb0>
../_images/6f0ddb939a5e3f9f539f59764ff69b44571e912b0b7e556000f53f835987b112.png
plt.plot(epochs, eval_metrics_history["train_mean_IoU"], label="Mean IoU on training set")
plt.plot(epochs, eval_metrics_history["val_mean_IoU"], label="Mean IoU on validation set")
plt.legend()
<matplotlib.legend.Legend at 0x7d81e40b4ce0>
../_images/5b0d3e4fca3641a4af79f4c35d269ee6073088ea8a8e3399c973c94e32a0a0dc.png

Next, we will visualize model predictions on validation data:

val_batch = next(iter(val_loader))
images, masks = val_batch["image"], val_batch["mask"]
with jax.set_mesh(mesh):
    x = jax.device_put(images, device=jax.P("fsdp"))
    preds = eval_model(x)
    preds = jnp.argmax(preds, axis=-1)
    preds = jax.sharding.reshard(preds, jax.P())
E0519 15:19:41.770595   56852 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0519 15:19:49.648804   56825 cuda_timer.cc:87] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
def display_image_mask_pred(img, mask, pred, label=""):
    if img.dtype in (np.float32, ):
        img = ((img - img.min()) / (img.max() - img.min()) * 255.0).astype(np.uint8)
    fig, axs = plt.subplots(1, 5, figsize=(15, 10))
    axs[0].set_title(f"Image{label}")
    axs[0].imshow(img)
    axs[1].set_title(f"Mask{label}")
    axs[1].imshow(mask)
    axs[2].set_title("Image + Mask")
    axs[2].imshow(img)
    axs[2].imshow(mask, alpha=0.5)
    axs[3].set_title(f"Pred{label}")
    axs[3].imshow(pred)
    axs[4].set_title("Image + Pred")
    axs[4].imshow(img)
    axs[4].imshow(pred, alpha=0.5)
for img, mask, pred in zip(images[:4], masks[:4], preds[:4]):
    display_image_mask_pred(img, mask, pred, label=" (validation set)")
../_images/8c263c5b46ea8c121cbe747cf0b5f8cb88089e3b6874e46a868c6caa0bda541f.png ../_images/f92e1bdc8864e045876249e0b2c628d8d52291b5d374029ba6279dd76da3ce6b.png ../_images/f456228136f456e6d56ab04c72256e4969573d54324e4aa9a4c203c8b814225f.png ../_images/e9e6bc56f5db389d6bc8227f9f97562b1dd0610e9d49d33a37594b5b14ac8c57.png

We can see that model can predict approximative shape of the animal and the background and struggles with predicting the boundary. Carefully choosing hyperparameters and training for more epochs we may achieve better results.