Case Study 1: Building an Image Classifier with ResNet

Overview

In this case study, we build a complete image classification system using a ResNet architecture trained on the CIFAR-10 dataset. We implement ResNet-18 from scratch, train it with modern best practices, and analyze its behavior using the visualization techniques from Chapter 14. This case study brings together concepts from loss functions (Chapter 8), optimization (Chapter 12), regularization (Chapter 13), and the CNN-specific material from this chapter.

Problem Statement

CIFAR-10 contains 60,000 32x32 color images in 10 classes (airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck), with 50,000 training images and 10,000 test images. While the images are small and low-resolution, achieving high accuracy requires a well-designed CNN with proper training methodology.

Our goals: 1. Implement ResNet-18 from scratch with proper skip connections 2. Train it to achieve over 93% test accuracy on CIFAR-10 3. Analyze training dynamics, learned features, and failure cases 4. Compare with a baseline CNN without skip connections

Architecture Design

Why ResNet for CIFAR-10?

CIFAR-10 images are only 32x32 pixels -- much smaller than the 224x224 images ResNet was originally designed for. We adapt the architecture by: - Using a 3x3 initial convolution (instead of 7x7) with stride 1 (instead of stride 2) - Removing the initial max pooling layer - Keeping the four residual stages but with appropriate channel counts

These modifications preserve the spatial resolution through the early layers, which is crucial for small images. With the original ResNet design, a 32x32 image would be reduced to 1x1 by the third stage, losing all spatial information.

ResNet-18 for CIFAR-10

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from typing import Tuple, List, Optional
import time

torch.manual_seed(42)


class BasicBlock(nn.Module):
    """ResNet Basic Block with skip connection.

    Each block contains two 3x3 convolutions with batch normalization
    and a residual shortcut connection.

    Args:
        in_channels: Number of input channels.
        out_channels: Number of output channels.
        stride: Stride for the first convolution (used for downsampling).
    """

    expansion: int = 1

    def __init__(
        self, in_channels: int, out_channels: int, stride: int = 1
    ) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels, out_channels, kernel_size=3,
            stride=stride, padding=1, bias=False,
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3,
            stride=1, padding=1, bias=False,
        )
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_channels, out_channels, kernel_size=1,
                    stride=stride, bias=False,
                ),
                nn.BatchNorm2d(out_channels),
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass with residual connection.

        Args:
            x: Input tensor of shape (N, C, H, W).

        Returns:
            Output tensor after residual addition and activation.
        """
        identity = self.shortcut(x)
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += identity
        out = F.relu(out)
        return out


class ResNet18CIFAR(nn.Module):
    """ResNet-18 adapted for CIFAR-10 (32x32 images).

    Modifications from standard ResNet-18:
    - Initial 3x3 conv with stride 1 instead of 7x7 with stride 2
    - No initial max pooling
    - Four stages with [2, 2, 2, 2] blocks and [64, 128, 256, 512] channels

    Args:
        num_classes: Number of output classes. Defaults to 10.
    """

    def __init__(self, num_classes: int = 10) -> None:
        super().__init__()
        self.in_channels = 64

        # Initial convolution (no aggressive downsampling for 32x32 images)
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)

        # Residual stages
        self.layer1 = self._make_layer(64, num_blocks=2, stride=1)   # 32x32
        self.layer2 = self._make_layer(128, num_blocks=2, stride=2)  # 16x16
        self.layer3 = self._make_layer(256, num_blocks=2, stride=2)  # 8x8
        self.layer4 = self._make_layer(512, num_blocks=2, stride=2)  # 4x4

        # Classification head
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(512, num_classes)

        # Weight initialization
        self._initialize_weights()

    def _make_layer(
        self, out_channels: int, num_blocks: int, stride: int
    ) -> nn.Sequential:
        """Create a residual stage with multiple blocks.

        Args:
            out_channels: Number of output channels for this stage.
            num_blocks: Number of residual blocks.
            stride: Stride for the first block (downsampling).

        Returns:
            Sequential container of residual blocks.
        """
        strides = [stride] + [1] * (num_blocks - 1)
        layers: List[nn.Module] = []
        for s in strides:
            layers.append(BasicBlock(self.in_channels, out_channels, s))
            self.in_channels = out_channels
        return nn.Sequential(*layers)

    def _initialize_weights(self) -> None:
        """Initialize weights using Kaiming initialization."""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through ResNet-18.

        Args:
            x: Input tensor of shape (N, 3, 32, 32).

        Returns:
            Logits tensor of shape (N, num_classes).
        """
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

Training Pipeline

Data Preparation

We use standard CIFAR-10 augmentation: random crops with padding and horizontal flips. These simple augmentations are remarkably effective -- they can improve accuracy by 2-4% on CIFAR-10.

def get_cifar10_loaders(
    batch_size: int = 128,
    num_workers: int = 2,
) -> Tuple[DataLoader, DataLoader]:
    """Create CIFAR-10 data loaders with standard augmentation.

    Args:
        batch_size: Samples per mini-batch. Defaults to 128.
        num_workers: Data loading workers. Defaults to 2.

    Returns:
        Tuple of (train_loader, test_loader).
    """
    # CIFAR-10 normalization statistics
    cifar10_mean = (0.4914, 0.4822, 0.4465)
    cifar10_std = (0.2470, 0.2435, 0.2616)

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(cifar10_mean, cifar10_std),
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(cifar10_mean, cifar10_std),
    ])

    train_set = datasets.CIFAR10(
        root="./data", train=True, download=True, transform=train_transform,
    )
    test_set = datasets.CIFAR10(
        root="./data", train=False, download=True, transform=test_transform,
    )

    train_loader = DataLoader(
        train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers,
    )
    test_loader = DataLoader(
        test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers,
    )

    return train_loader, test_loader

Training Loop with Learning Rate Scheduling

We use SGD with momentum and a cosine annealing learning rate schedule, which we found in Chapter 12 to provide smooth convergence without requiring manual tuning of decay milestones.

def train_resnet(
    num_epochs: int = 200,
    batch_size: int = 128,
    learning_rate: float = 0.1,
    weight_decay: float = 5e-4,
    device: Optional[torch.device] = None,
) -> Tuple[nn.Module, dict]:
    """Train ResNet-18 on CIFAR-10 with cosine annealing.

    Args:
        num_epochs: Number of training epochs. Defaults to 200.
        batch_size: Mini-batch size. Defaults to 128.
        learning_rate: Initial learning rate. Defaults to 0.1.
        weight_decay: L2 regularization strength. Defaults to 5e-4.
        device: Device to train on. Auto-detects if None.

    Returns:
        Tuple of (trained model, training history dict).
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print(f"Training on: {device}")

    # Data
    train_loader, test_loader = get_cifar10_loaders(batch_size)

    # Model
    model = ResNet18CIFAR(num_classes=10).to(device)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params:,}")

    # Loss, optimizer, scheduler
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(
        model.parameters(),
        lr=learning_rate,
        momentum=0.9,
        weight_decay=weight_decay,
    )
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    # Training history
    history = {
        "train_loss": [], "test_loss": [],
        "train_acc": [], "test_acc": [],
        "lr": [],
    }

    best_acc = 0.0

    for epoch in range(num_epochs):
        # --- Training ---
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            train_total += targets.size(0)
            train_correct += predicted.eq(targets).sum().item()

        scheduler.step()

        # --- Evaluation ---
        model.eval()
        test_loss = 0.0
        test_correct = 0
        test_total = 0

        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)

                test_loss += loss.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                test_total += targets.size(0)
                test_correct += predicted.eq(targets).sum().item()

        # Record metrics
        train_acc = train_correct / train_total
        test_acc = test_correct / test_total
        avg_train_loss = train_loss / train_total
        avg_test_loss = test_loss / test_total
        current_lr = scheduler.get_last_lr()[0]

        history["train_loss"].append(avg_train_loss)
        history["test_loss"].append(avg_test_loss)
        history["train_acc"].append(train_acc)
        history["test_acc"].append(test_acc)
        history["lr"].append(current_lr)

        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), "best_resnet18_cifar10.pth")

        if (epoch + 1) % 20 == 0:
            print(
                f"Epoch [{epoch+1}/{num_epochs}] "
                f"Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.4f} | "
                f"Test Loss: {avg_test_loss:.4f} | Test Acc: {test_acc:.4f} | "
                f"LR: {current_lr:.6f}"
            )

    print(f"\nBest Test Accuracy: {best_acc:.4f}")
    return model, history

Analysis and Results

Expected Performance

With the configuration above (200 epochs, SGD with cosine annealing, standard augmentation), ResNet-18 typically achieves: - Training accuracy: ~99.5% - Test accuracy: ~93-94% - The gap indicates some overfitting, which is expected for CIFAR-10

Comparing With and Without Skip Connections

To appreciate the impact of skip connections, we train the same architecture but zero out the shortcut paths:

class PlainNet18CIFAR(nn.Module):
    """ResNet-18 architecture WITHOUT skip connections (plain network).

    This serves as an ablation study to demonstrate the importance
    of residual connections for training deep networks.

    Args:
        num_classes: Number of output classes. Defaults to 10.
    """

    def __init__(self, num_classes: int = 10) -> None:
        super().__init__()
        # Same architecture as ResNet-18 but without shortcuts
        self.features = nn.Sequential(
            # Initial conv
            nn.Conv2d(3, 64, 3, 1, 1, bias=False), nn.BatchNorm2d(64), nn.ReLU(),
            # Stage 1: 4 conv layers at 64 channels
            nn.Conv2d(64, 64, 3, 1, 1, bias=False), nn.BatchNorm2d(64), nn.ReLU(),
            nn.Conv2d(64, 64, 3, 1, 1, bias=False), nn.BatchNorm2d(64), nn.ReLU(),
            nn.Conv2d(64, 64, 3, 1, 1, bias=False), nn.BatchNorm2d(64), nn.ReLU(),
            nn.Conv2d(64, 64, 3, 1, 1, bias=False), nn.BatchNorm2d(64), nn.ReLU(),
            # Stage 2: downsample + 3 conv layers at 128 channels
            nn.Conv2d(64, 128, 3, 2, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(),
            nn.Conv2d(128, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(),
            nn.Conv2d(128, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(),
            nn.Conv2d(128, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(),
            # Stage 3: downsample + 3 conv layers at 256 channels
            nn.Conv2d(128, 256, 3, 2, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(),
            nn.Conv2d(256, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(),
            nn.Conv2d(256, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(),
            nn.Conv2d(256, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(),
            # Stage 4: downsample + 3 conv layers at 512 channels
            nn.Conv2d(256, 512, 3, 2, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(),
            nn.Conv2d(512, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(),
            nn.Conv2d(512, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(),
            nn.Conv2d(512, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(),
        )
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass without skip connections."""
        out = self.features(x)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

Expected results of this comparison: - PlainNet-18: ~90-91% test accuracy, slower convergence, higher training loss - ResNet-18: ~93-94% test accuracy, faster convergence, lower training loss

The difference, while modest for an 18-layer network, becomes dramatic for deeper networks (56+ layers), where plain networks can fail to converge entirely.

Error Analysis

After training, we analyze which classes are most confused:

def compute_confusion_matrix(
    model: nn.Module,
    test_loader: DataLoader,
    device: torch.device,
    num_classes: int = 10,
) -> torch.Tensor:
    """Compute the confusion matrix for the test set.

    Args:
        model: Trained model.
        test_loader: Test data loader.
        device: Device for computation.
        num_classes: Number of classes.

    Returns:
        Confusion matrix tensor of shape (num_classes, num_classes).
    """
    model.eval()
    confusion = torch.zeros(num_classes, num_classes, dtype=torch.int64)

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)

            for t, p in zip(targets, predicted.cpu()):
                confusion[t, p] += 1

    return confusion

Common confusion pairs in CIFAR-10: - Cat vs. Dog: Both are furry animals with similar shapes at 32x32 resolution - Automobile vs. Truck: Both are vehicles with wheels - Deer vs. Horse: Both are quadrupeds with similar body structure

These confusions make intuitive sense and highlight the challenge of fine-grained recognition at low resolution.

Key Takeaways

  1. Skip connections are essential for training deeper networks, enabling gradient flow and simplifying the optimization landscape.
  2. Architecture adaptation matters: A CIFAR-10 ResNet needs different stem design than an ImageNet ResNet due to the smaller input resolution.
  3. Kaiming initialization combined with batch normalization enables stable training of deep networks from the start.
  4. Cosine annealing provides smooth, effective learning rate decay without requiring manual milestone tuning.
  5. Data augmentation is cheap and effective -- random crops and flips provide 2-4% accuracy improvement on CIFAR-10.
  6. Error analysis reveals intuitive failure modes that align with human perception of visual similarity.

Connection to Later Chapters

The ResNet architecture and skip connection principle will reappear when we study: - LSTM and GRU gates (Chapter 15): Gating mechanisms serve a similar function to skip connections, controlling information flow through time - Transformer architectures (Chapter 16): Residual connections are a core component of every transformer block - Generative models (Chapter 17): Encoder-decoder architectures with skip connections (U-Net) are fundamental to image generation