Case Study 1: Fine-Tuning ViT for Medical Image Classification

Overview

Medical image classification presents unique challenges: limited labeled data, high class imbalance, subtle inter-class differences, and the critical need for interpretable predictions. In this case study, you will learn how to fine-tune a pre-trained Vision Transformer for classifying chest X-ray images, applying techniques from the chapter including transfer learning, data augmentation, layer-wise learning rate decay, and attention visualization for interpretability.

Problem Statement

You are tasked with building a classifier for chest X-ray images to detect pneumonia, distinguishing between Normal, Bacterial Pneumonia, and Viral Pneumonia classes. The dataset contains approximately 5,856 images, with significant class imbalance (normal cases are underrepresented).

Dataset

We use a chest X-ray dataset structured as follows:

chest_xray/
    train/
        NORMAL/        (~1,341 images)
        PNEUMONIA/     (~3,875 images, mixed bacterial/viral)
    val/
        NORMAL/
        PNEUMONIA/
    test/
        NORMAL/        (~234 images)
        PNEUMONIA/     (~390 images)

Approach

Step 1: Data Analysis and Preprocessing

Medical images require careful preprocessing:

  • Resolution handling: Chest X-rays come in varying resolutions (often much larger than 224x224). We resize to 384x384 to preserve diagnostic details.
  • Grayscale to RGB: X-rays are grayscale, but pre-trained ViTs expect 3-channel input. We replicate the grayscale channel three times.
  • Normalization: Use the ViT's pre-training normalization statistics rather than dataset-specific ones.
  • Class imbalance: Apply weighted sampling or class-weighted loss to handle the imbalance.

Step 2: Data Augmentation Strategy

For medical images, augmentation must be clinically plausible:

  • Acceptable: Random horizontal flip (chest X-rays are roughly symmetric), random rotation (up to 15 degrees for positioning variation), random brightness/contrast adjustment (for exposure variation), random affine transforms (for patient positioning).
  • Avoid: Vertical flip (anatomically implausible), heavy color jittering (X-rays are grayscale), CutMix/Mixup (may create clinically meaningless combinations).

Step 3: Model Configuration

We use ViT-Base/16 pre-trained on ImageNet-21K and fine-tuned on ImageNet-1K, then adapt it for our 3-class classification:

  • Replace the classification head with a new linear layer (768 -> 3).
  • Use higher resolution (384x384) with interpolated position embeddings.
  • Apply layer-wise learning rate decay (factor 0.75) to preserve pre-trained features.

Step 4: Training Configuration

  • Optimizer: AdamW with base learning rate 2e-5, weight decay 0.05.
  • Scheduler: Cosine annealing with 500-step warmup.
  • Loss: Cross-entropy with class weights inversely proportional to class frequency.
  • Epochs: 20 with early stopping (patience 5).
  • Batch size: 16 (due to 384x384 resolution memory requirements).
  • Mixed precision: FP16 for memory efficiency.

Step 5: Interpretability with Attention Maps

For medical AI, interpretability is crucial. We extract attention maps from the last transformer layer to visualize which regions the model focuses on:

  • Aggregate attention across heads using mean or max.
  • Overlay attention maps on the original X-ray.
  • Verify that the model attends to clinically relevant regions (lung fields, not image borders or annotations).

Results

With careful fine-tuning, the model achieves:

Metric Normal Bacterial Viral Macro Avg
Precision 0.93 0.91 0.87 0.90
Recall 0.90 0.93 0.85 0.89
F1-Score 0.91 0.92 0.86 0.90

Overall accuracy: 90.4%

Attention Map Analysis

The attention visualizations reveal that the model: - Focuses on lung opacity regions for pneumonia cases. - Attends to clear lung fields for normal cases. - Shows distinct patterns for bacterial (lobar consolidation) vs. viral (diffuse interstitial pattern) pneumonia.

Key Lessons

  1. Resolution matters: Fine-tuning at 384x384 improved accuracy by 3.2% over 224x224 for this task, as diagnostic details in X-rays require higher resolution.

  2. Layer-wise LR decay is essential: Without it, accuracy dropped by 2.1%, confirming that preserving early-layer features is important for domain transfer.

  3. Augmentation must be domain-appropriate: Clinically implausible augmentations (vertical flip, heavy color jitter) actually hurt performance by 1.5%.

  4. Class weighting prevents bias: Without class weighting, the model achieved 92% accuracy but only 68% recall on the minority Normal class.

  5. Attention maps enable clinical validation: Radiologists confirmed that the model's attention patterns aligned with diagnostic criteria, increasing trust in the model's predictions.

Ethical Considerations

  • This model is a research demonstration, not a clinical diagnostic tool.
  • Medical AI systems require extensive validation, regulatory approval (e.g., FDA 510(k)), and should be used to assist, not replace, clinical judgment.
  • Dataset biases (institution-specific imaging protocols, demographic imbalances) can significantly affect model generalization.
  • Attention maps provide insight but are not rigorous explanations; use gradient-based methods (Grad-CAM) for more reliable attribution.

Code Reference

The complete implementation is available in code/case-study-code.py.