Case Study 2: Data Augmentation Strategies for Small Datasets

Background

A wildlife conservation team needs to classify camera trap images of endangered species in a remote rainforest. They have collected only 500 labeled images across 5 species: jaguar (120), ocelot (95), tapir (110), peccary (90), and capybara (85). The images vary significantly in lighting (day vs. night infrared), angle, occlusion (vegetation), and image quality.

The challenge: build a classifier that achieves at least 85% accuracy despite having fewer than 100 images per class for some species. This case study explores how different data augmentation strategies can make or break model performance on extremely small datasets.

The Challenge of Small Datasets

With only 500 images and a modern CNN, the ratio of parameters to training examples is staggering:

Model Parameters Params per Image
ResNet-18 11.2M 22,400x
EfficientNet-B0 5.3M 10,600x
MobileNetV3-Small 2.5M 5,000x

Even the smallest model has 5,000 parameters per training image. Without aggressive regularization, memorization is inevitable.

Experiment Setup

import torch
import torch.nn as nn
from torchvision import transforms, models
from torch.utils.data import DataLoader, random_split

torch.manual_seed(42)

# We use a pretrained EfficientNet-B0 as the backbone
# Transfer learning is essential for small datasets (see Chapter 15)
model = models.efficientnet_b0(
    weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1
)
model.classifier = nn.Sequential(
    nn.Dropout(0.4),
    nn.Linear(1280, 5),
)

# Split: 350 train, 75 val, 75 test
# Stratified to maintain class balance

Strategy 1: No Augmentation (Baseline)

baseline_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

Results after 50 epochs: | Metric | Value | |--------|-------| | Training accuracy | 100.0% | | Validation accuracy | 58.7% | | Generalization gap | 41.3% |

Complete memorization, as expected. The model learned to identify individual images rather than species.

Strategy 2: Basic Augmentation

Standard geometric and color transformations:

basic_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(
        brightness=0.3,
        contrast=0.3,
        saturation=0.2,
    ),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

Results: | Metric | Value | Change from Baseline | |--------|-------|---------------------| | Training accuracy | 95.2% | -4.8% | | Validation accuracy | 72.0% | +13.3% | | Generalization gap | 23.2% | -18.1% |

A significant improvement, but still insufficient. The augmentations are too mild for such a small dataset.

Strategy 3: Aggressive Augmentation

Push the augmentation much further:

aggressive_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.3, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.2),
    transforms.RandomRotation(degrees=45),
    transforms.RandomAffine(
        degrees=0,
        translate=(0.2, 0.2),
        scale=(0.8, 1.2),
        shear=15,
    ),
    transforms.ColorJitter(
        brightness=0.5,
        contrast=0.5,
        saturation=0.5,
        hue=0.1,
    ),
    transforms.RandomGrayscale(p=0.1),
    transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
    transforms.RandomErasing(p=0.3, scale=(0.02, 0.25)),
])

Results: | Metric | Value | Change from Basic | |--------|-------|-------------------| | Training accuracy | 84.6% | -10.6% | | Validation accuracy | 79.3% | +7.3% | | Generalization gap | 5.3% | -17.9% |

Much better. The aggressive augmentation makes the training task harder (lower training accuracy) but the model learns more generalizable features.

Strategy 4: Domain-Specific Augmentation

Camera trap images have specific characteristics that suggest targeted augmentations:

import torch
from torchvision import transforms

torch.manual_seed(42)


class SimulateNightVision:
    """Simulate infrared/night vision camera appearance.

    Many camera traps switch to infrared at night,
    producing grayscale images with different contrast profiles.
    """

    def __init__(self, p: float = 0.3) -> None:
        self.p = p

    def __call__(self, img: torch.Tensor) -> torch.Tensor:
        """Apply night vision simulation.

        Args:
            img: Input tensor image.

        Returns:
            Possibly modified image tensor.
        """
        if torch.rand(1).item() < self.p:
            # Convert to grayscale-like by reducing saturation
            gray = img.mean(dim=0, keepdim=True)
            # Add slight green tint (common in night vision)
            result = torch.zeros_like(img)
            result[0] = gray[0] * 0.8
            result[1] = gray[0] * 1.1
            result[2] = gray[0] * 0.7
            return result.clamp(0, 1)
        return img


class SimulateVegetationOcclusion:
    """Simulate partial occlusion by vegetation.

    Camera trap images frequently have branches, leaves,
    or grass partially occluding the animal.
    """

    def __init__(self, p: float = 0.3, max_bars: int = 5) -> None:
        self.p = p
        self.max_bars = max_bars

    def __call__(self, img: torch.Tensor) -> torch.Tensor:
        """Apply vegetation-like occlusion.

        Args:
            img: Input tensor image of shape (C, H, W).

        Returns:
            Image with random bar occlusions.
        """
        if torch.rand(1).item() < self.p:
            _, h, w = img.shape
            num_bars = torch.randint(1, self.max_bars + 1, (1,)).item()
            result = img.clone()
            for _ in range(num_bars):
                bar_width = torch.randint(2, 8, (1,)).item()
                if torch.rand(1).item() > 0.5:
                    # Horizontal bar
                    y = torch.randint(0, h - bar_width, (1,)).item()
                    # Dark green color for vegetation
                    result[0, y:y + bar_width, :] *= 0.3
                    result[1, y:y + bar_width, :] = (
                        result[1, y:y + bar_width, :] * 0.5 + 0.2
                    )
                    result[2, y:y + bar_width, :] *= 0.3
                else:
                    # Vertical bar
                    x = torch.randint(0, w - bar_width, (1,)).item()
                    result[0, :, x:x + bar_width] *= 0.3
                    result[1, :, x:x + bar_width] = (
                        result[1, :, x:x + bar_width] * 0.5 + 0.2
                    )
                    result[2, :, x:x + bar_width] *= 0.3
            return result.clamp(0, 1)
        return img


domain_specific_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.4, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=30),
    transforms.ColorJitter(
        brightness=0.5,
        contrast=0.5,
        saturation=0.4,
        hue=0.05,
    ),
    transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
    # Domain-specific augmentations applied after ToTensor
    SimulateNightVision(p=0.3),
    SimulateVegetationOcclusion(p=0.3),
    transforms.RandomErasing(p=0.3, scale=(0.02, 0.2)),
])

Results: | Metric | Value | Change from Aggressive | |--------|-------|----------------------| | Training accuracy | 82.1% | -2.5% | | Validation accuracy | 82.7% | +3.4% | | Generalization gap | -0.6% | -5.9% |

Domain-specific augmentations further improved validation accuracy by simulating conditions the model would encounter in the real world.

Strategy 5: RandAugment

Using automated augmentation policies:

from torchvision.transforms import autoaugment

randaug_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.4, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    autoaugment.RandAugment(num_ops=3, magnitude=12),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
    transforms.RandomErasing(p=0.2),
])

Results: | Metric | Value | Change from Domain | |--------|-------|-------------------| | Training accuracy | 80.5% | -1.6% | | Validation accuracy | 83.1% | +0.4% | | Generalization gap | -2.6% | -2.0% |

RandAugment performed comparably to the domain-specific strategy, which is notable because it required zero domain knowledge to set up.

Strategy 6: Mixup + CutMix + Strong Augmentation

Combining sample-level augmentation with the best spatial augmentation:

import torch

torch.manual_seed(42)


def train_step_with_mixing(
    model: nn.Module,
    x: torch.Tensor,
    y: torch.Tensor,
    criterion: nn.Module,
    optimizer: torch.optim.Optimizer,
    mixup_alpha: float = 0.3,
    cutmix_alpha: float = 1.0,
    mix_prob: float = 0.5,
) -> float:
    """Training step with randomly applied mixup or CutMix.

    Args:
        model: The neural network.
        x: Input batch.
        y: Target batch.
        criterion: Loss function.
        optimizer: Optimizer.
        mixup_alpha: Mixup beta distribution parameter.
        cutmix_alpha: CutMix beta distribution parameter.
        mix_prob: Probability of applying mixing.

    Returns:
        Loss value for this step.
    """
    model.train()
    optimizer.zero_grad()

    if torch.rand(1).item() < mix_prob:
        if torch.rand(1).item() < 0.5:
            # Apply mixup
            lam = torch.distributions.Beta(
                mixup_alpha, mixup_alpha
            ).sample().item()
            index = torch.randperm(x.size(0), device=x.device)
            mixed_x = lam * x + (1 - lam) * x[index]
            output = model(mixed_x)
            loss = lam * criterion(output, y) + (
                (1 - lam) * criterion(output, y[index])
            )
        else:
            # Apply CutMix
            lam = torch.distributions.Beta(
                cutmix_alpha, cutmix_alpha
            ).sample().item()
            index = torch.randperm(x.size(0), device=x.device)
            _, _, h, w = x.shape
            cut_ratio = (1.0 - lam) ** 0.5
            cut_h, cut_w = int(h * cut_ratio), int(w * cut_ratio)
            cx = torch.randint(0, w, (1,)).item()
            cy = torch.randint(0, h, (1,)).item()
            x1 = max(0, cx - cut_w // 2)
            y1 = max(0, cy - cut_h // 2)
            x2 = min(w, cx + cut_w // 2)
            y2 = min(h, cy + cut_h // 2)
            mixed_x = x.clone()
            mixed_x[:, :, y1:y2, x1:x2] = x[index, :, y1:y2, x1:x2]
            actual_lam = 1 - (y2 - y1) * (x2 - x1) / (h * w)
            output = model(mixed_x)
            loss = actual_lam * criterion(output, y) + (
                (1 - actual_lam) * criterion(output, y[index])
            )
    else:
        output = model(x)
        loss = criterion(output, y)

    loss.backward()
    optimizer.step()
    return loss.item()

Results: | Metric | Value | Change from RandAugment | |--------|-------|------------------------| | Training accuracy | 78.3% | -2.2% | | Validation accuracy | 85.6% | +2.5% | | Generalization gap | -7.3% | -4.7% |

This combined strategy finally achieved the 85% target.

Complete Results Summary

Strategy Train Acc Val Acc Gap Test Acc
No augmentation 100.0% 58.7% 41.3% 56.2%
Basic augmentation 95.2% 72.0% 23.2% 70.8%
Aggressive augmentation 84.6% 79.3% 5.3% 78.1%
Domain-specific 82.1% 82.7% -0.6% 81.4%
RandAugment 80.5% 83.1% -2.6% 82.3%
Mixing + strong aug 78.3% 85.6% -7.3% 84.9%

Key Observations

1. Augmentation Strength Scales with Data Scarcity

The optimal augmentation strength is inversely proportional to dataset size. With only 500 images, aggressive augmentation that would hurt on a large dataset became essential. This aligns with the recommendations in Section 13.11.4 of the main chapter.

2. Domain Knowledge Has Diminishing Returns

The domain-specific augmentations (night vision simulation, vegetation occlusion) provided a meaningful boost (+3.4% over aggressive augmentation), but RandAugment achieved similar results with no domain knowledge. For practitioners without domain expertise, automated policies are an excellent starting point.

3. Mixing Techniques Are Especially Valuable for Small Datasets

Mixup and CutMix provided an additional 2.5% improvement on top of already strong augmentation. This is because they create novel training examples by combining existing ones, effectively providing a combinatorial expansion of the dataset.

4. Lower Training Accuracy Is a Good Sign

Across all strategies, lower training accuracy correlated with higher validation accuracy. This is the regularization tradeoff in action: making the training task harder forces the model to learn more robust features.

5. The Order of Augmentation Matters

The team found that applying normalization before custom tensor-based augmentations (like the night vision simulation) gave different results than the reverse order. The final pipeline was carefully designed to apply PIL-based transforms before ToTensor(), and tensor-based custom transforms after Normalize().

Recommendations for Small Dataset Projects

Based on this case study, the recommended approach for small datasets (fewer than 1,000 samples per class) is:

  1. Always start with transfer learning from a model pretrained on a large dataset.
  2. Use aggressive augmentation as the default. Scale the augmentation strength inversely with dataset size.
  3. Try RandAugment first if you lack domain knowledge. It is simple and effective.
  4. Add domain-specific augmentations if you have expert knowledge about the deployment conditions.
  5. Include mixup and/or CutMix for an additional regularization boost.
  6. Combine with other regularization techniques: weight decay, dropout, label smoothing, and early stopping.
  7. Train longer with strong augmentation. The model needs more epochs to converge because each epoch effectively presents different data.

Code Reference

The complete implementation for this case study is available in code/case-study-code.py.