Case Study 1: Preventing Overfitting in a Medical Imaging Model

Background

Dr. Sarah Chen's team at a university hospital has collected 3,200 dermoscopic images across four diagnostic categories: melanoma (800), basal cell carcinoma (800), benign nevus (800), and seborrheic keratosis (800). They want to build a deep learning classifier to assist dermatologists in triaging suspicious skin lesions.

The dataset is small by deep learning standards, highly sensitive to class-specific visual features, and carries serious consequences for misclassification---a missed melanoma could be life-threatening. The team's initial model achieved 98.5% training accuracy but only 71.2% validation accuracy after 100 epochs, a clear case of severe overfitting.

This case study walks through the systematic process of identifying and resolving overfitting using the regularization techniques from Chapter 13.

The Initial (Overfitting) Model

The team started with a ResNet-18 trained from scratch:

import torch
import torch.nn as nn
from torchvision import models

torch.manual_seed(42)

# Initial model: ResNet-18 from scratch, no regularization
model_v0 = models.resnet18(weights=None, num_classes=4)

optimizer_v0 = torch.optim.Adam(model_v0.parameters(), lr=1e-3)
criterion_v0 = nn.CrossEntropyLoss()

# Minimal preprocessing, no augmentation
# Only resize and normalize

Results after 100 epochs: | Metric | Value | |--------|-------| | Training accuracy | 98.5% | | Validation accuracy | 71.2% | | Generalization gap | 27.3% | | Training loss | 0.04 | | Validation loss | 1.62 |

The learning curves showed classic overfitting: training loss steadily decreased while validation loss began increasing after epoch 15.

Step 1: Diagnosis

Before applying regularization, the team diagnosed the problem:

  1. Generalization gap of 27.3% is severe. The model memorized the training data.
  2. Small dataset (3,200 images for a model with 11 million parameters) means the model is massively overparameterized relative to the data.
  3. No data augmentation meant each image was seen in exactly the same form every epoch, facilitating memorization.
  4. No weight decay allowed weights to grow unconstrained.
  5. No dropout in the classifier head.

Step 2: Transfer Learning as Foundation

The single most impactful change for small medical imaging datasets is transfer learning (covered in depth in Chapter 15). Instead of training from scratch, the team used ImageNet-pretrained weights:

torch.manual_seed(42)

# V1: Pretrained backbone with frozen early layers
model_v1 = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

# Replace the final classification layer
model_v1.fc = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(512, 4),
)

# Freeze early layers initially
for name, param in model_v1.named_parameters():
    if "layer4" not in name and "fc" not in name:
        param.requires_grad = False

Results: | Metric | Value | Change | |--------|-------|--------| | Training accuracy | 94.1% | -4.4% | | Validation accuracy | 82.8% | +11.6% | | Generalization gap | 11.3% | -16.0% |

Transfer learning alone reduced the gap from 27.3% to 11.3%, a dramatic improvement. But 11.3% is still too large.

Step 3: Data Augmentation

The team implemented a comprehensive augmentation pipeline tailored for dermoscopy:

from torchvision import transforms

torch.manual_seed(42)

# Medical imaging augmentation pipeline
# Note: No vertical flip (lesion orientation matters less,
# but we include it for dermoscopy since viewing angle varies)
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(degrees=90),
    transforms.ColorJitter(
        brightness=0.3,
        contrast=0.3,
        saturation=0.3,
        hue=0.05,  # Conservative hue shift - color is diagnostic
    ),
    transforms.RandomAffine(
        degrees=0,
        translate=(0.1, 0.1),
        scale=(0.9, 1.1),
    ),
    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
    transforms.RandomErasing(p=0.2, scale=(0.02, 0.15)),
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

Key augmentation choices for medical imaging:

  • Conservative color jitter: Color information (redness, pigmentation) is diagnostically important, so hue shifts are kept small.
  • Rotation up to 90 degrees: Dermoscopic images can be captured at any orientation.
  • Random erasing: Simulates partial occlusion by hair, bubbles, or artifacts.
  • No aggressive cropping: Important diagnostic features can appear anywhere in the image.

Results with augmentation added: | Metric | Value | Change from V1 | |--------|-------|-----------------| | Training accuracy | 89.3% | -4.8% | | Validation accuracy | 86.1% | +3.3% | | Generalization gap | 3.2% | -8.1% |

Step 4: Weight Decay and Label Smoothing

Next, the team added weight decay via AdamW and label smoothing:

torch.manual_seed(42)

# Differential learning rates with weight decay
optimizer_v3 = torch.optim.AdamW(
    [
        {"params": model_v1.layer4.parameters(), "lr": 1e-4},
        {"params": model_v1.fc.parameters(), "lr": 1e-3},
    ],
    weight_decay=0.05,
)

# Label smoothing - important for medical imaging where labels
# can be noisy (inter-annotator disagreement is common)
criterion_v3 = nn.CrossEntropyLoss(label_smoothing=0.1)

Results: | Metric | Value | Change from V2 | |--------|-------|-----------------| | Training accuracy | 87.6% | -1.7% | | Validation accuracy | 87.4% | +1.3% | | Generalization gap | 0.2% | -3.0% |

The generalization gap is now essentially zero, indicating the model generalizes well.

Step 5: Mixup for Robustness

To further improve robustness, the team added mixup:

import torch

torch.manual_seed(42)


def mixup_data(
    x: torch.Tensor,
    y: torch.Tensor,
    alpha: float = 0.3,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]:
    """Apply mixup to a batch of dermoscopy images.

    Args:
        x: Batch of images.
        y: Batch of labels.
        alpha: Mixup strength.

    Returns:
        Mixed data, original labels, shuffled labels, mix coefficient.
    """
    lam = torch.distributions.Beta(alpha, alpha).sample().item()
    batch_size = x.size(0)
    index = torch.randperm(batch_size, device=x.device)
    mixed_x = lam * x + (1 - lam) * x[index]
    return mixed_x, y, y[index], lam

Results: | Metric | Value | Change from V3 | |--------|-------|-----------------| | Training accuracy | 85.2% | -2.4% | | Validation accuracy | 88.7% | +1.3% | | Generalization gap | -3.5% | (val > train, healthy sign) |

Step 6: Early Stopping and Learning Rate Scheduling

Finally, the team added cosine annealing with warm restarts and early stopping:

torch.manual_seed(42)

scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer_v3, T_0=20, T_mult=2
)

early_stopping = EarlyStopping(
    patience=15,
    min_delta=0.001,
    restore_best_weights=True,
)

Final results: | Metric | Value | |--------|-------| | Training accuracy | 86.1% | | Validation accuracy | 89.3% | | Test accuracy | 88.7% | | Generalization gap (test) | -2.6% | | Best validation epoch | 67 |

Ablation Study

The team performed an ablation to quantify the contribution of each technique:

Configuration Val Acc Gap
Baseline (no regularization, from scratch) 71.2% 27.3%
+ Transfer learning 82.8% 11.3%
+ Data augmentation 86.1% 3.2%
+ Weight decay + label smoothing 87.4% 0.2%
+ Mixup 88.7% -3.5%
+ Early stopping + LR scheduling 89.3% -2.6%

Lessons Learned

  1. Transfer learning is the single most impactful technique for small medical imaging datasets. It contributed +11.6% validation accuracy.

  2. Data augmentation must be domain-aware. The team's initial attempt with aggressive color jitter (hue=0.3) actually hurt performance because it destroyed diagnostic color information. Reducing hue jitter to 0.05 fixed this.

  3. Label smoothing is especially valuable for medical imaging because inter-annotator agreement is rarely 100%. Different dermatologists may disagree on diagnoses, and label smoothing accounts for this inherent uncertainty.

  4. The generalization gap is the key metric to watch, not just validation accuracy. A negative gap (validation > training) indicates healthy regularization.

  5. Regularization techniques compound. Each individual technique provided a modest improvement, but together they transformed a failing model (71.2% validation accuracy) into a clinically useful one (89.3%).

  6. Do not over-regularize. The team tried adding Dropout2d(0.3) to the convolutional layers on top of all other techniques, and this actually reduced validation accuracy by 1.2%. When the gap is already near zero, adding more regularization pushes toward underfitting.

Deployment Considerations

Before deployment, the team:

  • Used Monte Carlo dropout (Exercise 13.23) to estimate prediction uncertainty and flag cases requiring human review.
  • Validated on an external dataset from a different hospital to check for distribution shift.
  • Set a confidence threshold: predictions with softmax probability below 0.7 are routed to a human dermatologist.

Code Reference

The complete implementation for this case study is available in code/case-study-code.py.