Case Study 2: Climate Stochastic Weather Generation — Diffusion Models for Ensemble Members

Context

The Pacific Climate Research Consortium (PCRC) faces a fundamental challenge in climate communication: a single forecast is almost always wrong. Weather and climate are chaotic systems — small perturbations in initial conditions lead to large differences in outcomes. The standard scientific approach is ensemble forecasting: run the same climate model many times with slightly perturbed initial conditions, producing a distribution of possible outcomes. Policymakers can then plan for the range of possibilities, not a single prediction.

Traditional ensemble generation is computationally brutal. A single global climate model run on a 100 km grid takes approximately 10,000 CPU-hours. A 50-member ensemble requires 500,000 CPU-hours — roughly $50,000 in cloud compute. And PCRC needs ensembles for multiple climate scenarios, multiple time horizons, and multiple regional domains.

The proposal: train a diffusion model on existing ensemble members to learn the distribution of plausible weather states, then generate additional ensemble members at a fraction of the computational cost. Instead of running the physics-based model 50 times, run it 10 times and use the diffusion model to generate 40 additional members that are statistically consistent with the physics-based ensemble.

The Data

Weather states are spatial fields: at each grid point, multiple meteorological variables (temperature, pressure, wind, humidity) describe the atmosphere. An ensemble member is one realization of these fields. We simulate this as multi-channel 2D spatial data.

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from typing import Dict, List, Tuple


def generate_weather_ensemble(
    n_members: int = 500,
    grid_size: int = 32,
    n_variables: int = 4,
    seed: int = 42,
) -> np.ndarray:
    """Generate a synthetic weather ensemble dataset.

    Each ensemble member is a spatial field with 4 variables:
    temperature (T), pressure (P), u-wind, v-wind. Members share
    a common large-scale pattern (the climate signal) but differ
    in small-scale details (weather noise). This structure mirrors
    real climate ensembles where members agree on trends but
    disagree on specific weather events.

    Args:
        n_members: Number of ensemble members.
        grid_size: Spatial resolution (grid_size x grid_size).
        n_variables: Number of meteorological variables per grid point.
        seed: Random seed.

    Returns:
        Array of shape (n_members, n_variables, grid_size, grid_size).
    """
    rng = np.random.RandomState(seed)
    x = np.linspace(0, 2 * np.pi, grid_size)
    y = np.linspace(0, 2 * np.pi, grid_size)
    xx, yy = np.meshgrid(x, y)

    ensemble = np.zeros(
        (n_members, n_variables, grid_size, grid_size), dtype=np.float32
    )

    for i in range(n_members):
        # Shared climate signal (same for all members, with slight variation)
        phase_shift = rng.uniform(-0.1, 0.1, size=4)

        # Temperature: strong meridional gradient + wave pattern
        t_signal = (
            -5.0 * np.cos(yy + phase_shift[0])
            + 2.0 * np.sin(xx * 0.8 + phase_shift[1])
        )
        t_noise = 1.5 * _smooth_noise(rng, grid_size, length_scale=4)
        ensemble[i, 0] = t_signal + t_noise

        # Pressure: large-scale low/high pressure systems
        p_signal = (
            1013.0
            + 10.0 * np.sin(xx * 0.5 + phase_shift[2]) * np.cos(yy * 0.7)
        )
        p_noise = 3.0 * _smooth_noise(rng, grid_size, length_scale=6)
        ensemble[i, 1] = p_signal + p_noise

        # U-wind: geostrophic balance (proportional to pressure gradient)
        ensemble[i, 2] = (
            -np.gradient(ensemble[i, 1], axis=0) * 0.5
            + 0.8 * _smooth_noise(rng, grid_size, length_scale=3)
        )

        # V-wind: geostrophic balance
        ensemble[i, 3] = (
            np.gradient(ensemble[i, 1], axis=1) * 0.5
            + 0.8 * _smooth_noise(rng, grid_size, length_scale=3)
        )

    # Normalize each variable to zero mean, unit variance
    for v in range(n_variables):
        mu = ensemble[:, v].mean()
        sigma = ensemble[:, v].std() + 1e-8
        ensemble[:, v] = (ensemble[:, v] - mu) / sigma

    return ensemble


def _smooth_noise(
    rng: np.random.RandomState, size: int, length_scale: int = 4
) -> np.ndarray:
    """Generate spatially correlated noise via spectral filtering.

    Creates noise with spatial correlations by generating white noise
    in Fourier space and applying a Gaussian low-pass filter.

    Args:
        rng: Random state.
        size: Grid size.
        length_scale: Correlation length in grid points.

    Returns:
        Smooth noise field of shape (size, size).
    """
    white = rng.randn(size, size)
    # Fourier transform and apply Gaussian filter
    freq = np.fft.fftfreq(size)
    kx, ky = np.meshgrid(freq, freq)
    k2 = kx**2 + ky**2
    # Gaussian filter in frequency space
    filt = np.exp(-k2 * (2 * np.pi * length_scale) ** 2 / 2)
    filtered = np.fft.ifft2(np.fft.fft2(white) * filt).real
    # Normalize to unit variance
    filtered = filtered / (filtered.std() + 1e-8)
    return filtered

The Model: Convolutional DDPM for Weather Fields

Weather fields have spatial structure — temperature at one grid point is correlated with nearby grid points. A convolutional denoising network is the natural architecture. We use a simplified U-Net-like architecture with residual blocks and sinusoidal timestep embeddings.

class ConvResBlock(nn.Module):
    """Convolutional residual block with timestep conditioning.

    Args:
        channels: Number of input and output channels.
        time_dim: Dimensionality of the timestep embedding.
    """

    def __init__(self, channels: int, time_dim: int = 128) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)
        self.time_proj = nn.Linear(time_dim, channels)

    def forward(
        self, x: torch.Tensor, t_emb: torch.Tensor
    ) -> torch.Tensor:
        """Forward pass with timestep conditioning.

        Args:
            x: Feature maps, shape (B, C, H, W).
            t_emb: Timestep embedding, shape (B, time_dim).

        Returns:
            Output feature maps, shape (B, C, H, W).
        """
        h = F.silu(self.bn1(self.conv1(x)))
        # Add timestep information
        h = h + self.time_proj(t_emb)[:, :, None, None]
        h = self.bn2(self.conv2(h))
        return F.silu(h + x)


class WeatherDenoisingNet(nn.Module):
    """Convolutional denoising network for weather field generation.

    A simplified U-Net with residual blocks and timestep conditioning.
    Predicts the noise added to weather fields at each diffusion timestep.

    Args:
        in_channels: Number of meteorological variables.
        base_channels: Base number of feature channels.
        time_dim: Timestep embedding dimensionality.
    """

    def __init__(
        self,
        in_channels: int = 4,
        base_channels: int = 64,
        time_dim: int = 128,
    ) -> None:
        super().__init__()

        # Timestep embedding
        self.time_embed = nn.Sequential(
            nn.Linear(1, time_dim),
            nn.SiLU(),
            nn.Linear(time_dim, time_dim),
        )

        # Encoder
        self.enc1 = nn.Conv2d(in_channels, base_channels, 3, padding=1)
        self.res1 = ConvResBlock(base_channels, time_dim)
        self.down1 = nn.Conv2d(base_channels, base_channels * 2, 4, stride=2, padding=1)
        self.res2 = ConvResBlock(base_channels * 2, time_dim)

        # Bottleneck
        self.down2 = nn.Conv2d(base_channels * 2, base_channels * 4, 4, stride=2, padding=1)
        self.res_mid = ConvResBlock(base_channels * 4, time_dim)

        # Decoder
        self.up2 = nn.ConvTranspose2d(
            base_channels * 4, base_channels * 2, 4, stride=2, padding=1
        )
        self.res3 = ConvResBlock(base_channels * 4, time_dim)  # Skip connection doubles channels
        self.up1 = nn.ConvTranspose2d(
            base_channels * 4, base_channels, 4, stride=2, padding=1
        )
        self.res4 = ConvResBlock(base_channels * 2, time_dim)

        self.final = nn.Conv2d(base_channels * 2, in_channels, 1)

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """Predict noise given noisy weather field and timestep.

        Args:
            x: Noisy weather field, shape (B, 4, 32, 32).
            t: Timesteps, shape (B,).

        Returns:
            Predicted noise, shape (B, 4, 32, 32).
        """
        t_emb = self.time_embed(t.float().unsqueeze(-1))

        # Encoder
        h1 = self.enc1(x)
        h1 = self.res1(h1, t_emb)
        h2 = self.down1(h1)
        h2 = self.res2(h2, t_emb)

        # Bottleneck
        h = self.down2(h2)
        h = self.res_mid(h, t_emb)

        # Decoder with skip connections
        h = self.up2(h)
        h = torch.cat([h, h2], dim=1)
        h = self.res3(h, t_emb)
        h = self.up1(h)
        h = torch.cat([h, h1], dim=1)
        h = self.res4(h, t_emb)

        return self.final(h)

Training and Ensemble Generation

class WeatherDDPM:
    """DDPM for weather ensemble generation.

    Wraps the convolutional denoising network with the DDPM
    forward/reverse process. Generates new ensemble members
    by sampling from noise and iteratively denoising.

    Args:
        model: Convolutional denoising network.
        n_timesteps: Number of diffusion timesteps.
        beta_start: Initial noise level.
        beta_end: Final noise level.
        device: Computation device.
    """

    def __init__(
        self,
        model: WeatherDenoisingNet,
        n_timesteps: int = 500,
        beta_start: float = 1e-4,
        beta_end: float = 0.02,
        device: str = "cpu",
    ) -> None:
        self.model = model.to(device)
        self.n_timesteps = n_timesteps
        self.device = device

        betas = torch.linspace(beta_start, beta_end, n_timesteps, device=device)
        alphas = 1.0 - betas
        alpha_cumprod = torch.cumprod(alphas, dim=0)

        self.betas = betas
        self.sqrt_alpha_cumprod = torch.sqrt(alpha_cumprod)
        self.sqrt_one_minus_alpha_cumprod = torch.sqrt(1.0 - alpha_cumprod)
        self.sqrt_recip_alpha = torch.sqrt(1.0 / alphas)
        self.posterior_variance = (
            betas * (1.0 - torch.cat([torch.tensor([1.0], device=device), alpha_cumprod[:-1]]))
            / (1.0 - alpha_cumprod)
        )

    def compute_loss(self, x_0: torch.Tensor) -> torch.Tensor:
        batch_size = x_0.size(0)
        t = torch.randint(0, self.n_timesteps, (batch_size,), device=self.device)
        epsilon = torch.randn_like(x_0)

        sqrt_ab = self.sqrt_alpha_cumprod[t][:, None, None, None]
        sqrt_omab = self.sqrt_one_minus_alpha_cumprod[t][:, None, None, None]
        x_t = sqrt_ab * x_0 + sqrt_omab * epsilon

        return F.mse_loss(self.model(x_t, t), epsilon)

    @torch.no_grad()
    def generate_ensemble(
        self, n_members: int, shape: Tuple[int, ...]
    ) -> np.ndarray:
        """Generate new ensemble members.

        Args:
            n_members: Number of ensemble members to generate.
            shape: Shape of each member (n_variables, H, W).

        Returns:
            Generated ensemble, shape (n_members, *shape).
        """
        self.model.eval()
        x = torch.randn(n_members, *shape, device=self.device)

        for t in reversed(range(self.n_timesteps)):
            t_batch = torch.full(
                (n_members,), t, device=self.device, dtype=torch.long
            )
            eps_pred = self.model(x, t_batch)

            mean = self.sqrt_recip_alpha[t] * (
                x - self.betas[t] / self.sqrt_one_minus_alpha_cumprod[t] * eps_pred
            )

            if t > 0:
                noise = torch.randn_like(x)
                x = mean + torch.sqrt(self.posterior_variance[t]) * noise
            else:
                x = mean

        return x.cpu().numpy()


# ---- Training pipeline ----

# Generate training ensemble
real_ensemble = generate_weather_ensemble(n_members=500, grid_size=32)
print(f"Training ensemble shape: {real_ensemble.shape}")
print(f"Variable means: {real_ensemble.mean(axis=(0, 2, 3)).round(4)}")
print(f"Variable stds:  {real_ensemble.std(axis=(0, 2, 3)).round(4)}")

# Train
dataset = TensorDataset(torch.tensor(real_ensemble, dtype=torch.float32))
loader = DataLoader(dataset, batch_size=32, shuffle=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
denoiser = WeatherDenoisingNet(in_channels=4, base_channels=64)
weather_ddpm = WeatherDDPM(denoiser, n_timesteps=500, device=device)
optimizer = torch.optim.Adam(weather_ddpm.model.parameters(), lr=2e-4)

print("\n=== Training Weather DDPM ===")
for epoch in range(200):
    weather_ddpm.model.train()
    epoch_loss = []
    for (batch,) in loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        loss = weather_ddpm.compute_loss(batch)
        loss.backward()
        optimizer.step()
        epoch_loss.append(loss.item())

    if (epoch + 1) % 40 == 0:
        print(f"Epoch {epoch+1:3d} | Loss: {np.mean(epoch_loss):.6f}")

# Generate synthetic ensemble members
synthetic_ensemble = weather_ddpm.generate_ensemble(
    n_members=200, shape=(4, 32, 32)
)
print(f"\nGenerated ensemble shape: {synthetic_ensemble.shape}")

Evaluation: Are Synthetic Members Physically Plausible?

Synthetic ensemble members must satisfy two criteria: (1) they should be statistically indistinguishable from real members (same marginals, spatial correlations, cross-variable relationships), and (2) they should be physically plausible (wind fields should be approximately geostrophic, temperature gradients should be smooth).

def evaluate_weather_ensemble(
    real: np.ndarray, synthetic: np.ndarray
) -> Dict[str, float]:
    """Evaluate synthetic ensemble quality against real ensemble.

    Checks marginal distributions, spatial correlation structure,
    cross-variable relationships, and physical consistency.

    Args:
        real: Real ensemble, shape (N_real, 4, H, W).
        synthetic: Synthetic ensemble, shape (N_synth, 4, H, W).

    Returns:
        Dictionary of evaluation metrics.
    """
    from scipy import stats
    var_names = ["Temperature", "Pressure", "U-wind", "V-wind"]
    metrics = {}

    # 1. Marginal distributions (per variable, spatially averaged)
    print("=== Marginal Distribution Comparison (KS test) ===")
    for v, name in enumerate(var_names):
        real_flat = real[:, v].flatten()
        synth_flat = synthetic[:, v].flatten()
        ks_stat, p_val = stats.ks_2samp(
            real_flat[::10], synth_flat[::10]  # Subsample for speed
        )
        print(f"  {name:>12}: KS = {ks_stat:.4f}, p = {p_val:.4f}")
        metrics[f"ks_{name.lower()}"] = ks_stat

    # 2. Ensemble spread (standard deviation across members at each grid point)
    real_spread = real.std(axis=0).mean(axis=(1, 2))
    synth_spread = synthetic.std(axis=0).mean(axis=(1, 2))
    print(f"\n=== Ensemble Spread (std across members) ===")
    for v, name in enumerate(var_names):
        print(
            f"  {name:>12}: real = {real_spread[v]:.4f}, "
            f"synth = {synth_spread[v]:.4f}"
        )
    metrics["spread_ratio"] = (synth_spread / (real_spread + 1e-8)).mean()

    # 3. Spatial autocorrelation (mean correlation at lag 1)
    print(f"\n=== Spatial Autocorrelation (lag-1) ===")
    for v, name in enumerate(var_names):
        real_autocorr = np.mean([
            np.corrcoef(real[i, v, :-1, :].flatten(), real[i, v, 1:, :].flatten())[0, 1]
            for i in range(min(50, len(real)))
        ])
        synth_autocorr = np.mean([
            np.corrcoef(synthetic[i, v, :-1, :].flatten(), synthetic[i, v, 1:, :].flatten())[0, 1]
            for i in range(min(50, len(synthetic)))
        ])
        print(
            f"  {name:>12}: real = {real_autocorr:.4f}, "
            f"synth = {synth_autocorr:.4f}"
        )

    # 4. Cross-variable correlation (T-P correlation)
    real_tp_corr = np.mean([
        np.corrcoef(real[i, 0].flatten(), real[i, 1].flatten())[0, 1]
        for i in range(min(100, len(real)))
    ])
    synth_tp_corr = np.mean([
        np.corrcoef(synthetic[i, 0].flatten(), synthetic[i, 1].flatten())[0, 1]
        for i in range(min(100, len(synthetic)))
    ])
    print(f"\n=== Cross-Variable Correlation (T-P) ===")
    print(f"  Real: {real_tp_corr:.4f}, Synthetic: {synth_tp_corr:.4f}")
    metrics["tp_corr_error"] = abs(real_tp_corr - synth_tp_corr)

    return metrics


metrics = evaluate_weather_ensemble(real_ensemble, synthetic_ensemble)

Why Diffusion Models for Weather Ensembles

Three properties make diffusion models ideal for this application:

  1. Mode coverage. Each generated ensemble member should be a distinct, plausible weather state — not a copy of a training member or a blurred average. Diffusion models excel at mode coverage: the stochastic reverse process naturally produces diverse samples. GANs risk mode collapse, which would produce an ensemble with too little spread — underestimating uncertainty.

  2. Spatial coherence. The convolutional denoising network preserves the spatial structure of weather fields: the learned reverse process respects the spatial correlations between nearby grid points and the cross-variable relationships (e.g., geostrophic balance between pressure gradients and wind). A pixel-independent model would produce spatially incoherent noise.

  3. Calibrated uncertainty. The ensemble spread (standard deviation across members) should match the true forecast uncertainty. Diffusion models, because they approximate the full data distribution (not just its mode), produce well-calibrated ensemble spreads — the generated members cover the range of plausible weather states with approximately correct probability.

From Prototype to Production

In production, several enhancements would be needed:

  1. Conditional generation. The model should be conditioned on the current atmospheric state (analysis fields from weather observations) to generate ensembles that are consistent with the observed weather, not arbitrary weather states. This requires classifier-free guidance or conditional architectures.

  2. Physical constraints. Post-processing steps can enforce known physical constraints (e.g., mass conservation, energy balance) that the diffusion model may not learn perfectly from data alone.

  3. Validation against physics-based ensembles. The ultimate test is whether decisions made using the ML-augmented ensemble are as good as decisions made using the full physics-based ensemble — e.g., do flood warning thresholds produce the same false alarm rates?

  4. Computational cost comparison. The value proposition depends on the ratio of diffusion model inference cost to physics-based model run cost. If a 500-step diffusion model on a GPU generates one member in 30 seconds, while a physics-based model takes 200 CPU-hours, the speedup is approximately 24,000x — making 1000-member ensembles feasible for the first time.

Lessons for Practice

  1. Generative models for uncertainty quantification. The ensemble application illustrates a broader principle: generative models are tools for uncertainty quantification, not just data augmentation. Any time you need the distribution of possible outcomes — not just a point prediction — a generative model is the right framework.

  2. Domain-specific evaluation is essential. FID and inception scores are meaningless for weather data. Evaluation must be grounded in domain-specific metrics: ensemble spread, rank histograms, continuous ranked probability scores (CRPS), and skill scores against observations. The generative model community's standard metrics must be adapted, not adopted.

  3. Diffusion models trade inference cost for training stability. Compared to GANs, diffusion models are trivial to train (just minimize MSE on noise prediction) but expensive to sample from (hundreds of denoising steps). For applications where you generate many samples offline (ensemble forecasting, synthetic dataset creation), this tradeoff is favorable.