Case Study 1: Building an Image Classifier with ResNet
Overview
In this case study, we build a complete image classification system using a ResNet architecture trained on the CIFAR-10 dataset. We implement ResNet-18 from scratch, train it with modern best practices, and analyze its behavior using the visualization techniques from Chapter 14. This case study brings together concepts from loss functions (Chapter 8), optimization (Chapter 12), regularization (Chapter 13), and the CNN-specific material from this chapter.
Problem Statement
CIFAR-10 contains 60,000 32x32 color images in 10 classes (airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck), with 50,000 training images and 10,000 test images. While the images are small and low-resolution, achieving high accuracy requires a well-designed CNN with proper training methodology.
Our goals: 1. Implement ResNet-18 from scratch with proper skip connections 2. Train it to achieve over 93% test accuracy on CIFAR-10 3. Analyze training dynamics, learned features, and failure cases 4. Compare with a baseline CNN without skip connections
Architecture Design
Why ResNet for CIFAR-10?
CIFAR-10 images are only 32x32 pixels -- much smaller than the 224x224 images ResNet was originally designed for. We adapt the architecture by: - Using a 3x3 initial convolution (instead of 7x7) with stride 1 (instead of stride 2) - Removing the initial max pooling layer - Keeping the four residual stages but with appropriate channel counts
These modifications preserve the spatial resolution through the early layers, which is crucial for small images. With the original ResNet design, a 32x32 image would be reduced to 1x1 by the third stage, losing all spatial information.
ResNet-18 for CIFAR-10
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from typing import Tuple, List, Optional
import time
torch.manual_seed(42)
class BasicBlock(nn.Module):
"""ResNet Basic Block with skip connection.
Each block contains two 3x3 convolutions with batch normalization
and a residual shortcut connection.
Args:
in_channels: Number of input channels.
out_channels: Number of output channels.
stride: Stride for the first convolution (used for downsampling).
"""
expansion: int = 1
def __init__(
self, in_channels: int, out_channels: int, stride: int = 1
) -> None:
super().__init__()
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False,
)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3,
stride=1, padding=1, bias=False,
)
self.bn2 = nn.BatchNorm2d(out_channels)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_channels, out_channels, kernel_size=1,
stride=stride, bias=False,
),
nn.BatchNorm2d(out_channels),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass with residual connection.
Args:
x: Input tensor of shape (N, C, H, W).
Returns:
Output tensor after residual addition and activation.
"""
identity = self.shortcut(x)
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += identity
out = F.relu(out)
return out
class ResNet18CIFAR(nn.Module):
"""ResNet-18 adapted for CIFAR-10 (32x32 images).
Modifications from standard ResNet-18:
- Initial 3x3 conv with stride 1 instead of 7x7 with stride 2
- No initial max pooling
- Four stages with [2, 2, 2, 2] blocks and [64, 128, 256, 512] channels
Args:
num_classes: Number of output classes. Defaults to 10.
"""
def __init__(self, num_classes: int = 10) -> None:
super().__init__()
self.in_channels = 64
# Initial convolution (no aggressive downsampling for 32x32 images)
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
# Residual stages
self.layer1 = self._make_layer(64, num_blocks=2, stride=1) # 32x32
self.layer2 = self._make_layer(128, num_blocks=2, stride=2) # 16x16
self.layer3 = self._make_layer(256, num_blocks=2, stride=2) # 8x8
self.layer4 = self._make_layer(512, num_blocks=2, stride=2) # 4x4
# Classification head
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(512, num_classes)
# Weight initialization
self._initialize_weights()
def _make_layer(
self, out_channels: int, num_blocks: int, stride: int
) -> nn.Sequential:
"""Create a residual stage with multiple blocks.
Args:
out_channels: Number of output channels for this stage.
num_blocks: Number of residual blocks.
stride: Stride for the first block (downsampling).
Returns:
Sequential container of residual blocks.
"""
strides = [stride] + [1] * (num_blocks - 1)
layers: List[nn.Module] = []
for s in strides:
layers.append(BasicBlock(self.in_channels, out_channels, s))
self.in_channels = out_channels
return nn.Sequential(*layers)
def _initialize_weights(self) -> None:
"""Initialize weights using Kaiming initialization."""
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through ResNet-18.
Args:
x: Input tensor of shape (N, 3, 32, 32).
Returns:
Logits tensor of shape (N, num_classes).
"""
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.avg_pool(out)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
Training Pipeline
Data Preparation
We use standard CIFAR-10 augmentation: random crops with padding and horizontal flips. These simple augmentations are remarkably effective -- they can improve accuracy by 2-4% on CIFAR-10.
def get_cifar10_loaders(
batch_size: int = 128,
num_workers: int = 2,
) -> Tuple[DataLoader, DataLoader]:
"""Create CIFAR-10 data loaders with standard augmentation.
Args:
batch_size: Samples per mini-batch. Defaults to 128.
num_workers: Data loading workers. Defaults to 2.
Returns:
Tuple of (train_loader, test_loader).
"""
# CIFAR-10 normalization statistics
cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2470, 0.2435, 0.2616)
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize(cifar10_mean, cifar10_std),
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(cifar10_mean, cifar10_std),
])
train_set = datasets.CIFAR10(
root="./data", train=True, download=True, transform=train_transform,
)
test_set = datasets.CIFAR10(
root="./data", train=False, download=True, transform=test_transform,
)
train_loader = DataLoader(
train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers,
)
test_loader = DataLoader(
test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers,
)
return train_loader, test_loader
Training Loop with Learning Rate Scheduling
We use SGD with momentum and a cosine annealing learning rate schedule, which we found in Chapter 12 to provide smooth convergence without requiring manual tuning of decay milestones.
def train_resnet(
num_epochs: int = 200,
batch_size: int = 128,
learning_rate: float = 0.1,
weight_decay: float = 5e-4,
device: Optional[torch.device] = None,
) -> Tuple[nn.Module, dict]:
"""Train ResNet-18 on CIFAR-10 with cosine annealing.
Args:
num_epochs: Number of training epochs. Defaults to 200.
batch_size: Mini-batch size. Defaults to 128.
learning_rate: Initial learning rate. Defaults to 0.1.
weight_decay: L2 regularization strength. Defaults to 5e-4.
device: Device to train on. Auto-detects if None.
Returns:
Tuple of (trained model, training history dict).
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Training on: {device}")
# Data
train_loader, test_loader = get_cifar10_loaders(batch_size)
# Model
model = ResNet18CIFAR(num_classes=10).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")
# Loss, optimizer, scheduler
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(
model.parameters(),
lr=learning_rate,
momentum=0.9,
weight_decay=weight_decay,
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
# Training history
history = {
"train_loss": [], "test_loss": [],
"train_acc": [], "test_acc": [],
"lr": [],
}
best_acc = 0.0
for epoch in range(num_epochs):
# --- Training ---
model.train()
train_loss = 0.0
train_correct = 0
train_total = 0
for inputs, targets in train_loader:
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
train_loss += loss.item() * inputs.size(0)
_, predicted = outputs.max(1)
train_total += targets.size(0)
train_correct += predicted.eq(targets).sum().item()
scheduler.step()
# --- Evaluation ---
model.eval()
test_loss = 0.0
test_correct = 0
test_total = 0
with torch.no_grad():
for inputs, targets in test_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
test_loss += loss.item() * inputs.size(0)
_, predicted = outputs.max(1)
test_total += targets.size(0)
test_correct += predicted.eq(targets).sum().item()
# Record metrics
train_acc = train_correct / train_total
test_acc = test_correct / test_total
avg_train_loss = train_loss / train_total
avg_test_loss = test_loss / test_total
current_lr = scheduler.get_last_lr()[0]
history["train_loss"].append(avg_train_loss)
history["test_loss"].append(avg_test_loss)
history["train_acc"].append(train_acc)
history["test_acc"].append(test_acc)
history["lr"].append(current_lr)
if test_acc > best_acc:
best_acc = test_acc
torch.save(model.state_dict(), "best_resnet18_cifar10.pth")
if (epoch + 1) % 20 == 0:
print(
f"Epoch [{epoch+1}/{num_epochs}] "
f"Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.4f} | "
f"Test Loss: {avg_test_loss:.4f} | Test Acc: {test_acc:.4f} | "
f"LR: {current_lr:.6f}"
)
print(f"\nBest Test Accuracy: {best_acc:.4f}")
return model, history
Analysis and Results
Expected Performance
With the configuration above (200 epochs, SGD with cosine annealing, standard augmentation), ResNet-18 typically achieves: - Training accuracy: ~99.5% - Test accuracy: ~93-94% - The gap indicates some overfitting, which is expected for CIFAR-10
Comparing With and Without Skip Connections
To appreciate the impact of skip connections, we train the same architecture but zero out the shortcut paths:
class PlainNet18CIFAR(nn.Module):
"""ResNet-18 architecture WITHOUT skip connections (plain network).
This serves as an ablation study to demonstrate the importance
of residual connections for training deep networks.
Args:
num_classes: Number of output classes. Defaults to 10.
"""
def __init__(self, num_classes: int = 10) -> None:
super().__init__()
# Same architecture as ResNet-18 but without shortcuts
self.features = nn.Sequential(
# Initial conv
nn.Conv2d(3, 64, 3, 1, 1, bias=False), nn.BatchNorm2d(64), nn.ReLU(),
# Stage 1: 4 conv layers at 64 channels
nn.Conv2d(64, 64, 3, 1, 1, bias=False), nn.BatchNorm2d(64), nn.ReLU(),
nn.Conv2d(64, 64, 3, 1, 1, bias=False), nn.BatchNorm2d(64), nn.ReLU(),
nn.Conv2d(64, 64, 3, 1, 1, bias=False), nn.BatchNorm2d(64), nn.ReLU(),
nn.Conv2d(64, 64, 3, 1, 1, bias=False), nn.BatchNorm2d(64), nn.ReLU(),
# Stage 2: downsample + 3 conv layers at 128 channels
nn.Conv2d(64, 128, 3, 2, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(),
nn.Conv2d(128, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(),
nn.Conv2d(128, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(),
nn.Conv2d(128, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(),
# Stage 3: downsample + 3 conv layers at 256 channels
nn.Conv2d(128, 256, 3, 2, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(),
nn.Conv2d(256, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(),
nn.Conv2d(256, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(),
nn.Conv2d(256, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(),
# Stage 4: downsample + 3 conv layers at 512 channels
nn.Conv2d(256, 512, 3, 2, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(),
nn.Conv2d(512, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(),
nn.Conv2d(512, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(),
nn.Conv2d(512, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(),
)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(512, num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass without skip connections."""
out = self.features(x)
out = self.avg_pool(out)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
Expected results of this comparison: - PlainNet-18: ~90-91% test accuracy, slower convergence, higher training loss - ResNet-18: ~93-94% test accuracy, faster convergence, lower training loss
The difference, while modest for an 18-layer network, becomes dramatic for deeper networks (56+ layers), where plain networks can fail to converge entirely.
Error Analysis
After training, we analyze which classes are most confused:
def compute_confusion_matrix(
model: nn.Module,
test_loader: DataLoader,
device: torch.device,
num_classes: int = 10,
) -> torch.Tensor:
"""Compute the confusion matrix for the test set.
Args:
model: Trained model.
test_loader: Test data loader.
device: Device for computation.
num_classes: Number of classes.
Returns:
Confusion matrix tensor of shape (num_classes, num_classes).
"""
model.eval()
confusion = torch.zeros(num_classes, num_classes, dtype=torch.int64)
with torch.no_grad():
for inputs, targets in test_loader:
inputs = inputs.to(device)
outputs = model(inputs)
_, predicted = outputs.max(1)
for t, p in zip(targets, predicted.cpu()):
confusion[t, p] += 1
return confusion
Common confusion pairs in CIFAR-10: - Cat vs. Dog: Both are furry animals with similar shapes at 32x32 resolution - Automobile vs. Truck: Both are vehicles with wheels - Deer vs. Horse: Both are quadrupeds with similar body structure
These confusions make intuitive sense and highlight the challenge of fine-grained recognition at low resolution.
Key Takeaways
- Skip connections are essential for training deeper networks, enabling gradient flow and simplifying the optimization landscape.
- Architecture adaptation matters: A CIFAR-10 ResNet needs different stem design than an ImageNet ResNet due to the smaller input resolution.
- Kaiming initialization combined with batch normalization enables stable training of deep networks from the start.
- Cosine annealing provides smooth, effective learning rate decay without requiring manual milestone tuning.
- Data augmentation is cheap and effective -- random crops and flips provide 2-4% accuracy improvement on CIFAR-10.
- Error analysis reveals intuitive failure modes that align with human perception of visual similarity.
Connection to Later Chapters
The ResNet architecture and skip connection principle will reappear when we study: - LSTM and GRU gates (Chapter 15): Gating mechanisms serve a similar function to skip connections, controlling information flow through time - Transformer architectures (Chapter 16): Residual connections are a core component of every transformer block - Generative models (Chapter 17): Encoder-decoder architectures with skip connections (U-Net) are fundamental to image generation