Case Study 2: Training a Deep Network on CIFAR-10
Overview
In this case study, we train a deep convolutional network on CIFAR-10 using the complete training pipeline from Chapter 12. Starting from a basic setup that achieves 75% accuracy, we incrementally add each technique---proper initialization, normalization, learning rate scheduling, gradient clipping, mixed precision, and data augmentation---measuring the impact of each addition. The final model achieves over 93% test accuracy.
This case study demonstrates that the gap between a mediocre model and a strong one is often not the architecture itself but the training recipe.
The Dataset
CIFAR-10 contains 60,000 32x32 color images in 10 classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, and truck. There are 50,000 training images and 10,000 test images. Despite the small image size, CIFAR-10 is a challenging benchmark because the classes are visually diverse and some inter-class distinctions (e.g., cat vs. dog) require fine-grained feature recognition.
The Baseline Model (V0)
We start with a simple CNN trained with minimal engineering:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
torch.manual_seed(42)
class BaselineCNN(nn.Module):
"""Simple CNN without modern training techniques.
Architecture: 3 conv blocks -> FC layers -> 10 classes.
"""
def __init__(self) -> None:
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
)
self.classifier = nn.Sequential(
nn.Linear(128 * 4 * 4, 256),
nn.ReLU(),
nn.Linear(256, 10),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x)
x = x.view(x.size(0), -1)
return self.classifier(x)
# Basic transforms: only normalization
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True,
transform=transform)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True,
transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)
model = BaselineCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# Train for 50 epochs with no schedule, no augmentation, no normalization
# Result: ~75% test accuracy
Baseline result: ~75% test accuracy after 50 epochs. The model underfits---training accuracy is only ~82%.
Incremental Improvements
V1: He Initialization + ReLU
The first fix is proper weight initialization. PyTorch's default (Kaiming uniform) is reasonable for ReLU, but we make it explicit:
def init_weights(module: nn.Module) -> None:
"""He initialization for conv and linear layers."""
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
if module.bias is not None:
nn.init.zeros_(module.bias)
model.apply(init_weights)
V1 result: ~78% test accuracy (+3%). Proper initialization helps but is not transformative for this shallow network.
V2: Batch Normalization
Adding batch normalization after each convolution stabilizes training and allows higher learning rates:
class CNNV2(nn.Module):
"""CNN with batch normalization."""
def __init__(self) -> None:
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1), nn.BatchNorm2d(32),
nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64),
nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128),
nn.ReLU(), nn.MaxPool2d(2),
)
self.classifier = nn.Sequential(
nn.Linear(128 * 4 * 4, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Linear(256, 10),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x)
x = x.view(x.size(0), -1)
return self.classifier(x)
With batch norm, we can safely increase the learning rate to 0.1:
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
V2 result: ~85% test accuracy (+7%). Batch norm is the single largest improvement at this stage.
V3: Learning Rate Schedule
Adding cosine annealing with warmup:
import math
def warmup_cosine_schedule(optimizer, warmup_epochs, total_epochs):
"""Warmup + cosine annealing scheduler."""
def lr_lambda(epoch):
if epoch < warmup_epochs:
return float(epoch) / float(max(1, warmup_epochs))
progress = float(epoch - warmup_epochs) / float(
max(1, total_epochs - warmup_epochs)
)
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
scheduler = warmup_cosine_schedule(optimizer, warmup_epochs=5, total_epochs=200)
We also increase training to 200 epochs since cosine annealing benefits from longer schedules.
V3 result: ~88% test accuracy (+3%). The schedule prevents the learning rate from being too high late in training.
V4: Data Augmentation
Standard CIFAR-10 augmentation includes random crops and horizontal flips:
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
V4 result: ~91% test accuracy (+3%). Data augmentation is crucial for preventing overfitting on the 50,000 training images.
V5: Deeper Architecture with Residual Connections
We switch to a ResNet-style architecture with skip connections:
class ResidualBlock(nn.Module):
"""Basic residual block."""
def __init__(self, channels: int) -> None:
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(channels, channels, 3, padding=1, bias=False),
nn.BatchNorm2d(channels),
nn.ReLU(),
nn.Conv2d(channels, channels, 3, padding=1, bias=False),
nn.BatchNorm2d(channels),
)
self.relu = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.relu(self.block(x) + x)
V5 result: ~92.5% test accuracy (+1.5%). Skip connections enable training deeper networks effectively.
V6: AdamW + Weight Decay + Gradient Clipping
For the final push, we switch to AdamW with proper weight decay and add gradient clipping:
# Separate parameters for weight decay
decay_params = []
no_decay_params = []
for name, param in model.named_parameters():
if "bn" in name or "bias" in name:
no_decay_params.append(param)
else:
decay_params.append(param)
optimizer = optim.AdamW([
{"params": decay_params, "weight_decay": 0.05},
{"params": no_decay_params, "weight_decay": 0.0},
], lr=1e-3)
# In training loop:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
V6 result: ~93.5% test accuracy (+1%). The combination of proper optimizer settings and gradient stability techniques closes the gap.
Results Summary
| Version | Technique Added | Test Accuracy | Delta |
|---|---|---|---|
| V0 | Baseline CNN | 75.0% | --- |
| V1 | He initialization | 78.0% | +3.0% |
| V2 | Batch normalization + momentum | 85.0% | +7.0% |
| V3 | Cosine LR schedule + warmup | 88.0% | +3.0% |
| V4 | Data augmentation | 91.0% | +3.0% |
| V5 | Residual connections | 92.5% | +1.5% |
| V6 | AdamW + weight decay + grad clip | 93.5% | +1.0% |
Analysis
Diminishing Returns
The improvements follow a pattern of diminishing returns. The first few techniques (batch norm, LR schedule, augmentation) each provide 3-7% gains. Later techniques provide 1-1.5% each. This is typical: the "low-hanging fruit" of training engineering provides the largest gains.
Interaction Effects
The techniques are not independent. Batch normalization enables higher learning rates, which in turn makes the learning rate schedule more impactful. Data augmentation prevents overfitting, which allows the model to benefit from more epochs. Skip connections enable depth, which increases the model's capacity to leverage the augmented data.
What We Did Not Do
Several advanced techniques could push accuracy further: - Mixed precision training: Would speed up training by ~2x with no accuracy loss - CutMix / Mixup: Could add another 0.5-1% - Label smoothing: Could add ~0.3% - Stochastic depth: Useful for deeper ResNets - AutoAugment / RandAugment: Learned augmentation policies
Key Takeaways
-
The training recipe matters more than the architecture for moderate-sized models. Switching from a basic training setup to a proper one improved accuracy by 18.5 percentage points on the same base architecture.
-
Batch normalization is the single most impactful technique for training stability and performance. It should be considered mandatory for any CNN.
-
Data augmentation is essential when the training set is small (by deep learning standards). Without it, the model memorizes training images.
-
Proper learning rate scheduling prevents the optimizer from overshooting late in training. Cosine annealing with warmup is a reliable default.
-
Weight decay should be excluded from bias and normalization parameters. These parameters have different roles and should not be shrunk toward zero.
-
Incremental development and measurement is the right approach. Add one technique at a time, measure the effect, and build on what works.
Discussion Questions
-
If you could only use three of the six techniques, which would you choose and why?
-
The V5 model uses residual connections but is still relatively shallow. How would you expect the results to change with a much deeper network (e.g., 50 layers)?
-
At what point does additional training engineering yield diminishing returns? How do you decide when to stop tuning and deploy?
-
This case study used CIFAR-10, which has balanced classes. How would the training recipe change for an imbalanced dataset?