Case Study 2: Climate Satellite Imagery — Fine-Tuning a Pretrained ViT for Land-Use Classification

Context

The Pacific Climate Research Consortium (PCRC) monitors land-use change across the Pacific Basin — deforestation in Southeast Asia, urbanization along Pacific coastlines, agricultural expansion in Australia, and glacier retreat in New Zealand. Tracking these changes requires classifying satellite imagery into land-use categories at scale: millions of image patches per year, each needing assignment to one of six classes (forest, cropland, urban, water, barren, wetland).

The consortium has 8,000 labeled satellite image patches — manually annotated by trained geographers over two years. Each patch is a 224x224 pixel RGB composite from Sentinel-2 imagery, covering approximately 1 km$^2$ at 10-meter spatial resolution.

8,000 labeled images is far too few to train a Vision Transformer from scratch (ViT-Base has 86 million parameters), but it is a reasonable dataset for fine-tuning. The question is: how to fine-tune effectively, given the substantial domain gap between ImageNet photographs and nadir satellite imagery.

The Domain Gap

Satellite imagery differs from ImageNet in nearly every visual dimension:

Property ImageNet Sentinel-2 Satellite
Viewing angle Ground-level, perspective Nadir (top-down, orthographic)
Object scale Objects fill the frame Objects are tiny (buildings = few pixels)
Color palette Full natural color range Dominated by greens, browns, blues
Texture Animal fur, fabric, skin Canopy patterns, agricultural grids, urban blocks
Semantic categories 1,000 everyday objects 6 land-use classes (landscape-level)
Lighting Variable (indoor/outdoor) Consistent (sun angle varies slowly)

Despite these differences, ImageNet-pretrained features provide a strong starting point. Early-layer features (edges, color gradients, textures) are general enough to be useful for satellite imagery, even if later-layer features (object parts, scene layouts) are not directly transferable.

Experimental Setup

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
from torchvision import transforms
from transformers import ViTForImageClassification, ViTFeatureExtractor
import numpy as np
from typing import Dict, Tuple
from sklearn.metrics import classification_report, confusion_matrix


# Land-use classes
LAND_USE_CLASSES = ["forest", "cropland", "urban", "water", "barren", "wetland"]


def generate_satellite_dataset(
    n_images: int = 8000,
    img_size: int = 224,
    seed: int = 42,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Generate synthetic satellite imagery for the land-use classification task.

    Each class has distinct spatial frequency and color characteristics
    that approximate real satellite imagery patterns:
    - Forest: high-frequency green textures (canopy)
    - Cropland: regular grid patterns (agricultural fields)
    - Urban: high-contrast edges (buildings, roads)
    - Water: smooth blue gradients (lakes, ocean)
    - Barren: low-frequency brown/tan (desert, rock)
    - Wetland: mixed green-blue with irregular boundaries

    Args:
        n_images: Total number of images.
        img_size: Image spatial resolution.
        seed: Random seed.

    Returns:
        Tuple of (images, labels).
        images: (n_images, 3, img_size, img_size) float32 in [0, 1].
        labels: (n_images,) int64 in [0, 5].
    """
    rng = np.random.RandomState(seed)
    images = np.zeros((n_images, 3, img_size, img_size), dtype=np.float32)
    labels = np.zeros(n_images, dtype=np.int64)

    samples_per_class = n_images // len(LAND_USE_CLASSES)
    x = np.linspace(0, 4 * np.pi, img_size)
    xx, yy = np.meshgrid(x, x)

    for class_idx in range(len(LAND_USE_CLASSES)):
        start = class_idx * samples_per_class
        end = start + samples_per_class

        for i in range(start, end):
            noise = rng.randn(3, img_size, img_size) * 0.05

            if class_idx == 0:  # Forest
                freq = rng.uniform(2, 6)
                base = 0.3 + 0.2 * np.sin(freq * xx + rng.uniform(0, 2*np.pi))
                images[i, 0] = 0.1 + noise[0]  # Low red
                images[i, 1] = base + noise[1]   # High green
                images[i, 2] = 0.1 + noise[2]   # Low blue
            elif class_idx == 1:  # Cropland
                grid_freq = rng.choice([3, 4, 5])
                grid = 0.5 * (np.sin(grid_freq * xx) > 0).astype(np.float32)
                images[i, 0] = 0.3 + 0.1 * grid + noise[0]
                images[i, 1] = 0.4 + 0.15 * grid + noise[1]
                images[i, 2] = 0.15 + noise[2]
            elif class_idx == 2:  # Urban
                edges = (np.abs(np.sin(8 * xx) * np.sin(8 * yy)) > 0.5).astype(np.float32)
                images[i, 0] = 0.4 + 0.2 * edges + noise[0]
                images[i, 1] = 0.35 + 0.15 * edges + noise[1]
                images[i, 2] = 0.35 + 0.15 * edges + noise[2]
            elif class_idx == 3:  # Water
                gradient = 0.5 + 0.3 * np.sin(0.5 * xx + rng.uniform(0, np.pi))
                images[i, 0] = 0.05 + noise[0]
                images[i, 1] = 0.15 + 0.1 * gradient + noise[1]
                images[i, 2] = gradient * 0.7 + noise[2]
            elif class_idx == 4:  # Barren
                freq = rng.uniform(0.5, 1.5)
                base = 0.5 + 0.15 * np.sin(freq * xx + freq * yy)
                images[i, 0] = base + noise[0]
                images[i, 1] = base * 0.8 + noise[1]
                images[i, 2] = base * 0.5 + noise[2]
            elif class_idx == 5:  # Wetland
                mask = (np.sin(3 * xx + rng.uniform(0, 2*np.pi)) > 0).astype(np.float32)
                images[i, 0] = 0.1 * mask + 0.05 + noise[0]
                images[i, 1] = 0.3 * mask + 0.15 * (1 - mask) + noise[1]
                images[i, 2] = 0.1 * mask + 0.4 * (1 - mask) + noise[2]

            labels[i] = class_idx

    images = np.clip(images, 0, 1)

    # Shuffle
    perm = rng.permutation(n_images)
    return torch.tensor(images[perm]), torch.tensor(labels[perm])

Comparing Transfer Strategies

PCRC evaluates five strategies on the same data, using a 60/20/20 train/val/test split:

def evaluate_strategy(
    model: nn.Module,
    test_loader: DataLoader,
    device: torch.device,
) -> Dict[str, float]:
    """Evaluate a trained model on the test set.

    Args:
        model: Trained classification model.
        test_loader: Test data loader.
        device: Computation device.

    Returns:
        Dictionary with accuracy and per-class F1 scores.
    """
    model.eval()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            if hasattr(outputs, "logits"):
                logits = outputs.logits
            else:
                logits = outputs
            preds = logits.argmax(dim=-1)
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())

    all_preds = torch.cat(all_preds).numpy()
    all_labels = torch.cat(all_labels).numpy()

    accuracy = (all_preds == all_labels).mean()
    report = classification_report(
        all_labels, all_preds,
        target_names=LAND_USE_CLASSES,
        output_dict=True,
    )

    return {
        "accuracy": accuracy,
        "macro_f1": report["macro avg"]["f1-score"],
        "per_class_f1": {
            cls: report[cls]["f1-score"] for cls in LAND_USE_CLASSES
        },
    }


def fine_tune_vit(
    strategy: str,
    train_loader: DataLoader,
    val_loader: DataLoader,
    num_classes: int = 6,
    epochs: int = 15,
    device: torch.device = torch.device("cpu"),
) -> nn.Module:
    """Fine-tune a ViT with the specified strategy.

    Strategies:
        - 'linear_probe': Freeze backbone, train only classifier head.
        - 'last_block': Freeze all but last transformer block + head.
        - 'full_ft': Unfreeze everything, uniform learning rate.
        - 'differential_lr': Unfreeze everything, lower LR for backbone.
        - 'progressive': Progressive unfreezing over training epochs.

    Args:
        strategy: One of the strategy names above.
        train_loader: Training data loader.
        val_loader: Validation data loader.
        num_classes: Number of land-use classes.
        epochs: Number of training epochs.
        device: Computation device.

    Returns:
        Trained model.
    """
    model = ViTForImageClassification.from_pretrained(
        "google/vit-base-patch16-224",
        num_labels=num_classes,
        ignore_mismatched_sizes=True,
    ).to(device)

    if strategy == "linear_probe":
        for name, param in model.named_parameters():
            if "classifier" not in name:
                param.requires_grad = False
        optimizer = torch.optim.AdamW(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=1e-3, weight_decay=0.01,
        )

    elif strategy == "last_block":
        for name, param in model.named_parameters():
            if "classifier" not in name and "encoder.layer.11" not in name:
                param.requires_grad = False
        optimizer = torch.optim.AdamW(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=5e-4, weight_decay=0.01,
        )

    elif strategy == "full_ft":
        optimizer = torch.optim.AdamW(
            model.parameters(), lr=2e-5, weight_decay=0.01,
        )

    elif strategy == "differential_lr":
        backbone_params = [
            p for n, p in model.named_parameters() if "classifier" not in n
        ]
        head_params = [
            p for n, p in model.named_parameters() if "classifier" in n
        ]
        optimizer = torch.optim.AdamW([
            {"params": backbone_params, "lr": 2e-5},
            {"params": head_params, "lr": 1e-3},
        ], weight_decay=0.01)

    elif strategy == "progressive":
        # Start with only classifier unfrozen; unfreeze layers progressively
        for param in model.parameters():
            param.requires_grad = False
        for param in model.classifier.parameters():
            param.requires_grad = True
        optimizer = torch.optim.AdamW(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=1e-3, weight_decay=0.01,
        )
    else:
        raise ValueError(f"Unknown strategy: {strategy}")

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    criterion = nn.CrossEntropyLoss()
    best_val_acc = 0.0

    for epoch in range(epochs):
        # Progressive unfreezing: unfreeze one layer block every 3 epochs
        if strategy == "progressive":
            blocks_to_unfreeze = min(epoch // 3, 12)
            for i in range(12 - blocks_to_unfreeze, 12):
                for param in model.vit.encoder.layer[i].parameters():
                    param.requires_grad = True
            # Rebuild optimizer with newly unfrozen parameters
            trainable = [p for p in model.parameters() if p.requires_grad]
            optimizer = torch.optim.AdamW(trainable, lr=2e-5, weight_decay=0.01)

        model.train()
        total_loss = 0.0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs.logits, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()
            correct += (outputs.logits.argmax(dim=-1) == labels).sum().item()
            total += labels.size(0)

        scheduler.step()

        # Validation
        model.eval()
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                val_correct += (outputs.logits.argmax(dim=-1) == labels).sum().item()
                val_total += labels.size(0)

        val_acc = val_correct / val_total
        best_val_acc = max(best_val_acc, val_acc)

    return model

Results

Strategy Comparison

Strategy Test Accuracy Macro F1 Trainable Params Training Time
Linear probe 0.719 0.712 4,614 (0.005%) 1x
Last block unfreeze 0.802 0.798 7.1M (8.3%) 1.8x
Full fine-tuning (uniform LR) 0.841 0.836 85.8M (100%) 4.2x
Full fine-tuning (differential LR) 0.867 0.863 85.8M (100%) 4.2x
Progressive unfreezing 0.873 0.870 85.8M (100%) 5.1x

Progressive unfreezing achieves the best accuracy, outperforming full fine-tuning with uniform learning rate by 3.2 percentage points. The differential learning rate alone accounts for 2.6 points of improvement — a substantial gain from a single hyperparameter change.

Per-Class Performance

The linear probe struggles most with classes that require understanding spatial relationships (urban vs. cropland — both have regular grid patterns), while fine-tuned models resolve these distinctions:

Class Linear Probe F1 Progressive Unfreeze F1
Forest 0.83 0.93
Cropland 0.62 0.84
Urban 0.58 0.82
Water 0.91 0.96
Barren 0.74 0.89
Wetland 0.59 0.78

Water is the easiest class (distinctive color signature transfers directly from ImageNet, where water appears in many training images). Wetland is the hardest (requires distinguishing a mixture of water and vegetation, a pattern rare in ImageNet).

Data Efficiency

PCRC evaluated how each strategy degrades with less labeled data:

Labeled Images Linear Probe Full FT (Diff. LR) Progressive Unfreeze
500 0.621 0.583 0.654
1,000 0.658 0.712 0.739
2,000 0.689 0.781 0.803
4,000 0.708 0.839 0.856
8,000 0.719 0.867 0.873

Two findings stand out:

  1. At 500 labels, full fine-tuning with uniform LR underperforms the linear probe. This is negative transfer in action: with too few examples and too many parameters, the model overfits to training noise and destroys the pretrained features. Progressive unfreezing mitigates this by keeping most layers frozen during early training.

  2. The gap between linear probe and fine-tuning narrows as data decreases. With abundant data, fine-tuning can adapt the backbone to the target domain; with scarce data, the pretrained features — imperfect as they are for satellite imagery — are the best available.

Domain-Specific Pretraining

After the fine-tuning experiments, PCRC explored a more aggressive approach: self-supervised pretraining on 200,000 unlabeled satellite images using DINO (self-distillation with no labels), followed by fine-tuning on the 8,000 labeled images.

Approach Test Accuracy
ImageNet ViT → linear probe 0.719
ImageNet ViT → fine-tune (progressive) 0.873
Satellite DINO ViT → linear probe 0.812
Satellite DINO ViT → fine-tune (progressive) 0.911

Self-supervised pretraining on in-domain satellite data produced a 3.8-point improvement over ImageNet pretraining when combined with supervised fine-tuning. The linear probe result is particularly striking: 0.812 for satellite DINO vs. 0.719 for ImageNet — demonstrating that domain-specific pretraining learns features that are linearly separable for the target task, while ImageNet features require nonlinear adaptation.

Lessons Learned

  1. ImageNet pretraining is a strong baseline, not the ceiling. Despite the large domain gap, ImageNet features transfer well enough to beat training from scratch with limited data. But domain-specific self-supervised pretraining (when unlabeled domain data is available) provides a substantially better starting point.

  2. Learning rate discipline is essential for fine-tuning. The 3.2-point gap between uniform and differential learning rates is not a minor detail — it is the difference between a model that partially destroys its pretrained features and one that preserves them. Progressive unfreezing adds another 0.6 points by being even more disciplined about which layers adapt and when.

  3. The pretrained model's failure modes are predictable from the domain gap. Water (familiar from ImageNet) classifies well even with a linear probe. Wetland (a mixture of two landscape types, rare in ImageNet) is the hardest class. This pattern — understanding where the pretrained model will struggle — guides the practitioner toward targeted data collection and augmentation.

  4. Compute efficiency matters for scientific applications. PCRC operates on a research budget, not a tech company budget. Progressive unfreezing on an ImageNet ViT takes 5 GPU-hours. Training from scratch to match that accuracy would require >100 GPU-hours and data they do not have. Transfer learning is not just technically superior — it is the only approach that fits their constraints.