Case Study 1: Climate Downscaling with Convolutional Neural Networks

Context

The Pacific Climate Research Consortium (PCRC) needs to translate coarse-resolution global climate projections into high-resolution regional predictions. Global climate models from the CMIP6 ensemble produce outputs on grids of approximately $100 \text{ km}$ resolution — useful for understanding continental-scale trends, but far too coarse for regional planning. A state transportation department needs predictions at $\sim 1 \text{ km}$ resolution to assess road infrastructure risk from extreme temperatures. A water utility needs precipitation projections at watershed scale. A public health agency needs urban heat island estimates at neighborhood resolution.

This spatial resolution gap is the downscaling problem. Traditional statistical downscaling uses linear regression to map coarse-resolution predictors (temperature, pressure, humidity) to fine-resolution targets (local temperature), calibrated against historical observations from weather stations. CNNs offer a natural upgrade: the coarse-resolution climate field is a gridded spatial map — exactly the kind of data where convolution's locality and weight sharing provide the right inductive bias.

This case study implements a CNN-based downscaling model that takes a $16 \times 16$ patch of coarse-resolution CMIP6 data (6 climate variables) and produces a $64 \times 64$ high-resolution daily maximum temperature field — a $4\times$ spatial upscaling.

The Data

We simulate data that captures the essential structure of the downscaling problem: coarse-resolution inputs with smooth, large-scale patterns, and fine-resolution targets with local details driven by topography and land use.

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from typing import Tuple, List, Dict


def generate_topography(size: int, seed: int = 42) -> np.ndarray:
    """Generate a synthetic topographic height field.

    Combines low-frequency terrain features (mountain ranges)
    with high-frequency local variation (ridges, valleys).

    Args:
        size: Spatial resolution (size x size grid).
        seed: Random seed.

    Returns:
        Height field of shape (size, size), normalized to [0, 1].
    """
    rng = np.random.RandomState(seed)
    x = np.linspace(0, 4 * np.pi, size)
    y = np.linspace(0, 4 * np.pi, size)
    xx, yy = np.meshgrid(x, y)

    # Large-scale terrain (mountains)
    terrain = (
        0.4 * np.sin(0.5 * xx) * np.cos(0.3 * yy)
        + 0.3 * np.sin(1.2 * xx + 0.5) * np.cos(0.8 * yy + 1.0)
    )
    # Local variation (ridges, valleys)
    local = 0.15 * np.sin(3.0 * xx) * np.cos(2.5 * yy)
    noise = 0.05 * rng.randn(size, size)

    height = terrain + local + noise
    height = (height - height.min()) / (height.max() - height.min())
    return height


class ClimateDownscalingDataset(Dataset):
    """Synthetic climate downscaling dataset.

    Generates paired (coarse input, fine target) samples where:
    - Coarse input: 16x16 grid with 6 climate variables
    - Fine target: 64x64 grid of daily max temperature

    The fine-resolution temperature depends on coarse-resolution
    climate variables plus local topographic effects that exist
    only at the fine resolution — this is what the CNN must learn.

    Args:
        n_samples: Number of (day, region) samples.
        coarse_size: Coarse grid resolution.
        fine_size: Fine grid resolution.
        seed: Random seed.
    """

    def __init__(
        self,
        n_samples: int = 10000,
        coarse_size: int = 16,
        fine_size: int = 64,
        seed: int = 42,
    ) -> None:
        super().__init__()
        rng = np.random.RandomState(seed)
        self.scale_factor = fine_size // coarse_size  # 4x upscaling

        # Static topography at fine resolution (unknown to the model)
        topo_fine = generate_topography(fine_size, seed=seed)

        # Coarse topography (what the model sees as a channel)
        from scipy.ndimage import zoom
        topo_coarse = zoom(topo_fine, 1.0 / self.scale_factor, order=1)

        self.coarse_inputs = []
        self.fine_targets = []

        for i in range(n_samples):
            # Generate coarse-resolution climate fields (6 variables)
            # Smooth random fields to simulate large-scale climate patterns
            base_temp = rng.uniform(5, 35)  # Base temperature (seasonal)
            spatial_trend_x = rng.uniform(-0.3, 0.3)
            spatial_trend_y = rng.uniform(-0.3, 0.3)

            cx = np.linspace(-1, 1, coarse_size)
            cy = np.linspace(-1, 1, coarse_size)
            cxx, cyy = np.meshgrid(cx, cy)

            # 6 coarse variables: temperature, pressure, humidity,
            # u-wind, v-wind, cloud cover
            coarse = np.zeros((6, coarse_size, coarse_size), dtype=np.float32)
            coarse[0] = base_temp + spatial_trend_x * cxx + spatial_trend_y * cyy
            coarse[0] += 0.5 * rng.randn(coarse_size, coarse_size)  # Temperature
            coarse[1] = 1013 + 5 * rng.randn(coarse_size, coarse_size)  # Pressure
            coarse[2] = np.clip(
                50 + 20 * rng.randn(coarse_size, coarse_size), 5, 100
            )  # Humidity
            coarse[3] = 3 * rng.randn(coarse_size, coarse_size)  # U-wind
            coarse[4] = 3 * rng.randn(coarse_size, coarse_size)  # V-wind
            coarse[5] = np.clip(
                0.5 + 0.3 * rng.randn(coarse_size, coarse_size), 0, 1
            )  # Cloud cover

            # Fine-resolution temperature: upscale coarse + topographic effect
            fx = np.linspace(-1, 1, fine_size)
            fy = np.linspace(-1, 1, fine_size)
            fxx, fyy = np.meshgrid(fx, fy)

            fine_temp = (
                base_temp
                + spatial_trend_x * fxx
                + spatial_trend_y * fyy
                # Lapse rate: temperature decreases with elevation
                - 6.5 * topo_fine  # ~6.5 C/km lapse rate (normalized)
                # Cloud modulation (smooth field)
                - 2.0 * zoom(coarse[5], self.scale_factor, order=1)
                # Local noise (weather variability)
                + 0.3 * rng.randn(fine_size, fine_size)
            ).astype(np.float32)

            self.coarse_inputs.append(coarse)
            self.fine_targets.append(fine_temp)

        self.coarse_inputs = np.array(self.coarse_inputs)
        self.fine_targets = np.array(self.fine_targets)

        # Normalize inputs and targets
        self.input_mean = self.coarse_inputs.mean(axis=(0, 2, 3), keepdims=True)
        self.input_std = self.coarse_inputs.std(axis=(0, 2, 3), keepdims=True) + 1e-8
        self.target_mean = self.fine_targets.mean()
        self.target_std = self.fine_targets.std() + 1e-8

        self.coarse_inputs = (self.coarse_inputs - self.input_mean) / self.input_std
        self.fine_targets = (self.fine_targets - self.target_mean) / self.target_std

    def __len__(self) -> int:
        return len(self.fine_targets)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        return (
            torch.tensor(self.coarse_inputs[idx], dtype=torch.float32),
            torch.tensor(self.fine_targets[idx], dtype=torch.float32).unsqueeze(0),
        )

The Model

The downscaling CNN has two components: (1) an encoder that extracts features from the coarse-resolution input, and (2) a decoder that upsamples to fine resolution. We use residual blocks throughout and sub-pixel convolution (pixel shuffle) for artifact-free upsampling.

class ResBlock(nn.Module):
    """Residual block for the downscaling CNN."""

    def __init__(self, channels: int) -> None:
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(channels),
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.relu(self.block(x) + x)


class DownscalingCNN(nn.Module):
    """CNN for climate statistical downscaling (4x upscaling).

    Takes coarse-resolution climate fields and produces fine-resolution
    temperature predictions, using residual blocks and sub-pixel upsampling.

    Args:
        in_channels: Number of input climate variables.
        base_channels: Base number of feature channels.
        n_res_blocks: Number of residual blocks in the encoder.
        upscale_factor: Spatial upscaling factor.
    """

    def __init__(
        self,
        in_channels: int = 6,
        base_channels: int = 64,
        n_res_blocks: int = 8,
        upscale_factor: int = 4,
    ) -> None:
        super().__init__()

        # Encoder: extract features at coarse resolution
        self.encoder_head = nn.Sequential(
            nn.Conv2d(in_channels, base_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(base_channels),
            nn.ReLU(inplace=True),
        )
        self.encoder_body = nn.Sequential(
            *[ResBlock(base_channels) for _ in range(n_res_blocks)]
        )

        # Upsampling via sub-pixel convolution (pixel shuffle)
        # 4x upscaling = two stages of 2x upscaling
        self.upsample = nn.Sequential(
            nn.Conv2d(base_channels, base_channels * 4, 3, padding=1),
            nn.PixelShuffle(2),  # (B, C*4, H, W) -> (B, C, 2H, 2W)
            nn.ReLU(inplace=True),
            nn.Conv2d(base_channels, base_channels * 4, 3, padding=1),
            nn.PixelShuffle(2),  # (B, C*4, 2H, 2W) -> (B, C, 4H, 4W)
            nn.ReLU(inplace=True),
        )

        # Final refinement at fine resolution
        self.decoder = nn.Sequential(
            ResBlock(base_channels),
            ResBlock(base_channels),
            nn.Conv2d(base_channels, 1, 3, padding=1),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass: coarse input -> fine output.

        Args:
            x: Coarse-resolution input, shape (B, 6, 16, 16).

        Returns:
            Fine-resolution temperature, shape (B, 1, 64, 64).
        """
        h = self.encoder_head(x)
        h = self.encoder_body(h) + h  # Global residual
        h = self.upsample(h)
        return self.decoder(h)

Training and Evaluation

def train_downscaling_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    epochs: int = 50,
    lr: float = 1e-3,
) -> Dict[str, List[float]]:
    """Train the downscaling CNN.

    Uses L1 loss (mean absolute error), which produces sharper
    spatial predictions than L2 loss for dense prediction tasks.

    Args:
        model: Downscaling CNN.
        train_loader: Training data loader.
        val_loader: Validation data loader.
        epochs: Number of epochs.
        lr: Learning rate.

    Returns:
        Dictionary of training and validation loss histories.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, patience=5, factor=0.5
    )
    criterion = nn.L1Loss()

    history: Dict[str, List[float]] = {"train_loss": [], "val_loss": []}

    for epoch in range(epochs):
        # Training
        model.train()
        train_losses = []
        for coarse, fine in train_loader:
            coarse, fine = coarse.to(device), fine.to(device)
            optimizer.zero_grad()
            pred = model(coarse)
            loss = criterion(pred, fine)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())

        # Validation
        model.eval()
        val_losses = []
        with torch.no_grad():
            for coarse, fine in val_loader:
                coarse, fine = coarse.to(device), fine.to(device)
                pred = model(coarse)
                val_losses.append(criterion(pred, fine).item())

        train_loss = np.mean(train_losses)
        val_loss = np.mean(val_losses)
        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        scheduler.step(val_loss)

        if (epoch + 1) % 10 == 0:
            print(
                f"Epoch {epoch+1:3d} | "
                f"Train MAE: {train_loss:.4f} | "
                f"Val MAE: {val_loss:.4f}"
            )

    return history


# Run the experiment
dataset = ClimateDownscalingDataset(n_samples=8000, seed=42)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(
    dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42)
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

model = DownscalingCNN(in_channels=6, base_channels=64, n_res_blocks=8)
total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params:,}")

history = train_downscaling_model(model, train_loader, val_loader, epochs=50)

Why CNNs Are Natural Here

Three properties of the downscaling problem align precisely with CNN inductive biases:

  1. Spatial locality. Local temperature depends primarily on nearby climate conditions and local topography. A $3 \times 3$ kernel at coarse resolution covers a $\sim 300 \text{ km}$ region — the right scale for mesoscale weather patterns.

  2. Translation equivariance. The physical relationship between coarse-resolution predictors and fine-resolution temperature is the same regardless of geographic position (modulo boundary effects). A cold front in the northwest should be downscaled the same way as a cold front in the southeast. Weight sharing enforces this.

  3. Hierarchical features. The encoder learns progressively more abstract representations: early layers capture gradients and boundaries between climate zones; deeper layers capture interactions between variables (e.g., how humidity modulates the effect of cloud cover on surface temperature).

Comparison with Linear Baselines

A standard benchmark is bilinear interpolation of the coarse-resolution temperature field:

from scipy.ndimage import zoom

def bilinear_baseline(
    dataset: ClimateDownscalingDataset,
) -> float:
    """Compute MAE of bilinear interpolation baseline.

    Simply upscales the coarse temperature channel (channel 0)
    to fine resolution using bilinear interpolation.

    Args:
        dataset: Climate downscaling dataset.

    Returns:
        Mean absolute error of bilinear interpolation.
    """
    errors = []
    for i in range(len(dataset)):
        coarse, fine = dataset[i]
        # Upscale coarse temperature (channel 0) to fine resolution
        coarse_temp = coarse[0].numpy()
        upscaled = zoom(coarse_temp, 4, order=1)  # Bilinear

        # De-normalize for comparison
        pred = upscaled
        target = fine[0].numpy()
        errors.append(np.abs(pred - target).mean())

    return np.mean(errors)


baseline_mae = bilinear_baseline(dataset)
print(f"Bilinear interpolation MAE: {baseline_mae:.4f}")
print(f"CNN MAE (final validation):  {history['val_loss'][-1]:.4f}")

The CNN substantially outperforms bilinear interpolation because it learns the nonlinear relationships between multiple climate variables and fine-resolution temperature — particularly the topographic lapse rate effect that cannot be captured by simple spatial interpolation of a single variable.

Lessons for Practice

  1. Sub-pixel convolution avoids artifacts. Transposed convolutions introduce checkerboard patterns in the output because of uneven overlap. Sub-pixel convolution (PixelShuffle) rearranges channels into spatial resolution, producing each output pixel exactly once.

  2. L1 loss preserves spatial sharpness. L2 (MSE) loss penalizes large errors quadratically, which encourages the model to predict smooth, blurred fields (the mean of possible outcomes). L1 loss penalizes errors linearly, producing sharper spatial gradients and more realistic local detail.

  3. Global residual connections matter for upsampling. The global skip from encoder input to encoder output ensures that the large-scale temperature pattern passes through even if the residual blocks focus on refining details.

  4. Domain knowledge as architecture. The 6-channel input (temperature, pressure, humidity, wind, cloud cover) is not arbitrary — these are the variables that physical models identify as relevant for surface temperature. Encoding this domain knowledge in the channel structure is a form of feature engineering that complements the CNN's learned feature extraction.