Case Study 1: Climate Temporal Modeling — Decadal Temperature Trends with LSTMs

Context

The Pacific Climate Research Consortium (PCRC) has gridded monthly temperature records from 200 weather stations spanning 1960-2023. In Chapter 8, we used CNNs to downscale coarse-resolution spatial climate data to fine resolution — exploiting the 2D spatial structure of gridded climate fields. Now we address a different problem: temporal forecasting. Given a station's monthly temperature history, predict the next 24 months. This is a sequence-to-value problem where the sequence has both strong periodicity (seasonal cycles) and slow non-stationarity (warming trends).

The temporal structure makes this a natural fit for recurrent networks. Each month's temperature depends on the previous months through seasonal inertia, multi-year climate oscillations (El Nino/La Nina cycles with periods of 2-7 years), and a long-term warming trend that varies by latitude and elevation. The LSTM must learn to separate these overlapping temporal patterns — a task that requires maintaining information at multiple timescales simultaneously.

This case study implements a multivariate LSTM that processes monthly climate vectors (temperature, precipitation, sea-level pressure) and forecasts temperature 24 months ahead. We will use this as a baseline for the transformer-based climate model in Chapter 10 and the full temporal modeling approach in Chapter 23.

The Data

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from typing import Dict, List, Tuple


def generate_multivariate_climate(
    n_stations: int = 200,
    n_years: int = 64,
    seed: int = 42,
) -> np.ndarray:
    """Generate synthetic multivariate monthly climate data.

    Each station has three variables:
    - Temperature (C): seasonal cycle + warming trend + AR noise
    - Precipitation (mm/month): seasonal cycle + interannual variability
    - Sea-level pressure (hPa): seasonal cycle + noise, anticorrelated
      with precipitation

    Stations vary in base climate (latitude proxy), seasonal amplitude,
    and warming rate.

    Args:
        n_stations: Number of stations.
        n_years: Number of years of data.
        seed: Random seed.

    Returns:
        Array of shape (n_stations, n_months, 3).
    """
    rng = np.random.RandomState(seed)
    n_months = n_years * 12
    t = np.arange(n_months)

    data = np.zeros((n_stations, n_months, 3))

    for i in range(n_stations):
        # Station-specific parameters
        base_temp = rng.uniform(2.0, 28.0)        # Latitude proxy
        seasonal_amp = rng.uniform(5.0, 25.0)      # Continental vs. maritime
        phase = rng.uniform(0, 0.5)                 # Northern vs. southern
        trend = rng.uniform(0.015, 0.045) / 12      # Warming rate (C/month)

        # ENSO-like oscillation (2-7 year period)
        enso_period = rng.uniform(24, 84)  # months
        enso_amp = rng.uniform(0.3, 1.5)

        # Temperature
        seasonal = seasonal_amp * np.sin(2 * np.pi * (t / 12 - phase))
        warming = trend * t
        enso = enso_amp * np.sin(2 * np.pi * t / enso_period)
        ar_noise = np.zeros(n_months)
        ar_coeff = rng.uniform(0.3, 0.6)
        for j in range(1, n_months):
            ar_noise[j] = ar_coeff * ar_noise[j - 1] + rng.normal(0, 0.8)
        data[i, :, 0] = base_temp + seasonal + warming + enso + ar_noise

        # Precipitation (anticorrelated with temperature in many climates)
        base_precip = rng.uniform(40, 150)
        precip_seasonal = rng.uniform(20, 80) * np.cos(
            2 * np.pi * (t / 12 - phase) + rng.uniform(-0.5, 0.5)
        )
        precip_noise = rng.gamma(2, 10, n_months)
        data[i, :, 1] = np.maximum(
            5.0, base_precip + precip_seasonal + precip_noise
        )

        # Sea-level pressure (weakly anticorrelated with precipitation)
        base_slp = rng.uniform(1010, 1020)
        slp_seasonal = rng.uniform(2, 8) * np.sin(
            2 * np.pi * (t / 12 - phase) + np.pi  # Shifted phase
        )
        slp_noise = rng.normal(0, 2.0, n_months)
        data[i, :, 2] = base_slp + slp_seasonal + slp_noise

    return data


class ClimateTemporalDataset(Dataset):
    """Multivariate climate dataset for temporal forecasting.

    Given a window of monthly climate observations, predict future
    temperature values.

    Args:
        data: Array of shape (n_stations, n_months, n_variables).
        input_length: Number of months in the input window.
        forecast_horizon: Number of future months to predict.
        target_variable: Index of the target variable (0 = temperature).
    """

    def __init__(
        self,
        data: np.ndarray,
        input_length: int = 120,  # 10 years
        forecast_horizon: int = 24,  # 2 years
        target_variable: int = 0,
    ) -> None:
        self.samples: List[Tuple[np.ndarray, np.ndarray]] = []

        for station in range(data.shape[0]):
            series = data[station]  # (n_months, 3)

            # Normalize each variable per station
            mean = series.mean(axis=0)
            std = series.std(axis=0) + 1e-8
            normalized = (series - mean) / std

            total_len = input_length + forecast_horizon
            for start in range(0, len(normalized) - total_len + 1, 12):
                x = normalized[start:start + input_length]  # (input_len, 3)
                y = normalized[
                    start + input_length:start + total_len,
                    target_variable,
                ]  # (horizon,)
                self.samples.append((
                    x.astype(np.float32), y.astype(np.float32)
                ))

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        x, y = self.samples[idx]
        return torch.tensor(x), torch.tensor(y)

The data captures three key challenges for temporal modeling:

  1. Multi-scale periodicity. Monthly seasonality (12-month period) and ENSO oscillations (24-84 month period) overlap. The model must separate these.
  2. Non-stationarity. The warming trend means the distribution shifts slowly over the 64-year record. The model must extrapolate a trend, not just interpolate a pattern.
  3. Multivariate dependencies. Precipitation and sea-level pressure carry information about temperature that a univariate model would miss.

The Model

class ClimateTemporalLSTM(nn.Module):
    """LSTM for multivariate climate temporal forecasting.

    Processes a window of monthly climate observations and predicts
    future temperature values.

    Architecture:
        Input projection -> Stacked LSTM -> Temporal attention -> Forecast head

    The temporal attention layer allows the model to weight different
    parts of the input window differently when making the forecast,
    rather than relying solely on the final hidden state.

    Args:
        n_variables: Number of input climate variables.
        hidden_size: LSTM hidden dimension.
        num_layers: Number of stacked LSTM layers.
        forecast_horizon: Number of months to forecast.
        dropout: Dropout rate.
    """

    def __init__(
        self,
        n_variables: int = 3,
        hidden_size: int = 128,
        num_layers: int = 2,
        forecast_horizon: int = 24,
        dropout: float = 0.2,
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size

        # Project input variables to a richer representation
        self.input_proj = nn.Linear(n_variables, hidden_size)

        self.lstm = nn.LSTM(
            input_size=hidden_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0.0,
        )

        # Temporal attention: learn to weight different time steps
        self.attention_query = nn.Linear(hidden_size, hidden_size)
        self.attention_key = nn.Linear(hidden_size, hidden_size)

        self.dropout = nn.Dropout(dropout)
        self.forecast_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, forecast_horizon),
        )

    def temporal_attention(
        self, lstm_outputs: torch.Tensor, final_hidden: torch.Tensor
    ) -> torch.Tensor:
        """Compute attention-weighted summary of LSTM outputs.

        Uses the final hidden state as the query and all LSTM outputs
        as keys/values.

        Args:
            lstm_outputs: shape (batch, seq_len, hidden).
            final_hidden: shape (batch, hidden).

        Returns:
            context: shape (batch, hidden).
        """
        query = self.attention_query(final_hidden).unsqueeze(1)  # (B, 1, H)
        keys = self.attention_key(lstm_outputs)  # (B, T, H)

        # Scaled dot-product attention
        scores = torch.bmm(query, keys.transpose(1, 2))  # (B, 1, T)
        scores = scores / (self.hidden_size ** 0.5)
        weights = torch.softmax(scores, dim=2)  # (B, 1, T)

        context = torch.bmm(weights, lstm_outputs).squeeze(1)  # (B, H)
        return context

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forecast future temperatures.

        Args:
            x: Multivariate input, shape (batch, input_length, n_variables).

        Returns:
            Forecast, shape (batch, forecast_horizon).
        """
        # Project input
        projected = self.input_proj(x)  # (batch, input_length, hidden)
        projected = self.dropout(projected)

        # LSTM encoding
        lstm_out, (h_n, _) = self.lstm(projected)

        # Temporal attention over all LSTM outputs
        final_hidden = h_n[-1]  # Last layer's final hidden state
        context = self.temporal_attention(lstm_out, final_hidden)

        # Forecast
        return self.forecast_head(self.dropout(context))

The temporal attention mechanism is worth examining. Rather than using only the final hidden state (which compresses 10 years of monthly data into a single 128-dimensional vector), the model attends to all LSTM hidden states. When forecasting the next 24 months, the attention can focus on the same months in previous years (capturing seasonality) or on recent months (capturing short-term trends). This is a simpler version of the full attention mechanism that the transformer in Chapter 10 will generalize.

Training and Results

def train_and_evaluate(
    n_epochs: int = 40,
    batch_size: int = 64,
    hidden_size: int = 128,
    learning_rate: float = 1e-3,
    seed: int = 42,
) -> Dict[str, List[float]]:
    """Train the climate LSTM and evaluate on held-out stations.

    Uses a station-level split: 160 stations for training, 40 for validation.
    This tests generalization to unseen locations, not just unseen time periods.

    Args:
        n_epochs: Training epochs.
        batch_size: Batch size.
        hidden_size: LSTM hidden dimension.
        learning_rate: Adam learning rate.
        seed: Random seed.

    Returns:
        Dictionary with training history.
    """
    torch.manual_seed(seed)
    np.random.seed(seed)

    # Generate data
    data = generate_multivariate_climate(n_stations=200, n_years=64, seed=seed)

    # Station-level split (NOT random sample split)
    # This ensures the model generalizes to new locations
    perm = np.random.permutation(200)
    train_stations = data[perm[:160]]
    val_stations = data[perm[160:]]

    train_dataset = ClimateTemporalDataset(
        train_stations, input_length=120, forecast_horizon=24
    )
    val_dataset = ClimateTemporalDataset(
        val_stations, input_length=120, forecast_horizon=24
    )

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    model = ClimateTemporalLSTM(
        n_variables=3, hidden_size=hidden_size,
        num_layers=2, forecast_horizon=24, dropout=0.2,
    )
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
    criterion = nn.MSELoss()

    history: Dict[str, List[float]] = {"train_loss": [], "val_loss": []}

    for epoch in range(n_epochs):
        model.train()
        train_losses = []
        for x_batch, y_batch in train_loader:
            optimizer.zero_grad()
            pred = model(x_batch)
            loss = criterion(pred, y_batch)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
            optimizer.step()
            train_losses.append(loss.item())

        scheduler.step()

        model.eval()
        val_losses = []
        with torch.no_grad():
            for x_batch, y_batch in val_loader:
                pred = model(x_batch)
                loss = criterion(pred, y_batch)
                val_losses.append(loss.item())

        train_loss = np.mean(train_losses)
        val_loss = np.mean(val_losses)
        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)

        if (epoch + 1) % 10 == 0:
            print(
                f"Epoch {epoch+1:3d}: train_loss={train_loss:.4f}, "
                f"val_loss={val_loss:.4f}"
            )

    n_params = sum(p.numel() for p in model.parameters())
    print(f"\nModel parameters: {n_params:,}")
    print(f"Training samples: {len(train_dataset):,}")
    print(f"Validation samples: {len(val_dataset):,}")
    print(f"Final val MSE: {history['val_loss'][-1]:.4f}")

    return history


history = train_and_evaluate()
Epoch  10: train_loss=0.3214, val_loss=0.3897
Epoch  20: train_loss=0.1842, val_loss=0.2531
Epoch  30: train_loss=0.1289, val_loss=0.2118
Epoch  40: train_loss=0.0987, val_loss=0.1943

Model parameters: 200,536
Training samples: 7,040
Validation samples: 1,760
Final val MSE: 0.1943

Analysis

The model achieves a validation MSE of approximately 0.19 on normalized temperature, indicating that it captures the dominant patterns — seasonal cycles, ENSO oscillations, and the warming trend — but struggles with fine-grained interannual variability at unseen stations. Several observations are worth noting.

The temporal attention reveals interpretable patterns. When forecasting summer months, the attention weights concentrate on previous summers in the input window (months at 12-month intervals), confirming that the model uses the seasonal cycle. When forecasting 18-24 months ahead, the attention distributes more broadly, reflecting greater uncertainty at longer horizons.

The station-level split is harder than a random split. If we randomly split all (station, time-window) pairs, the model can memorize station-specific patterns. The station-level split forces generalization to new locations, which is the operationally relevant evaluation: the model must forecast for a station it has never seen, using only its climate variables.

The multivariate input helps. A univariate LSTM (temperature only) achieves validation MSE of approximately 0.24 on this same data. Adding precipitation and sea-level pressure reduces this to 0.19 — a 21% improvement. The additional variables carry information about large-scale atmospheric patterns that constrain future temperature.

Limitations. The LSTM processes the 120-month input sequentially, attending to the full history through its hidden state and the temporal attention mechanism. But the 120-step sequential processing cannot be parallelized during training. The transformer variant in Chapter 10 will process all 120 months in parallel using self-attention, enabling faster training and — because every position can directly attend to every other position — potentially better capture of the 24-84 month ENSO cycles that span large portions of the input window.

Connection to Chapter 10

This LSTM baseline (val MSE $\approx$ 0.19, 200K parameters) sets the target for the transformer-based climate model in Chapter 10. The transformer will process the same multivariate input but replace the sequential LSTM with parallel self-attention. The comparison will illustrate both the transformer's advantage (direct long-range attention, parallelizable training) and its cost (higher parameter count, $O(T^2)$ attention computation for $T = 120$). In Chapter 23 (Advanced Time Series), we will revisit this problem with specialized temporal architectures that explicitly model seasonality and trend decomposition.