Case Study 2: Climate Satellite Imagery — Fine-Tuning a Pretrained ViT for Land-Use Classification
Context
The Pacific Climate Research Consortium (PCRC) monitors land-use change across the Pacific Basin — deforestation in Southeast Asia, urbanization along Pacific coastlines, agricultural expansion in Australia, and glacier retreat in New Zealand. Tracking these changes requires classifying satellite imagery into land-use categories at scale: millions of image patches per year, each needing assignment to one of six classes (forest, cropland, urban, water, barren, wetland).
The consortium has 8,000 labeled satellite image patches — manually annotated by trained geographers over two years. Each patch is a 224x224 pixel RGB composite from Sentinel-2 imagery, covering approximately 1 km$^2$ at 10-meter spatial resolution.
8,000 labeled images is far too few to train a Vision Transformer from scratch (ViT-Base has 86 million parameters), but it is a reasonable dataset for fine-tuning. The question is: how to fine-tune effectively, given the substantial domain gap between ImageNet photographs and nadir satellite imagery.
The Domain Gap
Satellite imagery differs from ImageNet in nearly every visual dimension:
| Property | ImageNet | Sentinel-2 Satellite |
|---|---|---|
| Viewing angle | Ground-level, perspective | Nadir (top-down, orthographic) |
| Object scale | Objects fill the frame | Objects are tiny (buildings = few pixels) |
| Color palette | Full natural color range | Dominated by greens, browns, blues |
| Texture | Animal fur, fabric, skin | Canopy patterns, agricultural grids, urban blocks |
| Semantic categories | 1,000 everyday objects | 6 land-use classes (landscape-level) |
| Lighting | Variable (indoor/outdoor) | Consistent (sun angle varies slowly) |
Despite these differences, ImageNet-pretrained features provide a strong starting point. Early-layer features (edges, color gradients, textures) are general enough to be useful for satellite imagery, even if later-layer features (object parts, scene layouts) are not directly transferable.
Experimental Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
from torchvision import transforms
from transformers import ViTForImageClassification, ViTFeatureExtractor
import numpy as np
from typing import Dict, Tuple
from sklearn.metrics import classification_report, confusion_matrix
# Land-use classes
LAND_USE_CLASSES = ["forest", "cropland", "urban", "water", "barren", "wetland"]
def generate_satellite_dataset(
n_images: int = 8000,
img_size: int = 224,
seed: int = 42,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Generate synthetic satellite imagery for the land-use classification task.
Each class has distinct spatial frequency and color characteristics
that approximate real satellite imagery patterns:
- Forest: high-frequency green textures (canopy)
- Cropland: regular grid patterns (agricultural fields)
- Urban: high-contrast edges (buildings, roads)
- Water: smooth blue gradients (lakes, ocean)
- Barren: low-frequency brown/tan (desert, rock)
- Wetland: mixed green-blue with irregular boundaries
Args:
n_images: Total number of images.
img_size: Image spatial resolution.
seed: Random seed.
Returns:
Tuple of (images, labels).
images: (n_images, 3, img_size, img_size) float32 in [0, 1].
labels: (n_images,) int64 in [0, 5].
"""
rng = np.random.RandomState(seed)
images = np.zeros((n_images, 3, img_size, img_size), dtype=np.float32)
labels = np.zeros(n_images, dtype=np.int64)
samples_per_class = n_images // len(LAND_USE_CLASSES)
x = np.linspace(0, 4 * np.pi, img_size)
xx, yy = np.meshgrid(x, x)
for class_idx in range(len(LAND_USE_CLASSES)):
start = class_idx * samples_per_class
end = start + samples_per_class
for i in range(start, end):
noise = rng.randn(3, img_size, img_size) * 0.05
if class_idx == 0: # Forest
freq = rng.uniform(2, 6)
base = 0.3 + 0.2 * np.sin(freq * xx + rng.uniform(0, 2*np.pi))
images[i, 0] = 0.1 + noise[0] # Low red
images[i, 1] = base + noise[1] # High green
images[i, 2] = 0.1 + noise[2] # Low blue
elif class_idx == 1: # Cropland
grid_freq = rng.choice([3, 4, 5])
grid = 0.5 * (np.sin(grid_freq * xx) > 0).astype(np.float32)
images[i, 0] = 0.3 + 0.1 * grid + noise[0]
images[i, 1] = 0.4 + 0.15 * grid + noise[1]
images[i, 2] = 0.15 + noise[2]
elif class_idx == 2: # Urban
edges = (np.abs(np.sin(8 * xx) * np.sin(8 * yy)) > 0.5).astype(np.float32)
images[i, 0] = 0.4 + 0.2 * edges + noise[0]
images[i, 1] = 0.35 + 0.15 * edges + noise[1]
images[i, 2] = 0.35 + 0.15 * edges + noise[2]
elif class_idx == 3: # Water
gradient = 0.5 + 0.3 * np.sin(0.5 * xx + rng.uniform(0, np.pi))
images[i, 0] = 0.05 + noise[0]
images[i, 1] = 0.15 + 0.1 * gradient + noise[1]
images[i, 2] = gradient * 0.7 + noise[2]
elif class_idx == 4: # Barren
freq = rng.uniform(0.5, 1.5)
base = 0.5 + 0.15 * np.sin(freq * xx + freq * yy)
images[i, 0] = base + noise[0]
images[i, 1] = base * 0.8 + noise[1]
images[i, 2] = base * 0.5 + noise[2]
elif class_idx == 5: # Wetland
mask = (np.sin(3 * xx + rng.uniform(0, 2*np.pi)) > 0).astype(np.float32)
images[i, 0] = 0.1 * mask + 0.05 + noise[0]
images[i, 1] = 0.3 * mask + 0.15 * (1 - mask) + noise[1]
images[i, 2] = 0.1 * mask + 0.4 * (1 - mask) + noise[2]
labels[i] = class_idx
images = np.clip(images, 0, 1)
# Shuffle
perm = rng.permutation(n_images)
return torch.tensor(images[perm]), torch.tensor(labels[perm])
Comparing Transfer Strategies
PCRC evaluates five strategies on the same data, using a 60/20/20 train/val/test split:
def evaluate_strategy(
model: nn.Module,
test_loader: DataLoader,
device: torch.device,
) -> Dict[str, float]:
"""Evaluate a trained model on the test set.
Args:
model: Trained classification model.
test_loader: Test data loader.
device: Computation device.
Returns:
Dictionary with accuracy and per-class F1 scores.
"""
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
if hasattr(outputs, "logits"):
logits = outputs.logits
else:
logits = outputs
preds = logits.argmax(dim=-1)
all_preds.append(preds.cpu())
all_labels.append(labels.cpu())
all_preds = torch.cat(all_preds).numpy()
all_labels = torch.cat(all_labels).numpy()
accuracy = (all_preds == all_labels).mean()
report = classification_report(
all_labels, all_preds,
target_names=LAND_USE_CLASSES,
output_dict=True,
)
return {
"accuracy": accuracy,
"macro_f1": report["macro avg"]["f1-score"],
"per_class_f1": {
cls: report[cls]["f1-score"] for cls in LAND_USE_CLASSES
},
}
def fine_tune_vit(
strategy: str,
train_loader: DataLoader,
val_loader: DataLoader,
num_classes: int = 6,
epochs: int = 15,
device: torch.device = torch.device("cpu"),
) -> nn.Module:
"""Fine-tune a ViT with the specified strategy.
Strategies:
- 'linear_probe': Freeze backbone, train only classifier head.
- 'last_block': Freeze all but last transformer block + head.
- 'full_ft': Unfreeze everything, uniform learning rate.
- 'differential_lr': Unfreeze everything, lower LR for backbone.
- 'progressive': Progressive unfreezing over training epochs.
Args:
strategy: One of the strategy names above.
train_loader: Training data loader.
val_loader: Validation data loader.
num_classes: Number of land-use classes.
epochs: Number of training epochs.
device: Computation device.
Returns:
Trained model.
"""
model = ViTForImageClassification.from_pretrained(
"google/vit-base-patch16-224",
num_labels=num_classes,
ignore_mismatched_sizes=True,
).to(device)
if strategy == "linear_probe":
for name, param in model.named_parameters():
if "classifier" not in name:
param.requires_grad = False
optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=1e-3, weight_decay=0.01,
)
elif strategy == "last_block":
for name, param in model.named_parameters():
if "classifier" not in name and "encoder.layer.11" not in name:
param.requires_grad = False
optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=5e-4, weight_decay=0.01,
)
elif strategy == "full_ft":
optimizer = torch.optim.AdamW(
model.parameters(), lr=2e-5, weight_decay=0.01,
)
elif strategy == "differential_lr":
backbone_params = [
p for n, p in model.named_parameters() if "classifier" not in n
]
head_params = [
p for n, p in model.named_parameters() if "classifier" in n
]
optimizer = torch.optim.AdamW([
{"params": backbone_params, "lr": 2e-5},
{"params": head_params, "lr": 1e-3},
], weight_decay=0.01)
elif strategy == "progressive":
# Start with only classifier unfrozen; unfreeze layers progressively
for param in model.parameters():
param.requires_grad = False
for param in model.classifier.parameters():
param.requires_grad = True
optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=1e-3, weight_decay=0.01,
)
else:
raise ValueError(f"Unknown strategy: {strategy}")
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
criterion = nn.CrossEntropyLoss()
best_val_acc = 0.0
for epoch in range(epochs):
# Progressive unfreezing: unfreeze one layer block every 3 epochs
if strategy == "progressive":
blocks_to_unfreeze = min(epoch // 3, 12)
for i in range(12 - blocks_to_unfreeze, 12):
for param in model.vit.encoder.layer[i].parameters():
param.requires_grad = True
# Rebuild optimizer with newly unfrozen parameters
trainable = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(trainable, lr=2e-5, weight_decay=0.01)
model.train()
total_loss = 0.0
correct = 0
total = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs.logits, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
correct += (outputs.logits.argmax(dim=-1) == labels).sum().item()
total += labels.size(0)
scheduler.step()
# Validation
model.eval()
val_correct = 0
val_total = 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
val_correct += (outputs.logits.argmax(dim=-1) == labels).sum().item()
val_total += labels.size(0)
val_acc = val_correct / val_total
best_val_acc = max(best_val_acc, val_acc)
return model
Results
Strategy Comparison
| Strategy | Test Accuracy | Macro F1 | Trainable Params | Training Time |
|---|---|---|---|---|
| Linear probe | 0.719 | 0.712 | 4,614 (0.005%) | 1x |
| Last block unfreeze | 0.802 | 0.798 | 7.1M (8.3%) | 1.8x |
| Full fine-tuning (uniform LR) | 0.841 | 0.836 | 85.8M (100%) | 4.2x |
| Full fine-tuning (differential LR) | 0.867 | 0.863 | 85.8M (100%) | 4.2x |
| Progressive unfreezing | 0.873 | 0.870 | 85.8M (100%) | 5.1x |
Progressive unfreezing achieves the best accuracy, outperforming full fine-tuning with uniform learning rate by 3.2 percentage points. The differential learning rate alone accounts for 2.6 points of improvement — a substantial gain from a single hyperparameter change.
Per-Class Performance
The linear probe struggles most with classes that require understanding spatial relationships (urban vs. cropland — both have regular grid patterns), while fine-tuned models resolve these distinctions:
| Class | Linear Probe F1 | Progressive Unfreeze F1 |
|---|---|---|
| Forest | 0.83 | 0.93 |
| Cropland | 0.62 | 0.84 |
| Urban | 0.58 | 0.82 |
| Water | 0.91 | 0.96 |
| Barren | 0.74 | 0.89 |
| Wetland | 0.59 | 0.78 |
Water is the easiest class (distinctive color signature transfers directly from ImageNet, where water appears in many training images). Wetland is the hardest (requires distinguishing a mixture of water and vegetation, a pattern rare in ImageNet).
Data Efficiency
PCRC evaluated how each strategy degrades with less labeled data:
| Labeled Images | Linear Probe | Full FT (Diff. LR) | Progressive Unfreeze |
|---|---|---|---|
| 500 | 0.621 | 0.583 | 0.654 |
| 1,000 | 0.658 | 0.712 | 0.739 |
| 2,000 | 0.689 | 0.781 | 0.803 |
| 4,000 | 0.708 | 0.839 | 0.856 |
| 8,000 | 0.719 | 0.867 | 0.873 |
Two findings stand out:
-
At 500 labels, full fine-tuning with uniform LR underperforms the linear probe. This is negative transfer in action: with too few examples and too many parameters, the model overfits to training noise and destroys the pretrained features. Progressive unfreezing mitigates this by keeping most layers frozen during early training.
-
The gap between linear probe and fine-tuning narrows as data decreases. With abundant data, fine-tuning can adapt the backbone to the target domain; with scarce data, the pretrained features — imperfect as they are for satellite imagery — are the best available.
Domain-Specific Pretraining
After the fine-tuning experiments, PCRC explored a more aggressive approach: self-supervised pretraining on 200,000 unlabeled satellite images using DINO (self-distillation with no labels), followed by fine-tuning on the 8,000 labeled images.
| Approach | Test Accuracy |
|---|---|
| ImageNet ViT → linear probe | 0.719 |
| ImageNet ViT → fine-tune (progressive) | 0.873 |
| Satellite DINO ViT → linear probe | 0.812 |
| Satellite DINO ViT → fine-tune (progressive) | 0.911 |
Self-supervised pretraining on in-domain satellite data produced a 3.8-point improvement over ImageNet pretraining when combined with supervised fine-tuning. The linear probe result is particularly striking: 0.812 for satellite DINO vs. 0.719 for ImageNet — demonstrating that domain-specific pretraining learns features that are linearly separable for the target task, while ImageNet features require nonlinear adaptation.
Lessons Learned
-
ImageNet pretraining is a strong baseline, not the ceiling. Despite the large domain gap, ImageNet features transfer well enough to beat training from scratch with limited data. But domain-specific self-supervised pretraining (when unlabeled domain data is available) provides a substantially better starting point.
-
Learning rate discipline is essential for fine-tuning. The 3.2-point gap between uniform and differential learning rates is not a minor detail — it is the difference between a model that partially destroys its pretrained features and one that preserves them. Progressive unfreezing adds another 0.6 points by being even more disciplined about which layers adapt and when.
-
The pretrained model's failure modes are predictable from the domain gap. Water (familiar from ImageNet) classifies well even with a linear probe. Wetland (a mixture of two landscape types, rare in ImageNet) is the hardest class. This pattern — understanding where the pretrained model will struggle — guides the practitioner toward targeted data collection and augmentation.
-
Compute efficiency matters for scientific applications. PCRC operates on a research budget, not a tech company budget. Progressive unfreezing on an ImageNet ViT takes 5 GPU-hours. Training from scratch to match that accuracy would require >100 GPU-hours and data they do not have. Transfer learning is not just technically superior — it is the only approach that fits their constraints.