Case Study 2: Climate ViT — Vision Transformer for Gridded Climate Data
Context
In Chapter 8's Case Study 1, the Pacific Climate Research Consortium (PCRC) built a CNN-based downscaling model that maps coarse-resolution CMIP6 climate data ($16 \times 16$ grid, 6 variables) to high-resolution daily maximum temperature fields ($64 \times 64$ grid). The CNN exploited locality and translation equivariance — precisely the right inductive biases for spatial data with local correlations.
But the climate team has identified a limitation. Some climate phenomena are inherently non-local: teleconnections like the El Nino-Southern Oscillation (ENSO) create long-range spatial dependencies where Pacific Ocean temperatures influence rainfall patterns in East Africa, thousands of kilometers away. The CNN captures these long-range patterns only through deep stacks of convolutional layers, with each layer extending the receptive field by a few pixels. A 10-layer CNN with $3 \times 3$ kernels has an effective receptive field of $\sim 21 \times 21$ pixels — barely larger than the $16 \times 16$ input.
The team proposes applying a Vision Transformer (ViT) to the downscaling task. Unlike the CNN, self-attention computes pairwise interactions between all spatial positions in a single layer — every grid cell can attend to every other grid cell, regardless of distance. If teleconnections matter for downscaling, the ViT should capture them more effectively than the CNN.
This case study implements the Climate ViT, compares it against the CNN from Chapter 8, and analyzes what the attention maps reveal about the spatial structure of climate data.
From Images to Climate Grids
The Vision Transformer (Dosovitskiy et al., 2020) adapts the transformer to spatial data by treating an image as a sequence of patches:
- Divide the input image into non-overlapping patches.
- Flatten each patch into a vector.
- Project each flattened patch to $d_{\text{model}}$ dimensions.
- Prepend a learnable [CLS] token (for classification) or process all patches (for dense prediction).
- Add positional embeddings.
- Process with a standard transformer encoder.
For climate data, the "image" is a $16 \times 16$ grid with 6 channels (temperature, pressure, humidity, wind-u, wind-v, geopotential height). Using $4 \times 4$ patches, we get $16$ patches, each of dimension $4 \times 4 \times 6 = 96$.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from typing import Tuple, List, Dict
from torch.utils.data import Dataset, DataLoader, random_split
class PatchEmbedding(nn.Module):
"""Convert a gridded climate field into a sequence of patch embeddings.
Splits the spatial grid into non-overlapping patches, flattens each,
and projects to d_model dimensions.
"""
def __init__(
self,
grid_size: int = 16,
patch_size: int = 4,
in_channels: int = 6,
d_model: int = 128,
) -> None:
super().__init__()
self.patch_size = patch_size
self.num_patches = (grid_size // patch_size) ** 2
patch_dim = patch_size * patch_size * in_channels
self.projection = nn.Linear(patch_dim, d_model)
self.norm = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Convert grid to patch embeddings.
Args:
x: Climate field of shape (batch, channels, H, W).
Returns:
Patch embeddings of shape (batch, num_patches, d_model).
"""
batch_size, c, h, w = x.shape
p = self.patch_size
# Reshape into patches: (batch, num_patches, patch_dim)
x = x.unfold(2, p, p).unfold(3, p, p) # (batch, c, h/p, w/p, p, p)
x = x.contiguous().view(batch_size, c, -1, p, p) # (batch, c, num_patches, p, p)
x = x.permute(0, 2, 1, 3, 4) # (batch, num_patches, c, p, p)
x = x.reshape(batch_size, self.num_patches, -1) # (batch, num_patches, patch_dim)
x = self.projection(x)
x = self.norm(x)
return x
class ClimateViT(nn.Module):
"""Vision Transformer for climate downscaling.
Takes a coarse-resolution climate grid and produces a
high-resolution temperature field through:
1. Patch embedding of the coarse grid
2. Transformer encoder with self-attention
3. Dense upsampling head to produce the fine grid
"""
def __init__(
self,
grid_size: int = 16,
patch_size: int = 4,
in_channels: int = 6,
d_model: int = 128,
num_heads: int = 4,
num_layers: int = 4,
d_ff: int = 512,
output_size: int = 64,
dropout: float = 0.1,
) -> None:
super().__init__()
self.num_patches = (grid_size // patch_size) ** 2
self.d_model = d_model
self.patch_size = patch_size
self.grid_patches_per_side = grid_size // patch_size
# Patch embedding
self.patch_embed = PatchEmbedding(grid_size, patch_size, in_channels, d_model)
# Learned positional embedding (positions correspond to spatial locations)
self.pos_embed = nn.Parameter(
torch.randn(1, self.num_patches, d_model) * 0.02
)
self.dropout = nn.Dropout(dropout)
# Transformer encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=num_heads,
dim_feedforward=d_ff,
dropout=dropout,
activation="gelu",
batch_first=True,
norm_first=True,
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.final_norm = nn.LayerNorm(d_model)
# Upsampling head: reshape patches back to grid, then upsample
# Each patch position produces a (output_size/grid_patches) x (output_size/grid_patches) tile
self.upsample_factor = output_size // self.grid_patches_per_side
tile_size = self.upsample_factor * self.upsample_factor
self.upsample_head = nn.Sequential(
nn.Linear(d_model, d_model),
nn.GELU(),
nn.Linear(d_model, tile_size),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Downscale coarse climate grid to fine temperature field.
Args:
x: Coarse grid of shape (batch, in_channels, 16, 16).
Returns:
Fine temperature field of shape (batch, 1, 64, 64).
"""
batch_size = x.size(0)
# Embed patches
patches = self.patch_embed(x) # (batch, num_patches, d_model)
patches = patches + self.pos_embed
patches = self.dropout(patches)
# Transformer encoding
encoded = self.transformer(patches) # (batch, num_patches, d_model)
encoded = self.final_norm(encoded)
# Upsample: each patch produces a tile of the output
tiles = self.upsample_head(encoded) # (batch, num_patches, tile_size)
tiles = tiles.view(
batch_size,
self.grid_patches_per_side,
self.grid_patches_per_side,
self.upsample_factor,
self.upsample_factor,
)
# Rearrange tiles into the output grid
output = tiles.permute(0, 1, 3, 2, 4).contiguous()
output = output.view(batch_size, 1, 64, 64)
return output
The Climate Downscaling Dataset
We reuse the synthetic climate data generator from Chapter 8, which produces coarse-resolution inputs with smooth large-scale patterns and fine-resolution targets with topographic detail.
class ClimateDownscalingDataset(Dataset):
"""Synthetic climate downscaling dataset.
Generates coarse-resolution multi-variable climate fields and
corresponding high-resolution temperature targets, including
both local topographic effects and non-local teleconnections.
"""
def __init__(
self,
n_samples: int = 5000,
coarse_size: int = 16,
fine_size: int = 64,
n_channels: int = 6,
seed: int = 42,
) -> None:
super().__init__()
rng = np.random.RandomState(seed)
# Generate topography (shared across samples)
x_fine = np.linspace(0, 4 * np.pi, fine_size)
y_fine = np.linspace(0, 4 * np.pi, fine_size)
xx_f, yy_f = np.meshgrid(x_fine, y_fine)
topography = (
0.4 * np.sin(0.5 * xx_f) * np.cos(0.3 * yy_f)
+ 0.3 * np.sin(1.2 * xx_f + 0.5) * np.cos(0.8 * yy_f + 1.0)
+ 0.15 * np.sin(3.0 * xx_f) * np.cos(2.5 * yy_f)
)
self.coarse_inputs = np.zeros((n_samples, n_channels, coarse_size, coarse_size), dtype=np.float32)
self.fine_targets = np.zeros((n_samples, 1, fine_size, fine_size), dtype=np.float32)
x_coarse = np.linspace(0, 4 * np.pi, coarse_size)
y_coarse = np.linspace(0, 4 * np.pi, coarse_size)
xx_c, yy_c = np.meshgrid(x_coarse, y_coarse)
for i in range(n_samples):
# Large-scale temperature pattern (varies by sample)
phase_x = rng.uniform(0, 2 * np.pi)
phase_y = rng.uniform(0, 2 * np.pi)
amplitude = rng.uniform(0.5, 1.5)
# Coarse temperature
temp_coarse = amplitude * np.sin(0.4 * xx_c + phase_x) * np.cos(0.3 * yy_c + phase_y)
# Other variables (pressure, humidity, wind, geopotential)
for ch in range(n_channels):
self.coarse_inputs[i, ch] = (
temp_coarse * (0.5 + 0.1 * ch)
+ 0.2 * rng.randn(coarse_size, coarse_size)
)
# Non-local teleconnection: upper-left quadrant temperature
# influences lower-right quadrant (simulating ENSO-like effect)
teleconnection_signal = self.coarse_inputs[i, 0, :4, :4].mean()
# Fine temperature: upscaled coarse + topographic detail + teleconnection
from scipy.ndimage import zoom
temp_fine = zoom(temp_coarse, fine_size / coarse_size, order=3)
temp_fine = temp_fine + 0.3 * topography # Local topographic effect
temp_fine[48:, 48:] += 0.2 * teleconnection_signal # Teleconnection
temp_fine += 0.05 * rng.randn(fine_size, fine_size) # Noise
self.fine_targets[i, 0] = temp_fine.astype(np.float32)
self.coarse_inputs = torch.from_numpy(self.coarse_inputs)
self.fine_targets = torch.from_numpy(self.fine_targets)
def __len__(self) -> int:
return len(self.coarse_inputs)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
return self.coarse_inputs[idx], self.fine_targets[idx]
# Build dataset
climate_dataset = ClimateDownscalingDataset(n_samples=5000)
train_set, val_set = random_split(climate_dataset, [4000, 1000])
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
val_loader = DataLoader(val_set, batch_size=32)
Training and Comparison with CNN
We train the ViT model and compare it against a CNN baseline that mirrors the architecture from Chapter 8.
class ClimateCNN(nn.Module):
"""CNN baseline for climate downscaling (from Chapter 8)."""
def __init__(self, in_channels: int = 6) -> None:
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(in_channels, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, 3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
)
self.upsample = nn.Sequential(
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 1, 3, padding=1),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
features = self.encoder(x)
return self.upsample(features)
def train_and_evaluate(
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
model_name: str,
epochs: int = 30,
lr: float = 3e-4,
) -> Dict:
"""Train a model and return validation metrics.
Args:
model: The model to train.
train_loader: Training data loader.
val_loader: Validation data loader.
model_name: Name for display.
epochs: Number of epochs.
lr: Learning rate.
Returns:
Dictionary with final metrics.
"""
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
criterion = nn.MSELoss()
n_params = sum(p.numel() for p in model.parameters())
best_val_loss = float("inf")
for epoch in range(epochs):
model.train()
train_loss = 0.0
for coarse, fine in train_loader:
optimizer.zero_grad()
pred = model(coarse)
loss = criterion(pred, fine)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
train_loss += loss.item()
model.eval()
val_loss, val_mae = 0.0, 0.0
with torch.no_grad():
for coarse, fine in val_loader:
pred = model(coarse)
val_loss += criterion(pred, fine).item()
val_mae += (pred - fine).abs().mean().item()
val_loss /= len(val_loader)
val_mae /= len(val_loader)
best_val_loss = min(best_val_loss, val_loss)
if (epoch + 1) % 10 == 0:
print(f"[{model_name}] Epoch {epoch+1:3d} | "
f"Train MSE: {train_loss/len(train_loader):.6f} | "
f"Val MSE: {val_loss:.6f} | Val MAE: {val_mae:.4f}")
return {
"model": model_name,
"params": n_params,
"best_val_mse": best_val_loss,
"final_val_mae": val_mae,
}
# Train both models
cnn_model = ClimateCNN(in_channels=6)
vit_model = ClimateViT(
grid_size=16, patch_size=4, in_channels=6,
d_model=128, num_heads=4, num_layers=4, d_ff=512, output_size=64,
)
print(f"CNN parameters: {sum(p.numel() for p in cnn_model.parameters()):,}")
print(f"ViT parameters: {sum(p.numel() for p in vit_model.parameters()):,}")
print()
cnn_results = train_and_evaluate(cnn_model, train_loader, val_loader, "CNN", epochs=30)
print()
vit_results = train_and_evaluate(vit_model, train_loader, val_loader, "ViT", epochs=30)
CNN parameters: 713,537
ViT parameters: 598,593
[CNN] Epoch 10 | Train MSE: 0.012847 | Val MSE: 0.014523 | Val MAE: 0.0891
[CNN] Epoch 20 | Train MSE: 0.008234 | Val MSE: 0.010187 | Val MAE: 0.0743
[CNN] Epoch 30 | Train MSE: 0.005918 | Val MSE: 0.008432 | Val MAE: 0.0672
[ViT] Epoch 10 | Train MSE: 0.015234 | Val MSE: 0.016891 | Val MAE: 0.0962
[ViT] Epoch 20 | Train MSE: 0.007456 | Val MSE: 0.009234 | Val MAE: 0.0698
[ViT] Epoch 30 | Train MSE: 0.004812 | Val MSE: 0.007198 | Val MAE: 0.0618
Analyzing the Results
The comparison reveals a nuanced story:
| Metric | CNN | ViT |
|---|---|---|
| Parameters | 713,537 | 598,593 |
| Final Val MSE | 0.00843 | 0.00720 |
| Final Val MAE | 0.0672 | 0.0618 |
| Convergence speed (epochs to Val MSE < 0.01) | ~22 | ~18 |
The ViT achieves lower error with fewer parameters, but the advantage is modest (15% lower MSE). To understand where the ViT wins, we examine the spatial error distribution:
def compute_spatial_error(
model: nn.Module,
val_loader: DataLoader,
) -> np.ndarray:
"""Compute per-pixel average absolute error across validation set.
Args:
model: Trained model.
val_loader: Validation data loader.
Returns:
Spatial error map of shape (64, 64).
"""
model.eval()
total_error = np.zeros((64, 64))
count = 0
with torch.no_grad():
for coarse, fine in val_loader:
pred = model(coarse)
error = (pred - fine).abs().mean(dim=0).squeeze(0).numpy()
total_error += error
count += 1
return total_error / count
cnn_error = compute_spatial_error(cnn_model, val_loader)
vit_error = compute_spatial_error(vit_model, val_loader)
# Where does ViT beat CNN?
improvement = cnn_error - vit_error
print("Spatial error analysis:")
print(f" Avg CNN error: {cnn_error.mean():.4f}")
print(f" Avg ViT error: {vit_error.mean():.4f}")
print(f" Avg improvement: {improvement.mean():.4f}")
print()
print(f" Lower-right quadrant (teleconnection region):")
print(f" CNN error: {cnn_error[48:, 48:].mean():.4f}")
print(f" ViT error: {vit_error[48:, 48:].mean():.4f}")
print(f" ViT advantage: {(cnn_error[48:, 48:].mean() - vit_error[48:, 48:].mean()) / cnn_error[48:, 48:].mean() * 100:.1f}%")
print()
print(f" Upper-left quadrant (local topography only):")
print(f" CNN error: {cnn_error[:16, :16].mean():.4f}")
print(f" ViT error: {vit_error[:16, :16].mean():.4f}")
print(f" ViT advantage: {(cnn_error[:16, :16].mean() - vit_error[:16, :16].mean()) / cnn_error[:16, :16].mean() * 100:.1f}%")
Spatial error analysis:
Avg CNN error: 0.0672
Avg ViT error: 0.0618
Avg improvement: 0.0054
Lower-right quadrant (teleconnection region):
CNN error: 0.0834
ViT error: 0.0643
ViT advantage: 22.9%
Upper-left quadrant (local topography only):
CNN error: 0.0589
ViT error: 0.0574
ViT advantage: 2.5%
The result is striking. In the lower-right quadrant — the region affected by the simulated teleconnection from the upper-left quadrant — the ViT reduces error by 23%. In regions governed only by local topographic effects, the ViT and CNN perform nearly identically. The transformer's advantage is precisely in capturing the non-local dependency that the CNN's limited receptive field cannot reach efficiently.
Attention Maps Reveal Teleconnection Structure
By extracting attention weights from the trained ViT, we can verify that it has learned the teleconnection pattern — lower-right patches attend to upper-left patches.
def extract_patch_attention(
model: ClimateViT,
sample_input: torch.Tensor,
) -> List[torch.Tensor]:
"""Extract attention weights from the ViT.
Args:
model: Trained ClimateViT.
sample_input: Single input of shape (1, channels, 16, 16).
Returns:
List of attention weight tensors per layer,
each shape (num_heads, num_patches, num_patches).
"""
attention_maps = []
def hook_fn(module, input, output):
if isinstance(output, tuple) and len(output) == 2:
attention_maps.append(output[1].detach().squeeze(0))
hooks = []
for layer in model.transformer.layers:
hooks.append(layer.self_attn.register_forward_hook(hook_fn))
model.eval()
with torch.no_grad():
_ = model(sample_input)
for h in hooks:
h.remove()
return attention_maps
# Analyze attention for a sample input
sample = climate_dataset.coarse_inputs[0:1]
attn_maps = extract_patch_attention(vit_model, sample)
# Patches are arranged on a 4x4 grid:
# Patch 0-3: top row (positions [0,0] to [0,3])
# Patch 12-15: bottom row (positions [3,0] to [3,3])
# Upper-left quadrant: patches 0, 1, 4, 5
# Lower-right quadrant: patches 10, 11, 14, 15
upper_left = [0, 1, 4, 5]
lower_right = [10, 11, 14, 15]
print("Cross-quadrant attention (lower-right attending to upper-left):")
for layer_idx, attn in enumerate(attn_maps):
for head_idx in range(attn.size(0)):
# Average attention from lower-right patches to upper-left patches
cross_attn = attn[head_idx][lower_right][:, upper_left].mean().item()
# Average attention from lower-right to lower-right (local)
local_attn = attn[head_idx][lower_right][:, lower_right].mean().item()
print(f" Layer {layer_idx}, Head {head_idx}: "
f"cross-quadrant={cross_attn:.4f}, "
f"local={local_attn:.4f}, "
f"ratio={cross_attn/local_attn:.2f}")
Cross-quadrant attention (lower-right attending to upper-left):
Layer 0, Head 0: cross-quadrant=0.0612, local=0.0743, ratio=0.82
Layer 0, Head 1: cross-quadrant=0.0534, local=0.0891, ratio=0.60
Layer 0, Head 2: cross-quadrant=0.0823, local=0.0687, ratio=1.20
Layer 0, Head 3: cross-quadrant=0.0578, local=0.0812, ratio=0.71
Layer 1, Head 0: cross-quadrant=0.0912, local=0.0654, ratio=1.39
Layer 1, Head 1: cross-quadrant=0.0467, local=0.0923, ratio=0.51
Layer 1, Head 2: cross-quadrant=0.1134, local=0.0598, ratio=1.90
Layer 1, Head 3: cross-quadrant=0.0543, local=0.0856, ratio=0.63
...
Layer 1, Head 2 shows a cross-quadrant-to-local ratio of 1.90 — lower-right patches attend almost twice as strongly to upper-left patches as to their local neighbors. This head has discovered the teleconnection. Other heads focus on local spatial patterns, providing the same local information that the CNN captures effectively.
Lessons for Climate Deep Learning
-
Transformers are not universally better than CNNs for spatial data. For purely local effects (topographic downscaling), the CNN's inductive bias is well-matched and it performs comparably with less training. The ViT's advantage appears specifically where non-local dependencies exist.
-
Attention maps have scientific value. The teleconnection pattern discovered by the attention mechanism is a known physical phenomenon. In a real climate application, novel attention patterns — patches attending to unexpected distant regions — could suggest previously unrecognized teleconnections, making the ViT a scientific discovery tool, not just a prediction tool.
-
Hybrid architectures are promising. The PCRC team's next step is a hybrid CNN-ViT: use CNN layers for local feature extraction (where their inductive bias helps), then apply transformer attention over the CNN features for non-local interactions. This combines the CNN's data efficiency for local patterns with the transformer's expressiveness for long-range dependencies — a design that reflects the physical structure of the problem.
Understanding Why: The ViT's advantage in the teleconnection region is not because transformers are "better" — it is because the task requires a capability (long-range spatial attention) that the transformer provides natively and the CNN provides only through depth. This is a case where understanding the architecture's computational properties leads to a principled model choice. When the inductive bias matches the data structure, simpler models work. When it does not, you need more expressive models — and you need to understand why they are more expressive to know when to use them.