Case Study 1: Preventing Overfitting in a Medical Imaging Model
Background
Dr. Sarah Chen's team at a university hospital has collected 3,200 dermoscopic images across four diagnostic categories: melanoma (800), basal cell carcinoma (800), benign nevus (800), and seborrheic keratosis (800). They want to build a deep learning classifier to assist dermatologists in triaging suspicious skin lesions.
The dataset is small by deep learning standards, highly sensitive to class-specific visual features, and carries serious consequences for misclassification---a missed melanoma could be life-threatening. The team's initial model achieved 98.5% training accuracy but only 71.2% validation accuracy after 100 epochs, a clear case of severe overfitting.
This case study walks through the systematic process of identifying and resolving overfitting using the regularization techniques from Chapter 13.
The Initial (Overfitting) Model
The team started with a ResNet-18 trained from scratch:
import torch
import torch.nn as nn
from torchvision import models
torch.manual_seed(42)
# Initial model: ResNet-18 from scratch, no regularization
model_v0 = models.resnet18(weights=None, num_classes=4)
optimizer_v0 = torch.optim.Adam(model_v0.parameters(), lr=1e-3)
criterion_v0 = nn.CrossEntropyLoss()
# Minimal preprocessing, no augmentation
# Only resize and normalize
Results after 100 epochs: | Metric | Value | |--------|-------| | Training accuracy | 98.5% | | Validation accuracy | 71.2% | | Generalization gap | 27.3% | | Training loss | 0.04 | | Validation loss | 1.62 |
The learning curves showed classic overfitting: training loss steadily decreased while validation loss began increasing after epoch 15.
Step 1: Diagnosis
Before applying regularization, the team diagnosed the problem:
- Generalization gap of 27.3% is severe. The model memorized the training data.
- Small dataset (3,200 images for a model with 11 million parameters) means the model is massively overparameterized relative to the data.
- No data augmentation meant each image was seen in exactly the same form every epoch, facilitating memorization.
- No weight decay allowed weights to grow unconstrained.
- No dropout in the classifier head.
Step 2: Transfer Learning as Foundation
The single most impactful change for small medical imaging datasets is transfer learning (covered in depth in Chapter 15). Instead of training from scratch, the team used ImageNet-pretrained weights:
torch.manual_seed(42)
# V1: Pretrained backbone with frozen early layers
model_v1 = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
# Replace the final classification layer
model_v1.fc = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(512, 4),
)
# Freeze early layers initially
for name, param in model_v1.named_parameters():
if "layer4" not in name and "fc" not in name:
param.requires_grad = False
Results: | Metric | Value | Change | |--------|-------|--------| | Training accuracy | 94.1% | -4.4% | | Validation accuracy | 82.8% | +11.6% | | Generalization gap | 11.3% | -16.0% |
Transfer learning alone reduced the gap from 27.3% to 11.3%, a dramatic improvement. But 11.3% is still too large.
Step 3: Data Augmentation
The team implemented a comprehensive augmentation pipeline tailored for dermoscopy:
from torchvision import transforms
torch.manual_seed(42)
# Medical imaging augmentation pipeline
# Note: No vertical flip (lesion orientation matters less,
# but we include it for dermoscopy since viewing angle varies)
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
transforms.RandomRotation(degrees=90),
transforms.ColorJitter(
brightness=0.3,
contrast=0.3,
saturation=0.3,
hue=0.05, # Conservative hue shift - color is diagnostic
),
transforms.RandomAffine(
degrees=0,
translate=(0.1, 0.1),
scale=(0.9, 1.1),
),
transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
transforms.RandomErasing(p=0.2, scale=(0.02, 0.15)),
])
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
])
Key augmentation choices for medical imaging:
- Conservative color jitter: Color information (redness, pigmentation) is diagnostically important, so hue shifts are kept small.
- Rotation up to 90 degrees: Dermoscopic images can be captured at any orientation.
- Random erasing: Simulates partial occlusion by hair, bubbles, or artifacts.
- No aggressive cropping: Important diagnostic features can appear anywhere in the image.
Results with augmentation added: | Metric | Value | Change from V1 | |--------|-------|-----------------| | Training accuracy | 89.3% | -4.8% | | Validation accuracy | 86.1% | +3.3% | | Generalization gap | 3.2% | -8.1% |
Step 4: Weight Decay and Label Smoothing
Next, the team added weight decay via AdamW and label smoothing:
torch.manual_seed(42)
# Differential learning rates with weight decay
optimizer_v3 = torch.optim.AdamW(
[
{"params": model_v1.layer4.parameters(), "lr": 1e-4},
{"params": model_v1.fc.parameters(), "lr": 1e-3},
],
weight_decay=0.05,
)
# Label smoothing - important for medical imaging where labels
# can be noisy (inter-annotator disagreement is common)
criterion_v3 = nn.CrossEntropyLoss(label_smoothing=0.1)
Results: | Metric | Value | Change from V2 | |--------|-------|-----------------| | Training accuracy | 87.6% | -1.7% | | Validation accuracy | 87.4% | +1.3% | | Generalization gap | 0.2% | -3.0% |
The generalization gap is now essentially zero, indicating the model generalizes well.
Step 5: Mixup for Robustness
To further improve robustness, the team added mixup:
import torch
torch.manual_seed(42)
def mixup_data(
x: torch.Tensor,
y: torch.Tensor,
alpha: float = 0.3,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]:
"""Apply mixup to a batch of dermoscopy images.
Args:
x: Batch of images.
y: Batch of labels.
alpha: Mixup strength.
Returns:
Mixed data, original labels, shuffled labels, mix coefficient.
"""
lam = torch.distributions.Beta(alpha, alpha).sample().item()
batch_size = x.size(0)
index = torch.randperm(batch_size, device=x.device)
mixed_x = lam * x + (1 - lam) * x[index]
return mixed_x, y, y[index], lam
Results: | Metric | Value | Change from V3 | |--------|-------|-----------------| | Training accuracy | 85.2% | -2.4% | | Validation accuracy | 88.7% | +1.3% | | Generalization gap | -3.5% | (val > train, healthy sign) |
Step 6: Early Stopping and Learning Rate Scheduling
Finally, the team added cosine annealing with warm restarts and early stopping:
torch.manual_seed(42)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer_v3, T_0=20, T_mult=2
)
early_stopping = EarlyStopping(
patience=15,
min_delta=0.001,
restore_best_weights=True,
)
Final results: | Metric | Value | |--------|-------| | Training accuracy | 86.1% | | Validation accuracy | 89.3% | | Test accuracy | 88.7% | | Generalization gap (test) | -2.6% | | Best validation epoch | 67 |
Ablation Study
The team performed an ablation to quantify the contribution of each technique:
| Configuration | Val Acc | Gap |
|---|---|---|
| Baseline (no regularization, from scratch) | 71.2% | 27.3% |
| + Transfer learning | 82.8% | 11.3% |
| + Data augmentation | 86.1% | 3.2% |
| + Weight decay + label smoothing | 87.4% | 0.2% |
| + Mixup | 88.7% | -3.5% |
| + Early stopping + LR scheduling | 89.3% | -2.6% |
Lessons Learned
-
Transfer learning is the single most impactful technique for small medical imaging datasets. It contributed +11.6% validation accuracy.
-
Data augmentation must be domain-aware. The team's initial attempt with aggressive color jitter (hue=0.3) actually hurt performance because it destroyed diagnostic color information. Reducing hue jitter to 0.05 fixed this.
-
Label smoothing is especially valuable for medical imaging because inter-annotator agreement is rarely 100%. Different dermatologists may disagree on diagnoses, and label smoothing accounts for this inherent uncertainty.
-
The generalization gap is the key metric to watch, not just validation accuracy. A negative gap (validation > training) indicates healthy regularization.
-
Regularization techniques compound. Each individual technique provided a modest improvement, but together they transformed a failing model (71.2% validation accuracy) into a clinically useful one (89.3%).
-
Do not over-regularize. The team tried adding Dropout2d(0.3) to the convolutional layers on top of all other techniques, and this actually reduced validation accuracy by 1.2%. When the gap is already near zero, adding more regularization pushes toward underfitting.
Deployment Considerations
Before deployment, the team:
- Used Monte Carlo dropout (Exercise 13.23) to estimate prediction uncertainty and flag cases requiring human review.
- Validated on an external dataset from a different hospital to check for distribution shift.
- Set a confidence threshold: predictions with softmax probability below 0.7 are routed to a human dermatologist.
Code Reference
The complete implementation for this case study is available in code/case-study-code.py.