Case Study 2: Data Augmentation Strategies for Small Datasets
Background
A wildlife conservation team needs to classify camera trap images of endangered species in a remote rainforest. They have collected only 500 labeled images across 5 species: jaguar (120), ocelot (95), tapir (110), peccary (90), and capybara (85). The images vary significantly in lighting (day vs. night infrared), angle, occlusion (vegetation), and image quality.
The challenge: build a classifier that achieves at least 85% accuracy despite having fewer than 100 images per class for some species. This case study explores how different data augmentation strategies can make or break model performance on extremely small datasets.
The Challenge of Small Datasets
With only 500 images and a modern CNN, the ratio of parameters to training examples is staggering:
| Model | Parameters | Params per Image |
|---|---|---|
| ResNet-18 | 11.2M | 22,400x |
| EfficientNet-B0 | 5.3M | 10,600x |
| MobileNetV3-Small | 2.5M | 5,000x |
Even the smallest model has 5,000 parameters per training image. Without aggressive regularization, memorization is inevitable.
Experiment Setup
import torch
import torch.nn as nn
from torchvision import transforms, models
from torch.utils.data import DataLoader, random_split
torch.manual_seed(42)
# We use a pretrained EfficientNet-B0 as the backbone
# Transfer learning is essential for small datasets (see Chapter 15)
model = models.efficientnet_b0(
weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1
)
model.classifier = nn.Sequential(
nn.Dropout(0.4),
nn.Linear(1280, 5),
)
# Split: 350 train, 75 val, 75 test
# Stratified to maintain class balance
Strategy 1: No Augmentation (Baseline)
baseline_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
])
Results after 50 epochs: | Metric | Value | |--------|-------| | Training accuracy | 100.0% | | Validation accuracy | 58.7% | | Generalization gap | 41.3% |
Complete memorization, as expected. The model learned to identify individual images rather than species.
Strategy 2: Basic Augmentation
Standard geometric and color transformations:
basic_transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(degrees=15),
transforms.ColorJitter(
brightness=0.3,
contrast=0.3,
saturation=0.2,
),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
])
Results: | Metric | Value | Change from Baseline | |--------|-------|---------------------| | Training accuracy | 95.2% | -4.8% | | Validation accuracy | 72.0% | +13.3% | | Generalization gap | 23.2% | -18.1% |
A significant improvement, but still insufficient. The augmentations are too mild for such a small dataset.
Strategy 3: Aggressive Augmentation
Push the augmentation much further:
aggressive_transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.3, 1.0)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.2),
transforms.RandomRotation(degrees=45),
transforms.RandomAffine(
degrees=0,
translate=(0.2, 0.2),
scale=(0.8, 1.2),
shear=15,
),
transforms.ColorJitter(
brightness=0.5,
contrast=0.5,
saturation=0.5,
hue=0.1,
),
transforms.RandomGrayscale(p=0.1),
transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
transforms.RandomErasing(p=0.3, scale=(0.02, 0.25)),
])
Results: | Metric | Value | Change from Basic | |--------|-------|-------------------| | Training accuracy | 84.6% | -10.6% | | Validation accuracy | 79.3% | +7.3% | | Generalization gap | 5.3% | -17.9% |
Much better. The aggressive augmentation makes the training task harder (lower training accuracy) but the model learns more generalizable features.
Strategy 4: Domain-Specific Augmentation
Camera trap images have specific characteristics that suggest targeted augmentations:
import torch
from torchvision import transforms
torch.manual_seed(42)
class SimulateNightVision:
"""Simulate infrared/night vision camera appearance.
Many camera traps switch to infrared at night,
producing grayscale images with different contrast profiles.
"""
def __init__(self, p: float = 0.3) -> None:
self.p = p
def __call__(self, img: torch.Tensor) -> torch.Tensor:
"""Apply night vision simulation.
Args:
img: Input tensor image.
Returns:
Possibly modified image tensor.
"""
if torch.rand(1).item() < self.p:
# Convert to grayscale-like by reducing saturation
gray = img.mean(dim=0, keepdim=True)
# Add slight green tint (common in night vision)
result = torch.zeros_like(img)
result[0] = gray[0] * 0.8
result[1] = gray[0] * 1.1
result[2] = gray[0] * 0.7
return result.clamp(0, 1)
return img
class SimulateVegetationOcclusion:
"""Simulate partial occlusion by vegetation.
Camera trap images frequently have branches, leaves,
or grass partially occluding the animal.
"""
def __init__(self, p: float = 0.3, max_bars: int = 5) -> None:
self.p = p
self.max_bars = max_bars
def __call__(self, img: torch.Tensor) -> torch.Tensor:
"""Apply vegetation-like occlusion.
Args:
img: Input tensor image of shape (C, H, W).
Returns:
Image with random bar occlusions.
"""
if torch.rand(1).item() < self.p:
_, h, w = img.shape
num_bars = torch.randint(1, self.max_bars + 1, (1,)).item()
result = img.clone()
for _ in range(num_bars):
bar_width = torch.randint(2, 8, (1,)).item()
if torch.rand(1).item() > 0.5:
# Horizontal bar
y = torch.randint(0, h - bar_width, (1,)).item()
# Dark green color for vegetation
result[0, y:y + bar_width, :] *= 0.3
result[1, y:y + bar_width, :] = (
result[1, y:y + bar_width, :] * 0.5 + 0.2
)
result[2, y:y + bar_width, :] *= 0.3
else:
# Vertical bar
x = torch.randint(0, w - bar_width, (1,)).item()
result[0, :, x:x + bar_width] *= 0.3
result[1, :, x:x + bar_width] = (
result[1, :, x:x + bar_width] * 0.5 + 0.2
)
result[2, :, x:x + bar_width] *= 0.3
return result.clamp(0, 1)
return img
domain_specific_transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.4, 1.0)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(degrees=30),
transforms.ColorJitter(
brightness=0.5,
contrast=0.5,
saturation=0.4,
hue=0.05,
),
transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
# Domain-specific augmentations applied after ToTensor
SimulateNightVision(p=0.3),
SimulateVegetationOcclusion(p=0.3),
transforms.RandomErasing(p=0.3, scale=(0.02, 0.2)),
])
Results: | Metric | Value | Change from Aggressive | |--------|-------|----------------------| | Training accuracy | 82.1% | -2.5% | | Validation accuracy | 82.7% | +3.4% | | Generalization gap | -0.6% | -5.9% |
Domain-specific augmentations further improved validation accuracy by simulating conditions the model would encounter in the real world.
Strategy 5: RandAugment
Using automated augmentation policies:
from torchvision.transforms import autoaugment
randaug_transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.4, 1.0)),
transforms.RandomHorizontalFlip(p=0.5),
autoaugment.RandAugment(num_ops=3, magnitude=12),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
transforms.RandomErasing(p=0.2),
])
Results: | Metric | Value | Change from Domain | |--------|-------|-------------------| | Training accuracy | 80.5% | -1.6% | | Validation accuracy | 83.1% | +0.4% | | Generalization gap | -2.6% | -2.0% |
RandAugment performed comparably to the domain-specific strategy, which is notable because it required zero domain knowledge to set up.
Strategy 6: Mixup + CutMix + Strong Augmentation
Combining sample-level augmentation with the best spatial augmentation:
import torch
torch.manual_seed(42)
def train_step_with_mixing(
model: nn.Module,
x: torch.Tensor,
y: torch.Tensor,
criterion: nn.Module,
optimizer: torch.optim.Optimizer,
mixup_alpha: float = 0.3,
cutmix_alpha: float = 1.0,
mix_prob: float = 0.5,
) -> float:
"""Training step with randomly applied mixup or CutMix.
Args:
model: The neural network.
x: Input batch.
y: Target batch.
criterion: Loss function.
optimizer: Optimizer.
mixup_alpha: Mixup beta distribution parameter.
cutmix_alpha: CutMix beta distribution parameter.
mix_prob: Probability of applying mixing.
Returns:
Loss value for this step.
"""
model.train()
optimizer.zero_grad()
if torch.rand(1).item() < mix_prob:
if torch.rand(1).item() < 0.5:
# Apply mixup
lam = torch.distributions.Beta(
mixup_alpha, mixup_alpha
).sample().item()
index = torch.randperm(x.size(0), device=x.device)
mixed_x = lam * x + (1 - lam) * x[index]
output = model(mixed_x)
loss = lam * criterion(output, y) + (
(1 - lam) * criterion(output, y[index])
)
else:
# Apply CutMix
lam = torch.distributions.Beta(
cutmix_alpha, cutmix_alpha
).sample().item()
index = torch.randperm(x.size(0), device=x.device)
_, _, h, w = x.shape
cut_ratio = (1.0 - lam) ** 0.5
cut_h, cut_w = int(h * cut_ratio), int(w * cut_ratio)
cx = torch.randint(0, w, (1,)).item()
cy = torch.randint(0, h, (1,)).item()
x1 = max(0, cx - cut_w // 2)
y1 = max(0, cy - cut_h // 2)
x2 = min(w, cx + cut_w // 2)
y2 = min(h, cy + cut_h // 2)
mixed_x = x.clone()
mixed_x[:, :, y1:y2, x1:x2] = x[index, :, y1:y2, x1:x2]
actual_lam = 1 - (y2 - y1) * (x2 - x1) / (h * w)
output = model(mixed_x)
loss = actual_lam * criterion(output, y) + (
(1 - actual_lam) * criterion(output, y[index])
)
else:
output = model(x)
loss = criterion(output, y)
loss.backward()
optimizer.step()
return loss.item()
Results: | Metric | Value | Change from RandAugment | |--------|-------|------------------------| | Training accuracy | 78.3% | -2.2% | | Validation accuracy | 85.6% | +2.5% | | Generalization gap | -7.3% | -4.7% |
This combined strategy finally achieved the 85% target.
Complete Results Summary
| Strategy | Train Acc | Val Acc | Gap | Test Acc |
|---|---|---|---|---|
| No augmentation | 100.0% | 58.7% | 41.3% | 56.2% |
| Basic augmentation | 95.2% | 72.0% | 23.2% | 70.8% |
| Aggressive augmentation | 84.6% | 79.3% | 5.3% | 78.1% |
| Domain-specific | 82.1% | 82.7% | -0.6% | 81.4% |
| RandAugment | 80.5% | 83.1% | -2.6% | 82.3% |
| Mixing + strong aug | 78.3% | 85.6% | -7.3% | 84.9% |
Key Observations
1. Augmentation Strength Scales with Data Scarcity
The optimal augmentation strength is inversely proportional to dataset size. With only 500 images, aggressive augmentation that would hurt on a large dataset became essential. This aligns with the recommendations in Section 13.11.4 of the main chapter.
2. Domain Knowledge Has Diminishing Returns
The domain-specific augmentations (night vision simulation, vegetation occlusion) provided a meaningful boost (+3.4% over aggressive augmentation), but RandAugment achieved similar results with no domain knowledge. For practitioners without domain expertise, automated policies are an excellent starting point.
3. Mixing Techniques Are Especially Valuable for Small Datasets
Mixup and CutMix provided an additional 2.5% improvement on top of already strong augmentation. This is because they create novel training examples by combining existing ones, effectively providing a combinatorial expansion of the dataset.
4. Lower Training Accuracy Is a Good Sign
Across all strategies, lower training accuracy correlated with higher validation accuracy. This is the regularization tradeoff in action: making the training task harder forces the model to learn more robust features.
5. The Order of Augmentation Matters
The team found that applying normalization before custom tensor-based augmentations (like the night vision simulation) gave different results than the reverse order. The final pipeline was carefully designed to apply PIL-based transforms before ToTensor(), and tensor-based custom transforms after Normalize().
Recommendations for Small Dataset Projects
Based on this case study, the recommended approach for small datasets (fewer than 1,000 samples per class) is:
- Always start with transfer learning from a model pretrained on a large dataset.
- Use aggressive augmentation as the default. Scale the augmentation strength inversely with dataset size.
- Try RandAugment first if you lack domain knowledge. It is simple and effective.
- Add domain-specific augmentations if you have expert knowledge about the deployment conditions.
- Include mixup and/or CutMix for an additional regularization boost.
- Combine with other regularization techniques: weight decay, dropout, label smoothing, and early stopping.
- Train longer with strong augmentation. The model needs more epochs to converge because each epoch effectively presents different data.
Code Reference
The complete implementation for this case study is available in code/case-study-code.py.