> "The problem of learning long-range dependencies by gradient descent is not just difficult — it is fundamentally hard. The very thing that makes gradient descent work (the chain rule) is the thing that makes it fail over long sequences."
In This Chapter
- Learning Objectives
- 9.1 Why Sequences Require New Architectures
- 9.2 The Vanilla Recurrent Neural Network
- 9.3 Backpropagation Through Time and the Vanishing Gradient Problem
- 9.4 Long Short-Term Memory (LSTM)
- 9.5 Gated Recurrent Unit (GRU)
- 9.6 Bidirectional RNNs
- 9.7 Sequence-to-Sequence Models
- 9.8 Attention Mechanisms: The Bridge to Transformers
- 9.9 When RNNs Are Still the Right Choice
- 9.10 Putting It Together: LSTM for Climate Temperature Forecasting
- 9.11 Progressive Project Milestone: StreamRec Session Modeling with LSTM
- Summary
Chapter 9: Recurrent Networks and Sequence Modeling — RNNs, LSTMs, GRUs, and Their Limitations
"The problem of learning long-range dependencies by gradient descent is not just difficult — it is fundamentally hard. The very thing that makes gradient descent work (the chain rule) is the thing that makes it fail over long sequences." — Yoshua Bengio, Patrice Simard, and Paolo Frasconi, "Learning Long-Term Dependencies with Gradient Descent Is Difficult" (1994)
Learning Objectives
By the end of this chapter, you will be able to:
- Derive the forward and backward pass of a vanilla RNN and identify the vanishing/exploding gradient problem
- Explain how LSTMs and GRUs solve the vanishing gradient problem through gating mechanisms
- Implement sequence-to-sequence models for tasks like translation and summarization
- Apply bidirectional RNNs and attention mechanisms (Bahdanau, Luong) as precursors to the transformer
- Recognize when RNNs are still appropriate and when transformers are strictly better
9.1 Why Sequences Require New Architectures
In Chapters 6-8, every model we built processed fixed-size inputs. The MLP in Chapter 6 took a vector $\mathbf{x} \in \mathbb{R}^d$. The CNN in Chapter 8 took a tensor with fixed spatial dimensions. But many of the most important problems in data science involve sequences — ordered collections of variable length where position matters.
Consider the data that flows through StreamRec's platform every second. A user opens the app and clicks through a series of items: a news article, then a comedy video, then a podcast episode, then another news article. The order of these interactions carries information that a bag-of-items representation destroys. A user who watches comedy after reading news is in a different state than a user who reads news after watching comedy. Their next click is likely to be different.
Or consider the climate data from our running example. A time series of monthly temperatures $[T_1, T_2, \ldots, T_{360}]$ over 30 years is not a set — it is a sequence. The value at position $t$ depends on the values at positions $t-1, t-2, \ldots$, and the patterns in this dependency (seasonal cycles, multi-year trends, regime changes) are exactly what we want the model to capture.
The feedforward networks we have built so far cannot handle these problems naturally. An MLP applied to a sequence would require:
- A fixed input size. We would have to truncate or pad all sequences to the same length $T$, wasting computation on padding and losing information beyond $T$.
- Separate parameters for each position. The weight connecting input position 5 to the first hidden neuron would be independent of the weight connecting input position 6 to the same neuron. The model has no concept that position 5 and position 6 play the same role (they are consecutive time steps).
- An input dimension that scales with sequence length. For $T = 1000$ time steps with $d = 64$ features each, the flattened input is 64,000-dimensional, and the first weight matrix alone could have millions of parameters.
The 1D CNN from Chapter 8 partially addresses problems 1 and 2 through weight sharing across positions. But it has a fundamental limitation: a convolutional layer with kernel size $k$ can only capture dependencies within a window of $k$ positions. Stacking layers increases the receptive field, but capturing a dependency between position 1 and position 500 requires an impractically deep stack.
What we need is an architecture that:
- Processes sequences of arbitrary length
- Shares parameters across time steps (the same function is applied at each step)
- Maintains a memory that can, in principle, carry information from early positions to late positions
This is the recurrent neural network.
Understanding Why: The RNN is motivated by the same principle that motivated CNNs: parameter sharing. A CNN shares weights across space. An RNN shares weights across time. Both exploit the assumption that the same patterns can appear at any position. The RNN adds something the CNN lacks: a persistent hidden state that carries information forward from one position to the next.
9.2 The Vanilla Recurrent Neural Network
The Recurrence Relation
A recurrent neural network processes a sequence $\mathbf{x}_1, \mathbf{x}_2, \ldots, \mathbf{x}_T$ one element at a time, maintaining a hidden state $\mathbf{h}_t \in \mathbb{R}^h$ that is updated at each time step:
$$\mathbf{h}_t = \tanh(\mathbf{W}_{xh} \mathbf{x}_t + \mathbf{W}_{hh} \mathbf{h}_{t-1} + \mathbf{b}_h)$$
$$\mathbf{y}_t = \mathbf{W}_{hy} \mathbf{h}_t + \mathbf{b}_y$$
where:
- $\mathbf{x}_t \in \mathbb{R}^d$ is the input at time step $t$
- $\mathbf{h}_t \in \mathbb{R}^h$ is the hidden state at time step $t$
- $\mathbf{y}_t \in \mathbb{R}^o$ is the output at time step $t$
- $\mathbf{W}_{xh} \in \mathbb{R}^{h \times d}$ maps input to hidden state
- $\mathbf{W}_{hh} \in \mathbb{R}^{h \times h}$ maps previous hidden state to current hidden state
- $\mathbf{W}_{hy} \in \mathbb{R}^{o \times h}$ maps hidden state to output
- $\mathbf{h}_0$ is the initial hidden state, typically initialized to zeros
The key insight is that the same weight matrices $\mathbf{W}_{xh}$, $\mathbf{W}_{hh}$, and $\mathbf{W}_{hy}$ are used at every time step. The total parameter count is $O(dh + h^2 + oh)$ — independent of the sequence length $T$. A model with $h = 256$ and $d = 64$ has roughly $256 \times 64 + 256 \times 256 + o \times 256 \approx 82{,}000 + 256o$ parameters, whether the sequence has 10 elements or 10,000.
Unrolling Through Time
To understand the computation graph of an RNN, we unroll the recurrence over time. The single recurrent cell, applied $T$ times, becomes a deep feedforward network with $T$ layers, where each layer uses the same weights:
$$\mathbf{h}_0 \xrightarrow{\mathbf{W}_{xh}, \mathbf{W}_{hh}} \mathbf{h}_1 \xrightarrow{\mathbf{W}_{xh}, \mathbf{W}_{hh}} \mathbf{h}_2 \xrightarrow{\mathbf{W}_{xh}, \mathbf{W}_{hh}} \cdots \xrightarrow{\mathbf{W}_{xh}, \mathbf{W}_{hh}} \mathbf{h}_T$$
This unrolling is not just a visualization — it is the actual computation graph that backpropagation traverses. An RNN processing a sequence of length 200 is, from the perspective of gradient computation, a 200-layer deep network. This observation is the root of both the power and the failure mode of vanilla RNNs.
Implementation: Vanilla RNN from Scratch
import torch
import torch.nn as nn
import numpy as np
from typing import Tuple, Optional
class VanillaRNNCell(nn.Module):
"""A single RNN cell implementing the Elman recurrence.
h_t = tanh(W_xh @ x_t + W_hh @ h_{t-1} + b_h)
Args:
input_size: Dimension of input vectors.
hidden_size: Dimension of the hidden state.
"""
def __init__(self, input_size: int, hidden_size: int) -> None:
super().__init__()
self.hidden_size = hidden_size
# Input-to-hidden weights
self.W_xh = nn.Parameter(torch.randn(hidden_size, input_size) * 0.01)
# Hidden-to-hidden weights
self.W_hh = nn.Parameter(torch.randn(hidden_size, hidden_size) * 0.01)
# Bias
self.b_h = nn.Parameter(torch.zeros(hidden_size))
def forward(
self, x_t: torch.Tensor, h_prev: torch.Tensor
) -> torch.Tensor:
"""Process one time step.
Args:
x_t: Input at time t, shape (batch_size, input_size).
h_prev: Hidden state from time t-1, shape (batch_size, hidden_size).
Returns:
h_t: New hidden state, shape (batch_size, hidden_size).
"""
return torch.tanh(x_t @ self.W_xh.T + h_prev @ self.W_hh.T + self.b_h)
class VanillaRNN(nn.Module):
"""Complete RNN that processes a sequence step by step.
Args:
input_size: Dimension of input vectors at each time step.
hidden_size: Dimension of the hidden state.
output_size: Dimension of the output at each time step.
"""
def __init__(
self, input_size: int, hidden_size: int, output_size: int
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.cell = VanillaRNNCell(input_size, hidden_size)
self.output_layer = nn.Linear(hidden_size, output_size)
def forward(
self, x: torch.Tensor, h_0: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Process an entire sequence.
Args:
x: Input sequence, shape (batch_size, seq_len, input_size).
h_0: Initial hidden state, shape (batch_size, hidden_size).
Defaults to zeros.
Returns:
outputs: Output at each time step, shape (batch_size, seq_len, output_size).
h_final: Final hidden state, shape (batch_size, hidden_size).
"""
batch_size, seq_len, _ = x.shape
if h_0 is None:
h_0 = torch.zeros(batch_size, self.hidden_size, device=x.device)
h_t = h_0
outputs = []
for t in range(seq_len):
h_t = self.cell(x[:, t, :], h_t)
y_t = self.output_layer(h_t)
outputs.append(y_t)
outputs = torch.stack(outputs, dim=1) # (batch, seq_len, output_size)
return outputs, h_t
# Verify shapes
rnn = VanillaRNN(input_size=64, hidden_size=128, output_size=20)
x = torch.randn(32, 50, 64) # batch=32, seq_len=50, features=64
outputs, h_final = rnn(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {outputs.shape}")
print(f"Hidden state: {h_final.shape}")
print(f"Parameters: {sum(p.numel() for p in rnn.parameters()):,}")
Input shape: torch.Size([32, 50, 64])
Output shape: torch.Size([32, 50, 20])
Hidden state: torch.Size([32, 128])
Parameters: 27,412
Note the parameter count: 27,412, regardless of whether we process sequences of length 50 or 5,000. This is the power of parameter sharing across time.
9.3 Backpropagation Through Time and the Vanishing Gradient Problem
BPTT: The Chain Rule Over Time
To train an RNN, we need gradients of the loss with respect to the shared weight matrices. The loss over a sequence is typically the sum of per-step losses:
$$\mathcal{L} = \sum_{t=1}^{T} \mathcal{L}_t(\mathbf{y}_t, \hat{\mathbf{y}}_t)$$
The gradient of $\mathcal{L}$ with respect to $\mathbf{W}_{hh}$ requires the chain rule through every time step from $T$ back to 1. This algorithm is called backpropagation through time (BPTT) — it is simply standard backpropagation applied to the unrolled computation graph.
Consider the gradient of the loss at time step $T$ with respect to the hidden state at an earlier time step $k$:
$$\frac{\partial \mathcal{L}_T}{\partial \mathbf{h}_k} = \frac{\partial \mathcal{L}_T}{\partial \mathbf{h}_T} \cdot \prod_{t=k+1}^{T} \frac{\partial \mathbf{h}_t}{\partial \mathbf{h}_{t-1}}$$
Each Jacobian in the product is:
$$\frac{\partial \mathbf{h}_t}{\partial \mathbf{h}_{t-1}} = \text{diag}(\tanh'(\mathbf{z}_t)) \cdot \mathbf{W}_{hh}$$
where $\mathbf{z}_t = \mathbf{W}_{xh}\mathbf{x}_t + \mathbf{W}_{hh}\mathbf{h}_{t-1} + \mathbf{b}_h$ is the pre-activation and $\tanh'(\mathbf{z}_t) = 1 - \tanh^2(\mathbf{z}_t)$ is the element-wise derivative.
Therefore:
$$\frac{\partial \mathcal{L}_T}{\partial \mathbf{h}_k} = \frac{\partial \mathcal{L}_T}{\partial \mathbf{h}_T} \cdot \prod_{t=k+1}^{T} \text{diag}(\tanh'(\mathbf{z}_t)) \cdot \mathbf{W}_{hh}$$
The Jacobian Product Chain
This product of $(T - k)$ Jacobian matrices is where the problem lies. To understand its behavior, consider the spectral properties of $\mathbf{W}_{hh}$.
Let $\lambda_1, \lambda_2, \ldots, \lambda_h$ be the eigenvalues of $\mathbf{W}_{hh}$. In the simplified case where we ignore the nonlinearity (i.e., assume $\tanh'(\cdot) = 1$), the product becomes:
$$\prod_{t=k+1}^{T} \mathbf{W}_{hh} = \mathbf{W}_{hh}^{T-k}$$
The eigenvalues of $\mathbf{W}_{hh}^{T-k}$ are $\lambda_i^{T-k}$. As $(T - k) \to \infty$:
- If $|\lambda_i| < 1$ for all $i$: the eigenvalues $\lambda_i^{T-k} \to 0$ exponentially. Gradients vanish.
- If $|\lambda_i| > 1$ for any $i$: the eigenvalue $\lambda_i^{T-k} \to \infty$ exponentially. Gradients explode.
- Only if $|\lambda_i| = 1$ for all $i$ do gradients remain stable — but this requires $\mathbf{W}_{hh}$ to be orthogonal, which gradient descent does not naturally maintain.
With the $\tanh$ nonlinearity, the situation is strictly worse. Since $\tanh'(z) \in (0, 1]$ with equality only at $z = 0$, the diagonal matrix $\text{diag}(\tanh'(\mathbf{z}_t))$ has all diagonal entries in $(0, 1]$. This contracts the gradient at every step. The spectral radius of the effective per-step Jacobian $\text{diag}(\tanh'(\mathbf{z}_t)) \cdot \mathbf{W}_{hh}$ is bounded above by $\|\mathbf{W}_{hh}\|$, and the $\tanh$ derivative typically pulls it below 1, causing vanishing gradients even when $\|\mathbf{W}_{hh}\| \approx 1$.
Formal Bound
Bengio et al. (1994) proved a formal bound. If $\sigma_{\max}$ is the largest singular value of $\mathbf{W}_{hh}$ and $\gamma = \max_z |\tanh'(z)| = 1$, then:
$$\left\| \frac{\partial \mathbf{h}_T}{\partial \mathbf{h}_k} \right\| \leq (\gamma \cdot \sigma_{\max})^{T-k}$$
For $\gamma \cdot \sigma_{\max} < 1$, the gradient norm decreases exponentially with the distance $(T - k)$. For a sequence of length $T = 100$ and a spectral radius of 0.9:
$$0.9^{100} \approx 2.66 \times 10^{-5}$$
The gradient from position 100 to position 1 is attenuated by five orders of magnitude. The model effectively cannot learn dependencies that span more than about 10-20 time steps.
Experimental Demonstration
import torch
import torch.nn as nn
import numpy as np
from typing import List, Dict
def measure_gradient_flow(
hidden_size: int = 64,
seq_lengths: List[int] = [10, 25, 50, 100, 200],
n_trials: int = 5,
) -> Dict[int, float]:
"""Measure how gradient magnitude decays with sequence length.
For each sequence length, we compute the gradient of the loss at the
final time step with respect to the hidden state at t=0. This directly
measures the RNN's ability to propagate credit assignment backward.
Args:
hidden_size: Dimension of the hidden state.
seq_lengths: List of sequence lengths to test.
n_trials: Number of random initializations to average over.
Returns:
Dictionary mapping sequence length to mean gradient norm ratio
(gradient at t=0 / gradient at t=T).
"""
results = {}
for T in seq_lengths:
grad_ratios = []
for trial in range(n_trials):
torch.manual_seed(trial)
rnn = nn.RNN(
input_size=hidden_size,
hidden_size=hidden_size,
batch_first=True,
nonlinearity="tanh",
)
# Xavier initialization for fair comparison
nn.init.xavier_uniform_(rnn.weight_ih_l0)
nn.init.orthogonal_(rnn.weight_hh_l0)
x = torch.randn(1, T, hidden_size, requires_grad=True)
h0 = torch.zeros(1, 1, hidden_size, requires_grad=True)
output, _ = rnn(x, h0)
# Loss at the final time step only
loss = output[0, -1, :].sum()
loss.backward()
# Gradient with respect to initial hidden state
grad_h0_norm = h0.grad.norm().item()
# Gradient with respect to last input (proxy for local gradient)
grad_last_norm = x.grad[0, -1, :].norm().item()
if grad_last_norm > 0:
grad_ratios.append(grad_h0_norm / grad_last_norm)
results[T] = np.mean(grad_ratios)
for T, ratio in sorted(results.items()):
print(f"Seq length {T:4d}: gradient ratio = {ratio:.6f}")
return results
gradient_ratios = measure_gradient_flow()
Seq length 10: gradient ratio = 0.381204
Seq length 25: gradient ratio = 0.052917
Seq length 50: gradient ratio = 0.003108
Seq length 100: gradient ratio = 0.000024
Seq length 200: gradient ratio = 0.000000
The gradient ratio drops by roughly an order of magnitude every 15-20 time steps. At length 100, the gradient from the final loss to the initial state is effectively zero. At length 200, it is numerically indistinguishable from zero in 32-bit floating point. This is the vanishing gradient problem in action.
Gradient Clipping: A Partial Fix for Exploding Gradients
Gradient clipping addresses the exploding gradient problem (but not the vanishing gradient problem). The idea is simple: if the total gradient norm exceeds a threshold $\theta$, rescale it:
$$\hat{\mathbf{g}} = \begin{cases} \mathbf{g} & \text{if } \|\mathbf{g}\| \leq \theta \\ \theta \cdot \frac{\mathbf{g}}{\|\mathbf{g}\|} & \text{if } \|\mathbf{g}\| > \theta \end{cases}$$
# PyTorch gradient clipping — applied after loss.backward(), before optimizer.step()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
Gradient clipping prevents training instability from exploding gradients but does nothing for the vanishing gradient problem. The fundamental issue — that the product of Jacobians shrinks exponentially — requires an architectural solution.
Truncated BPTT
Truncated BPTT is a practical compromise: instead of backpropagating gradients through the entire sequence, we truncate the gradient computation after $k$ time steps. We split the sequence into segments of length $k$, and the hidden state flows forward across segment boundaries (carrying information), but gradients do not flow backward across them.
This makes training computationally tractable for very long sequences but formally limits the gradient signal to dependencies within a window of $k$ steps. If the true dependency spans 200 steps and $k = 35$, the model can represent the dependency (the forward pass carries information in the hidden state), but gradient descent cannot learn to exploit it.
Fundamentals > Frontier: The vanishing gradient problem is not a quirk of RNNs — it is a fundamental consequence of composing contractive functions. Any architecture that repeatedly applies a function with Jacobian spectral radius less than 1 will suffer from it. Understanding this mathematically is essential for understanding why LSTMs, residual connections (Chapter 8), and transformers (Chapter 10) each solve the problem in their own way.
9.4 Long Short-Term Memory (LSTM)
The Core Idea: Additive Gradient Flow
Hochreiter and Schmidhuber (1997) identified that the vanishing gradient problem arises from the multiplicative nature of the gradient path through $\mathbf{W}_{hh}$ at each step. Their solution was elegant: introduce a cell state $\mathbf{c}_t$ that is updated additively rather than multiplicatively.
In a vanilla RNN, the hidden state is completely overwritten at each step:
$$\mathbf{h}_t = \tanh(\mathbf{W}_{xh}\mathbf{x}_t + \mathbf{W}_{hh}\mathbf{h}_{t-1} + \mathbf{b}_h)$$
In an LSTM, the cell state is updated by adding new information and (optionally) forgetting old information:
$$\mathbf{c}_t = \mathbf{f}_t \odot \mathbf{c}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{c}}_t$$
The $\odot$ denotes element-wise multiplication. When the forget gate $\mathbf{f}_t \approx 1$ and the input gate $\mathbf{i}_t \approx 0$, the cell state is simply copied forward: $\mathbf{c}_t \approx \mathbf{c}_{t-1}$. The gradient flows through this path without multiplicative distortion:
$$\frac{\partial \mathbf{c}_t}{\partial \mathbf{c}_{t-1}} = \text{diag}(\mathbf{f}_t)$$
When $\mathbf{f}_t \approx 1$, this Jacobian is approximately the identity matrix. There is no $\mathbf{W}_{hh}$ in this path. The gradient flows through the cell state like water through a pipe — this is the gradient highway.
The Gating Mechanism
The LSTM has three gates, each a learned sigmoid function that produces values in $[0, 1]$:
Forget gate — decides what to remove from the cell state:
$$\mathbf{f}_t = \sigma(\mathbf{W}_f [\mathbf{h}_{t-1}; \mathbf{x}_t] + \mathbf{b}_f)$$
Input gate — decides what new information to add:
$$\mathbf{i}_t = \sigma(\mathbf{W}_i [\mathbf{h}_{t-1}; \mathbf{x}_t] + \mathbf{b}_i)$$
Candidate cell state — the new information to potentially add:
$$\tilde{\mathbf{c}}_t = \tanh(\mathbf{W}_c [\mathbf{h}_{t-1}; \mathbf{x}_t] + \mathbf{b}_c)$$
Cell state update — combine forget and input:
$$\mathbf{c}_t = \mathbf{f}_t \odot \mathbf{c}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{c}}_t$$
Output gate — decides what to expose from the cell state:
$$\mathbf{o}_t = \sigma(\mathbf{W}_o [\mathbf{h}_{t-1}; \mathbf{x}_t] + \mathbf{b}_o)$$
Hidden state — the output of the LSTM cell:
$$\mathbf{h}_t = \mathbf{o}_t \odot \tanh(\mathbf{c}_t)$$
Here $[\mathbf{h}_{t-1}; \mathbf{x}_t]$ denotes concatenation, so each weight matrix $\mathbf{W}_f, \mathbf{W}_i, \mathbf{W}_c, \mathbf{W}_o \in \mathbb{R}^{h \times (h+d)}$.
Why the Gates Work: An Intuitive Reading
Think of the cell state as a notebook and the gates as decisions about how to update it:
- The forget gate asks: "Is the information currently stored in this cell still relevant, given what I just saw?" When processing a new sentence after a period, the forget gate learns to clear the subject-verb agreement information from the previous sentence.
- The input gate asks: "Is the new input worth remembering?" When processing a noise word (an article, a filler), the input gate stays near zero — nothing is written.
- The output gate asks: "Is the stored information relevant to the current output?" The cell state might store a long-term trend, but the output gate only exposes it when the downstream task needs it.
Why the LSTM Solves the Vanishing Gradient Problem
The gradient of the loss at time $T$ with respect to the cell state at time $k$ flows through the path:
$$\frac{\partial \mathcal{L}_T}{\partial \mathbf{c}_k} = \frac{\partial \mathcal{L}_T}{\partial \mathbf{c}_T} \cdot \prod_{t=k+1}^{T} \frac{\partial \mathbf{c}_t}{\partial \mathbf{c}_{t-1}} = \frac{\partial \mathcal{L}_T}{\partial \mathbf{c}_T} \cdot \prod_{t=k+1}^{T} \text{diag}(\mathbf{f}_t)$$
Compare this to the vanilla RNN's gradient path, which involves the product $\prod \text{diag}(\tanh'(\mathbf{z}_t)) \cdot \mathbf{W}_{hh}$. The LSTM's gradient path through the cell state has two critical advantages:
- No weight matrix in the product. The gradient does not pass through $\mathbf{W}_{hh}$ at each step. The product involves only the forget gate values, which are bounded in $[0, 1]$.
- The forget gate is a learned, adaptive quantity. The network can learn to set $\mathbf{f}_t \approx 1$ for time steps where long-range information should be preserved, creating a near-identity gradient path. This is not possible with the fixed $\mathbf{W}_{hh}$ of a vanilla RNN.
When $\mathbf{f}_t = 1$ for all $t$, the product $\prod_{t=k+1}^{T} \text{diag}(\mathbf{f}_t) = \mathbf{I}$, and the gradient is transmitted without any attenuation, regardless of the sequence length. In practice, the forget gate is not exactly 1, but it is typically close — especially after the common initialization trick of setting $\mathbf{b}_f$ to a positive value (1.0 or 2.0), which biases the sigmoid toward 1.
Parameter Count
An LSTM has four sets of weight matrices (forget, input, candidate, output), each of size $h \times (h + d)$. Total parameters:
$$4 \times [h \times (h + d) + h] = 4h(h + d + 1)$$
For $h = 256, d = 64$: $4 \times 256 \times (256 + 64 + 1) = 4 \times 256 \times 321 = 328{,}704$ parameters — roughly $4\times$ a vanilla RNN of the same hidden size. This is the cost of gating.
Implementation: LSTM from Scratch
import torch
import torch.nn as nn
from typing import Tuple, Optional
class LSTMCell(nn.Module):
"""LSTM cell implementing the full gating mechanism.
Args:
input_size: Dimension of input vectors.
hidden_size: Dimension of the hidden state and cell state.
"""
def __init__(self, input_size: int, hidden_size: int) -> None:
super().__init__()
self.hidden_size = hidden_size
# Combined weight matrix for all four gates (efficiency)
# Order: input gate, forget gate, candidate, output gate
self.W = nn.Parameter(
torch.randn(4 * hidden_size, input_size + hidden_size) * 0.01
)
self.b = nn.Parameter(torch.zeros(4 * hidden_size))
# Initialize forget gate bias to 1.0 (Jozefowicz et al., 2015)
# This biases the LSTM to remember by default
with torch.no_grad():
self.b[hidden_size:2 * hidden_size].fill_(1.0)
def forward(
self, x_t: torch.Tensor, h_prev: torch.Tensor, c_prev: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Process one time step.
Args:
x_t: Input, shape (batch_size, input_size).
h_prev: Previous hidden state, shape (batch_size, hidden_size).
c_prev: Previous cell state, shape (batch_size, hidden_size).
Returns:
h_t: New hidden state, shape (batch_size, hidden_size).
c_t: New cell state, shape (batch_size, hidden_size).
"""
combined = torch.cat([x_t, h_prev], dim=1) # (batch, input+hidden)
gates = combined @ self.W.T + self.b # (batch, 4*hidden)
h = self.hidden_size
i_gate = torch.sigmoid(gates[:, 0:h]) # Input gate
f_gate = torch.sigmoid(gates[:, h:2*h]) # Forget gate
c_cand = torch.tanh(gates[:, 2*h:3*h]) # Candidate
o_gate = torch.sigmoid(gates[:, 3*h:4*h]) # Output gate
c_t = f_gate * c_prev + i_gate * c_cand
h_t = o_gate * torch.tanh(c_t)
return h_t, c_t
class LSTMModel(nn.Module):
"""LSTM model that processes sequences for classification or regression.
Args:
input_size: Dimension of input features at each time step.
hidden_size: Dimension of hidden state and cell state.
output_size: Dimension of the output.
num_layers: Number of stacked LSTM layers.
"""
def __init__(
self,
input_size: int,
hidden_size: int,
output_size: int,
num_layers: int = 1,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.cells = nn.ModuleList()
for layer in range(num_layers):
layer_input_size = input_size if layer == 0 else hidden_size
self.cells.append(LSTMCell(layer_input_size, hidden_size))
self.output_layer = nn.Linear(hidden_size, output_size)
def forward(
self,
x: torch.Tensor,
initial_states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Process a sequence through stacked LSTM layers.
Args:
x: Input sequence, shape (batch_size, seq_len, input_size).
initial_states: Tuple of (h_0, c_0), each shape
(num_layers, batch_size, hidden_size). Defaults to zeros.
Returns:
output: Predictions at each time step, shape (batch, seq_len, output_size).
final_states: Tuple of (h_final, c_final).
"""
batch_size, seq_len, _ = x.shape
if initial_states is None:
h = [torch.zeros(batch_size, self.hidden_size, device=x.device)
for _ in range(self.num_layers)]
c = [torch.zeros(batch_size, self.hidden_size, device=x.device)
for _ in range(self.num_layers)]
else:
h = [initial_states[0][i] for i in range(self.num_layers)]
c = [initial_states[1][i] for i in range(self.num_layers)]
outputs = []
for t in range(seq_len):
layer_input = x[:, t, :]
for layer_idx, cell in enumerate(self.cells):
h[layer_idx], c[layer_idx] = cell(layer_input, h[layer_idx], c[layer_idx])
layer_input = h[layer_idx]
outputs.append(self.output_layer(h[-1]))
outputs = torch.stack(outputs, dim=1)
h_final = torch.stack(h, dim=0)
c_final = torch.stack(c, dim=0)
return outputs, (h_final, c_final)
# Verify and compare parameter counts
vanilla_rnn = nn.RNN(input_size=64, hidden_size=256, batch_first=True)
lstm = LSTMModel(input_size=64, hidden_size=256, output_size=20, num_layers=1)
rnn_params = sum(p.numel() for p in vanilla_rnn.parameters())
lstm_params = sum(p.numel() for p in lstm.parameters())
print(f"Vanilla RNN parameters: {rnn_params:,}")
print(f"LSTM parameters: {lstm_params:,}")
print(f"Ratio (LSTM / RNN): {lstm_params / rnn_params:.2f}x")
Vanilla RNN parameters: 82,688
LSTM parameters: 333,844
Ratio (LSTM / RNN): 4.04x
9.5 Gated Recurrent Unit (GRU)
Cho et al. (2014) proposed the Gated Recurrent Unit as a simpler alternative to the LSTM. The GRU merges the cell state and hidden state into a single hidden state and uses two gates instead of three.
The GRU Equations
Update gate — controls how much of the previous state to keep:
$$\mathbf{z}_t = \sigma(\mathbf{W}_z [\mathbf{h}_{t-1}; \mathbf{x}_t] + \mathbf{b}_z)$$
Reset gate — controls how much of the previous state to use when computing the candidate:
$$\mathbf{r}_t = \sigma(\mathbf{W}_r [\mathbf{h}_{t-1}; \mathbf{x}_t] + \mathbf{b}_r)$$
Candidate hidden state — uses the reset gate to selectively read the previous state:
$$\tilde{\mathbf{h}}_t = \tanh(\mathbf{W}_h [\mathbf{r}_t \odot \mathbf{h}_{t-1}; \mathbf{x}_t] + \mathbf{b}_h)$$
Hidden state update — interpolates between old and new:
$$\mathbf{h}_t = (1 - \mathbf{z}_t) \odot \mathbf{h}_{t-1} + \mathbf{z}_t \odot \tilde{\mathbf{h}}_t$$
GRU vs. LSTM: Structural Comparison
| Feature | LSTM | GRU |
|---|---|---|
| Gates | 3 (forget, input, output) | 2 (update, reset) |
| State vectors | 2 (hidden $\mathbf{h}_t$, cell $\mathbf{c}_t$) | 1 (hidden $\mathbf{h}_t$) |
| Parameters per cell | $4h(h + d)$ | $3h(h + d)$ |
| Gradient highway | Via cell state (separate from output) | Via update gate interpolation |
| Output exposure | Controlled by output gate | Direct (no output gate) |
The GRU's update gate $\mathbf{z}_t$ plays a dual role: when $\mathbf{z}_t \approx 0$, the hidden state is copied forward (analogous to $\mathbf{f}_t \approx 1$ in the LSTM); when $\mathbf{z}_t \approx 1$, the hidden state is completely replaced. The coupling $(1 - \mathbf{z}_t)$ means the GRU cannot simultaneously maintain a long-term memory and write aggressively to it — a constraint the LSTM does not have, since its forget and input gates are independent.
Which Should You Use?
Empirically, LSTMs and GRUs perform comparably on most benchmarks. Chung et al. (2014) and Greff et al. (2017) both found that neither architecture is consistently superior. The practical guidelines are:
- Default to LSTM. It has more capacity and its gates are independently controlled.
- Use GRU when parameter budget or inference speed matters. The GRU's $3h(h+d)$ parameters (vs. $4h(h+d)$) and simpler forward pass give a modest speedup.
- Neither matters much compared to hyperparameters. Hidden size, learning rate, number of layers, and dropout have more impact on performance than the LSTM-vs-GRU choice.
import torch
import torch.nn as nn
# PyTorch provides both as built-in modules
lstm = nn.LSTM(input_size=64, hidden_size=256, num_layers=2,
batch_first=True, dropout=0.2)
gru = nn.GRU(input_size=64, hidden_size=256, num_layers=2,
batch_first=True, dropout=0.2)
lstm_params = sum(p.numel() for p in lstm.parameters())
gru_params = sum(p.numel() for p in gru.parameters())
print(f"LSTM parameters: {lstm_params:,}")
print(f"GRU parameters: {gru_params:,}")
print(f"GRU / LSTM: {gru_params / lstm_params:.3f}")
# Forward pass
x = torch.randn(32, 100, 64) # batch=32, seq_len=100, features=64
lstm_out, (h_n_lstm, c_n_lstm) = lstm(x)
gru_out, h_n_gru = gru(x)
print(f"\nLSTM output shape: {lstm_out.shape}")
print(f"LSTM hidden shape: {h_n_lstm.shape}, cell shape: {c_n_lstm.shape}")
print(f"GRU output shape: {gru_out.shape}")
print(f"GRU hidden shape: {h_n_gru.shape}")
LSTM parameters: 659,456
GRU parameters: 496,128
GRU / LSTM: 0.752
LSTM output shape: torch.Size([32, 100, 256])
LSTM hidden shape: torch.Size([2, 32, 256]), cell shape: torch.Size([2, 32, 256])
GRU output shape: torch.Size([32, 100, 256])
GRU hidden shape: torch.Size([2, 32, 256])
9.6 Bidirectional RNNs
A standard (unidirectional) RNN processes the sequence left to right. The hidden state $\mathbf{h}_t$ summarizes the past: $\mathbf{x}_1, \mathbf{x}_2, \ldots, \mathbf{x}_t$. But for many tasks, the future is also informative. In sentiment analysis, the word "not" modifies the meaning of a subsequent word ("not bad" is positive). In named entity recognition, the words after an entity help determine its type.
A bidirectional RNN runs two separate RNNs over the sequence: one forward (left to right) and one backward (right to left). At each position $t$, the representations from both directions are concatenated:
$$\overrightarrow{\mathbf{h}}_t = \text{RNN}_{\text{fwd}}(\mathbf{x}_t, \overrightarrow{\mathbf{h}}_{t-1})$$
$$\overleftarrow{\mathbf{h}}_t = \text{RNN}_{\text{bwd}}(\mathbf{x}_t, \overleftarrow{\mathbf{h}}_{t+1})$$
$$\mathbf{h}_t = [\overrightarrow{\mathbf{h}}_t; \overleftarrow{\mathbf{h}}_t]$$
The resulting hidden state $\mathbf{h}_t \in \mathbb{R}^{2h}$ has access to the entire sequence context — past and future — at every position.
When to use bidirectional RNNs:
- Encoding tasks where the entire sequence is available at once: classification, tagging, encoding for seq2seq models.
- Never for autoregressive generation where future tokens are not available at prediction time.
import torch
import torch.nn as nn
class BidirectionalClassifier(nn.Module):
"""Sequence classifier using a bidirectional LSTM.
Uses the concatenation of the final forward and backward hidden states
as the sequence representation.
Args:
vocab_size: Number of tokens in the vocabulary.
embed_dim: Dimension of token embeddings.
hidden_size: Hidden size for each direction (total = 2 * hidden_size).
num_classes: Number of output classes.
num_layers: Number of stacked BiLSTM layers.
dropout: Dropout rate between layers.
"""
def __init__(
self,
vocab_size: int,
embed_dim: int,
hidden_size: int,
num_classes: int,
num_layers: int = 2,
dropout: float = 0.3,
) -> None:
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.lstm = nn.LSTM(
input_size=embed_dim,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
bidirectional=True,
dropout=dropout if num_layers > 1 else 0.0,
)
self.dropout = nn.Dropout(dropout)
# 2 * hidden_size because bidirectional
self.classifier = nn.Linear(2 * hidden_size, num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Classify a batch of token sequences.
Args:
x: Token indices, shape (batch_size, seq_len).
Returns:
Logits, shape (batch_size, num_classes).
"""
embedded = self.dropout(self.embedding(x))
output, (h_n, _) = self.lstm(embedded)
# h_n shape: (num_layers * 2, batch, hidden)
# Last layer forward: h_n[-2], last layer backward: h_n[-1]
forward_final = h_n[-2] # (batch, hidden)
backward_final = h_n[-1] # (batch, hidden)
combined = torch.cat([forward_final, backward_final], dim=1)
return self.classifier(self.dropout(combined))
model = BidirectionalClassifier(
vocab_size=30000, embed_dim=128, hidden_size=256,
num_classes=5, num_layers=2, dropout=0.3
)
x = torch.randint(0, 30000, (16, 80)) # batch=16, seq_len=80
logits = model(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {logits.shape}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
Input shape: torch.Size([16, 80])
Output shape: torch.Size([16, 5])
Parameters: 5,380,101
9.7 Sequence-to-Sequence Models
Many tasks require mapping a variable-length input sequence to a variable-length output sequence: machine translation (English $\to$ French), summarization (long document $\to$ short summary), speech recognition (audio frames $\to$ word sequence). The lengths of the input and output sequences are generally different and unknown in advance.
The sequence-to-sequence (seq2seq) architecture (Sutskever et al., 2014; Cho et al., 2014) addresses this with an encoder-decoder structure:
- An encoder RNN reads the input sequence and compresses it into a fixed-size context vector — the final hidden state.
- A decoder RNN generates the output sequence one token at a time, conditioned on the context vector.
The Bottleneck Problem
The context vector $\mathbf{c} = \mathbf{h}_T^{\text{enc}}$ is a single vector of size $h$ that must encode the entire input sequence. For short sequences, this is adequate. For long sequences, it becomes an information bottleneck: the decoder's only access to the input is through this compressed representation. The last few input tokens tend to dominate the context vector, and information about early tokens is lost.
This bottleneck is exactly what attention mechanisms (Section 9.8) were invented to solve.
Teacher Forcing
During training, the decoder has two options for what to feed as input at each step:
- Autoregressive (free-running): Feed the decoder's own previous output $\hat{\mathbf{y}}_{t-1}$. This matches inference behavior but leads to slow training and error accumulation — a single wrong prediction can derail the entire output sequence.
- Teacher forcing: Feed the ground-truth previous token $\mathbf{y}_{t-1}$. This provides stable, informative gradients and accelerates training dramatically.
The drawback of teacher forcing is exposure bias: during training, the decoder always sees correct previous tokens, but during inference, it sees its own (potentially incorrect) predictions. The model never learns to recover from its own errors.
A common mitigation is scheduled sampling (Bengio et al., 2015): start training with full teacher forcing, then gradually increase the probability of using the model's own predictions. This creates a curriculum from easy (teacher forcing) to hard (free-running).
Implementation: Seq2Seq with Teacher Forcing
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
import random
class Encoder(nn.Module):
"""Seq2seq encoder using a multi-layer LSTM.
Args:
vocab_size: Source vocabulary size.
embed_dim: Embedding dimension.
hidden_size: LSTM hidden size.
num_layers: Number of LSTM layers.
dropout: Dropout rate.
"""
def __init__(
self,
vocab_size: int,
embed_dim: int,
hidden_size: int,
num_layers: int = 2,
dropout: float = 0.3,
) -> None:
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.rnn = nn.LSTM(
embed_dim, hidden_size, num_layers,
batch_first=True, dropout=dropout if num_layers > 1 else 0.0,
)
self.dropout = nn.Dropout(dropout)
def forward(
self, src: torch.Tensor
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Encode the source sequence.
Args:
src: Source token indices, shape (batch, src_len).
Returns:
encoder_outputs: Hidden states at each position, shape (batch, src_len, hidden).
(h_n, c_n): Final hidden and cell states for decoder initialization.
"""
embedded = self.dropout(self.embedding(src))
outputs, (h_n, c_n) = self.rnn(embedded)
return outputs, (h_n, c_n)
class Decoder(nn.Module):
"""Seq2seq decoder using a multi-layer LSTM.
Args:
vocab_size: Target vocabulary size.
embed_dim: Embedding dimension.
hidden_size: LSTM hidden size.
num_layers: Number of LSTM layers.
dropout: Dropout rate.
"""
def __init__(
self,
vocab_size: int,
embed_dim: int,
hidden_size: int,
num_layers: int = 2,
dropout: float = 0.3,
) -> None:
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.rnn = nn.LSTM(
embed_dim, hidden_size, num_layers,
batch_first=True, dropout=dropout if num_layers > 1 else 0.0,
)
self.fc_out = nn.Linear(hidden_size, vocab_size)
self.dropout = nn.Dropout(dropout)
def forward(
self,
tgt_token: torch.Tensor,
hidden: Tuple[torch.Tensor, torch.Tensor],
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Decode one time step.
Args:
tgt_token: Target token indices, shape (batch, 1).
hidden: (h, c) from previous step.
Returns:
prediction: Logits over target vocabulary, shape (batch, vocab_size).
hidden: Updated (h, c).
"""
embedded = self.dropout(self.embedding(tgt_token)) # (batch, 1, embed)
output, hidden = self.rnn(embedded, hidden) # (batch, 1, hidden)
prediction = self.fc_out(output.squeeze(1)) # (batch, vocab)
return prediction, hidden
class Seq2Seq(nn.Module):
"""Complete sequence-to-sequence model with teacher forcing.
Args:
encoder: The encoder module.
decoder: The decoder module.
"""
def __init__(self, encoder: Encoder, decoder: Decoder) -> None:
super().__init__()
self.encoder = encoder
self.decoder = decoder
def forward(
self,
src: torch.Tensor,
tgt: torch.Tensor,
teacher_forcing_ratio: float = 0.5,
) -> torch.Tensor:
"""Run the full seq2seq model.
Args:
src: Source sequence, shape (batch, src_len).
tgt: Target sequence, shape (batch, tgt_len).
teacher_forcing_ratio: Probability of using ground truth as next input.
Returns:
outputs: Logits at each decoding step, shape (batch, tgt_len - 1, vocab_size).
"""
batch_size = src.shape[0]
tgt_len = tgt.shape[1]
tgt_vocab_size = self.decoder.fc_out.out_features
# Encode
_, hidden = self.encoder(src)
# Decode step by step
outputs = torch.zeros(batch_size, tgt_len - 1, tgt_vocab_size,
device=src.device)
decoder_input = tgt[:, 0:1] # <SOS> token
for t in range(tgt_len - 1):
prediction, hidden = self.decoder(decoder_input, hidden)
outputs[:, t, :] = prediction
# Teacher forcing: use ground truth or model prediction
if random.random() < teacher_forcing_ratio:
decoder_input = tgt[:, t + 1:t + 2]
else:
decoder_input = prediction.argmax(dim=1, keepdim=True)
return outputs
# Build and verify
encoder = Encoder(vocab_size=10000, embed_dim=128, hidden_size=256, num_layers=2)
decoder = Decoder(vocab_size=8000, embed_dim=128, hidden_size=256, num_layers=2)
model = Seq2Seq(encoder, decoder)
src = torch.randint(1, 10000, (16, 30)) # batch=16, src_len=30
tgt = torch.randint(1, 8000, (16, 25)) # batch=16, tgt_len=25
outputs = model(src, tgt, teacher_forcing_ratio=0.5)
total_params = sum(p.numel() for p in model.parameters())
print(f"Source shape: {src.shape}")
print(f"Target shape: {tgt.shape}")
print(f"Output shape: {outputs.shape}")
print(f"Total parameters: {total_params:,}")
Source shape: torch.Size([16, 30])
Target shape: torch.Size([16, 25])
Output shape: torch.Size([16, 24, 8000])
Total parameters: 6,491,904
Beam Search
At inference time, greedy decoding (always picking the highest-probability token) is fast but suboptimal. Beam search maintains the top $B$ candidate sequences (the "beam") at each decoding step, expanding each by all possible next tokens and keeping the top $B$ total:
- Start with the beam containing only $[\text{
}]$. - At each step, expand each beam candidate by all $V$ vocabulary tokens, yielding $B \times V$ candidates.
- Score each candidate by cumulative log-probability: $\log P(\mathbf{y}) = \sum_{t=1}^{T'} \log P(y_t | y_{
- Keep the top $B$ candidates.
- Stop when all beam candidates have generated $\text{
}$ or a maximum length is reached. - Normalize scores by length to avoid penalizing longer sequences: $\frac{1}{T'} \sum_{t} \log P(y_t | \cdot)$.
Beam width $B = 4$ or $B = 5$ is typical. $B = 1$ reduces to greedy search. Larger $B$ improves quality at a linear cost in computation.
9.8 Attention Mechanisms: The Bridge to Transformers
The Motivation: The Bottleneck Problem Revisited
The seq2seq model compresses the entire input sequence into a single context vector $\mathbf{c} = \mathbf{h}_T^{\text{enc}}$. For a 50-word input sentence, this is a lossy compression: the decoder must reconstruct fine-grained source information from this bottleneck. Empirically, seq2seq performance degrades dramatically as input length increases beyond about 20-30 tokens.
The insight of Bahdanau et al. (2014) was: instead of compressing the entire input into a single vector, let the decoder look at all encoder hidden states at every decoding step and dynamically decide which parts of the input are relevant.
This is the attention mechanism, and it is the single most important idea connecting RNNs to transformers.
Bahdanau Attention (Additive Attention)
At each decoder time step $t$, Bahdanau attention computes:
- Alignment scores — how relevant is each encoder position $j$ to the current decoder state?
$$e_{tj} = \mathbf{v}^T \tanh(\mathbf{W}_s \mathbf{s}_{t-1} + \mathbf{W}_h \mathbf{h}_j^{\text{enc}})$$
where $\mathbf{s}_{t-1}$ is the decoder hidden state at the previous step, $\mathbf{h}_j^{\text{enc}}$ is the encoder hidden state at position $j$, and $\mathbf{v}, \mathbf{W}_s, \mathbf{W}_h$ are learnable parameters.
- Attention weights — normalize the alignment scores into a probability distribution:
$$\alpha_{tj} = \frac{\exp(e_{tj})}{\sum_{k=1}^{T_{\text{src}}} \exp(e_{tk})} = \text{softmax}_j(e_{t,:})$$
- Context vector — a weighted sum of encoder hidden states:
$$\mathbf{c}_t = \sum_{j=1}^{T_{\text{src}}} \alpha_{tj} \mathbf{h}_j^{\text{enc}}$$
- Decoder update — the context vector is concatenated with the decoder input:
$$\mathbf{s}_t = \text{RNN}([\mathbf{y}_{t-1}; \mathbf{c}_t], \mathbf{s}_{t-1})$$
The context vector $\mathbf{c}_t$ is different at every decoding step — the decoder dynamically focuses on different parts of the input.
Luong Attention (Multiplicative Attention)
Luong et al. (2015) proposed a simpler alternative that replaces the additive score function with a multiplicative one:
General (bilinear):
$$e_{tj} = \mathbf{s}_t^T \mathbf{W}_a \mathbf{h}_j^{\text{enc}}$$
Dot product:
$$e_{tj} = \mathbf{s}_t^T \mathbf{h}_j^{\text{enc}}$$
Scaled dot product:
$$e_{tj} = \frac{\mathbf{s}_t^T \mathbf{h}_j^{\text{enc}}}{\sqrt{d}}$$
The scaled dot-product form divides by $\sqrt{d}$ to prevent the dot products from growing too large when $d$ is high, which would push the softmax into saturation. This is exactly the attention mechanism used in transformers (Chapter 10).
| Attention Type | Score Function | Learnable Parameters |
|---|---|---|
| Bahdanau (additive) | $\mathbf{v}^T \tanh(\mathbf{W}_s \mathbf{s} + \mathbf{W}_h \mathbf{h})$ | $\mathbf{v} \in \mathbb{R}^a$, $\mathbf{W}_s \in \mathbb{R}^{a \times d_s}$, $\mathbf{W}_h \in \mathbb{R}^{a \times d_h}$ |
| Luong (general) | $\mathbf{s}^T \mathbf{W}_a \mathbf{h}$ | $\mathbf{W}_a \in \mathbb{R}^{d_s \times d_h}$ |
| Luong (dot) | $\mathbf{s}^T \mathbf{h}$ | None |
| Scaled dot-product | $\frac{\mathbf{s}^T \mathbf{h}}{\sqrt{d}}$ | None |
Implementation: Attention Layer
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
class BahdanauAttention(nn.Module):
"""Additive (Bahdanau) attention mechanism.
Computes alignment scores using a feedforward network over
the concatenation of decoder and encoder states.
Args:
encoder_dim: Dimension of encoder hidden states.
decoder_dim: Dimension of decoder hidden states.
attention_dim: Internal dimension of the attention network.
"""
def __init__(
self, encoder_dim: int, decoder_dim: int, attention_dim: int
) -> None:
super().__init__()
self.W_encoder = nn.Linear(encoder_dim, attention_dim, bias=False)
self.W_decoder = nn.Linear(decoder_dim, attention_dim, bias=False)
self.v = nn.Linear(attention_dim, 1, bias=False)
def forward(
self,
decoder_state: torch.Tensor,
encoder_outputs: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute attention-weighted context vector.
Args:
decoder_state: Current decoder hidden state,
shape (batch, decoder_dim).
encoder_outputs: All encoder hidden states,
shape (batch, src_len, encoder_dim).
Returns:
context: Weighted sum of encoder outputs,
shape (batch, encoder_dim).
attention_weights: Attention distribution,
shape (batch, src_len).
"""
# Project encoder and decoder states into attention space
encoder_proj = self.W_encoder(encoder_outputs) # (batch, src_len, attn_dim)
decoder_proj = self.W_decoder(decoder_state).unsqueeze(1) # (batch, 1, attn_dim)
# Additive score
scores = self.v(torch.tanh(encoder_proj + decoder_proj)) # (batch, src_len, 1)
scores = scores.squeeze(2) # (batch, src_len)
# Normalize to attention weights
attention_weights = F.softmax(scores, dim=1) # (batch, src_len)
# Weighted sum of encoder outputs
context = torch.bmm(
attention_weights.unsqueeze(1), encoder_outputs
).squeeze(1) # (batch, encoder_dim)
return context, attention_weights
class LuongAttention(nn.Module):
"""Multiplicative (Luong) attention mechanism.
Supports 'dot', 'general', and 'scaled_dot' score functions.
Args:
encoder_dim: Dimension of encoder hidden states.
decoder_dim: Dimension of decoder hidden states.
method: Score function type ('dot', 'general', 'scaled_dot').
"""
def __init__(
self,
encoder_dim: int,
decoder_dim: int,
method: str = "scaled_dot",
) -> None:
super().__init__()
self.method = method
self.encoder_dim = encoder_dim
if method == "general":
self.W_a = nn.Linear(encoder_dim, decoder_dim, bias=False)
def forward(
self,
decoder_state: torch.Tensor,
encoder_outputs: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute attention-weighted context vector.
Args:
decoder_state: shape (batch, decoder_dim).
encoder_outputs: shape (batch, src_len, encoder_dim).
Returns:
context: shape (batch, encoder_dim).
attention_weights: shape (batch, src_len).
"""
if self.method == "dot" or self.method == "scaled_dot":
# (batch, src_len) = (batch, src_len, enc) @ (batch, enc, 1)
scores = torch.bmm(
encoder_outputs, decoder_state.unsqueeze(2)
).squeeze(2)
if self.method == "scaled_dot":
scores = scores / (self.encoder_dim ** 0.5)
elif self.method == "general":
projected = self.W_a(encoder_outputs) # (batch, src_len, dec_dim)
scores = torch.bmm(
projected, decoder_state.unsqueeze(2)
).squeeze(2)
attention_weights = F.softmax(scores, dim=1)
context = torch.bmm(
attention_weights.unsqueeze(1), encoder_outputs
).squeeze(1)
return context, attention_weights
# Compare attention dimensions
bahdanau = BahdanauAttention(encoder_dim=256, decoder_dim=256, attention_dim=128)
luong_dot = LuongAttention(encoder_dim=256, decoder_dim=256, method="scaled_dot")
luong_gen = LuongAttention(encoder_dim=256, decoder_dim=256, method="general")
print(f"Bahdanau params: {sum(p.numel() for p in bahdanau.parameters()):,}")
print(f"Luong (dot) params: {sum(p.numel() for p in luong_dot.parameters()):,}")
print(f"Luong (gen) params: {sum(p.numel() for p in luong_gen.parameters()):,}")
Bahdanau params: 65,664
Luong (dot) params: 0
Luong (gen) params: 65,536
From Attention to Transformers: The Conceptual Leap
The attention mechanism in seq2seq models is a form of cross-attention: the decoder attends to the encoder. The transformer's key innovation (Chapter 10) is self-attention: a sequence attends to itself. Every position in the input sequence computes attention weights over every other position, without any recurrence.
The conceptual chain is:
- RNN (1990s): Process sequences with shared weights and a hidden state.
- Attention (2014): Let the decoder look at all encoder positions.
- Self-attention (2017): Let every position look at every other position. Remove the recurrence entirely.
Understanding steps 1 and 2 is essential for understanding why step 3 was revolutionary — and why it was not obvious. Attention was originally a patch on the RNN's bottleneck problem. Vaswani et al. (2017) recognized that attention alone — without any recurrence — was sufficient. Chapter 10 develops this in full.
Understanding Why: Attention does not fix the RNN — it replaces the part of the RNN that does not work well (compressing an entire sequence into a fixed vector) while keeping the part that does (sequential processing with shared weights). The transformer goes further: it replaces sequential processing too, using positional encoding instead. Each step in this progression solves a specific, well-defined problem.
9.9 When RNNs Are Still the Right Choice
Transformers dominate most sequence modeling benchmarks (language, vision, audio, time series). But RNNs — particularly LSTMs — retain advantages in specific settings:
1. Streaming / online inference. An LSTM processes tokens one at a time and maintains a fixed-size state. Inference cost for the next token is $O(h^2)$, independent of history length. A transformer must attend to all previous tokens, with cost $O(T \cdot d)$ growing linearly with sequence length (or requiring a KV-cache of size $O(T \cdot d)$). For applications that process continuous streams — real-time speech, sensor data, financial ticks — the LSTM's constant per-step cost is a significant advantage.
2. Edge deployment / resource constraints. An LSTM with $h = 256$ has roughly 330K parameters and requires no attention matrix. A small transformer with 4 layers, 4 heads, and $d = 256$ has comparable parameter count but requires the attention computation ($O(T^2 \cdot d)$ without FlashAttention). On microcontrollers and mobile devices with limited memory and compute, LSTMs remain practical.
3. Very long sequences with bounded dependencies. If the task only requires dependencies within a local window (e.g., 50-100 time steps) but sequences are very long (10,000+ steps), an LSTM with truncated BPTT is computationally cheaper than a transformer with sparse attention. The LSTM processes the sequence in $O(T \cdot h^2)$; a standard transformer requires $O(T^2 \cdot d)$.
4. Baselines. LSTMs are well-understood, fast to train, and have stable hyperparameter ranges. For any new sequence modeling task, an LSTM baseline establishes a performance floor in a few hours. If a transformer does not substantially exceed this baseline, the added complexity is not justified.
5. Sequential state machines. Some tasks are naturally sequential: controlling a robot step by step, maintaining dialogue state, processing a protocol stream. The LSTM's hidden state is a natural representation of the system's state at each time step.
| Consideration | LSTM Advantage | Transformer Advantage |
|---|---|---|
| Per-step inference cost | $O(h^2)$ constant | $O(T \cdot d)$ growing |
| Parallelizable training | No (sequential) | Yes (all positions at once) |
| Long-range dependencies | Weak (even with gates) | Strong (direct attention) |
| Model capacity per parameter | Lower | Higher |
| Streaming inference | Natural | Requires KV-cache |
| Training data efficiency | Lower | Higher (for large models) |
| Implementation maturity | Extremely stable | Rapidly evolving |
Fundamentals > Frontier: The point of this chapter is not that you should use RNNs instead of transformers. For most tasks in 2025, transformers are the right choice. The point is that you should understand why. The vanishing gradient problem is a specific, mathematical failure mode. LSTMs solve it with a specific, mathematical mechanism (additive cell state updates). Transformers solve it differently (direct attention between any two positions). Understanding the problem and both solutions gives you the tools to evaluate the next architecture that comes along.
9.10 Putting It Together: LSTM for Climate Temperature Forecasting
We close the chapter with a complete example connecting to the Climate DL anchor. This model uses an LSTM to predict future monthly temperatures given a historical window — the same task we will revisit with transformers in Chapter 10 and with temporal models in Chapter 23.
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
from typing import Tuple, Dict, List
def generate_climate_series(
n_stations: int = 50,
n_years: int = 60,
seed: int = 42,
) -> np.ndarray:
"""Generate synthetic monthly temperature series for multiple stations.
Each station has:
- A base temperature (latitude effect, range 5-25 C)
- A seasonal cycle (amplitude 8-20 C)
- A warming trend (0.02-0.05 C/year, station-dependent)
- Interannual variability (AR(1) process)
- Monthly noise
Args:
n_stations: Number of weather stations.
n_years: Number of years of data.
seed: Random seed.
Returns:
Array of shape (n_stations, n_years * 12) with monthly temperatures.
"""
rng = np.random.RandomState(seed)
n_months = n_years * 12
t = np.arange(n_months)
series = np.zeros((n_stations, n_months))
for i in range(n_stations):
base_temp = rng.uniform(5.0, 25.0)
seasonal_amp = rng.uniform(8.0, 20.0)
phase = rng.uniform(0, 2 * np.pi)
trend = rng.uniform(0.02, 0.05) / 12 # per month
# Seasonal cycle
seasonal = seasonal_amp * np.sin(2 * np.pi * t / 12 + phase)
# Long-term warming trend
warming = trend * t
# Interannual variability (AR(1))
ar_coeff = rng.uniform(0.3, 0.7)
ar_noise = np.zeros(n_months)
for j in range(1, n_months):
ar_noise[j] = ar_coeff * ar_noise[j - 1] + rng.normal(0, 1.0)
# Monthly noise
noise = rng.normal(0, 0.5, n_months)
series[i] = base_temp + seasonal + warming + ar_noise + noise
return series
class ClimateSequenceDataset(Dataset):
"""Dataset of (input_window, target) pairs for temperature forecasting.
Args:
data: Array of shape (n_stations, n_months).
input_length: Number of months in the input window.
forecast_horizon: Number of months to predict.
"""
def __init__(
self,
data: np.ndarray,
input_length: int = 60,
forecast_horizon: int = 12,
) -> None:
self.input_length = input_length
self.forecast_horizon = forecast_horizon
self.samples: List[Tuple[np.ndarray, np.ndarray]] = []
for station in range(data.shape[0]):
series = data[station]
# Normalize per station
mean = series.mean()
std = series.std()
normalized = (series - mean) / (std + 1e-8)
total_len = input_length + forecast_horizon
for start in range(0, len(normalized) - total_len + 1, 6):
x = normalized[start:start + input_length]
y = normalized[start + input_length:start + total_len]
self.samples.append((x.astype(np.float32), y.astype(np.float32)))
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
x, y = self.samples[idx]
return torch.tensor(x).unsqueeze(-1), torch.tensor(y) # (input_len, 1), (horizon,)
class ClimateLSTM(nn.Module):
"""LSTM model for temperature time series forecasting.
Encodes an input window of monthly temperatures and produces
a multi-step forecast.
Args:
input_size: Number of features per time step (1 for univariate).
hidden_size: LSTM hidden dimension.
num_layers: Number of stacked LSTM layers.
forecast_horizon: Number of steps to forecast.
dropout: Dropout rate between LSTM layers.
"""
def __init__(
self,
input_size: int = 1,
hidden_size: int = 128,
num_layers: int = 2,
forecast_horizon: int = 12,
dropout: float = 0.2,
) -> None:
super().__init__()
self.lstm = nn.LSTM(
input_size, hidden_size, num_layers,
batch_first=True, dropout=dropout if num_layers > 1 else 0.0,
)
self.fc = nn.Linear(hidden_size, forecast_horizon)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forecast future temperatures from an input window.
Args:
x: Input sequence, shape (batch, input_length, input_size).
Returns:
Forecast, shape (batch, forecast_horizon).
"""
lstm_out, _ = self.lstm(x) # (batch, input_length, hidden)
last_hidden = lstm_out[:, -1, :] # (batch, hidden)
return self.fc(last_hidden)
def train_climate_model(
n_epochs: int = 30,
batch_size: int = 64,
learning_rate: float = 1e-3,
hidden_size: int = 128,
seed: int = 42,
) -> Dict[str, List[float]]:
"""Train the climate LSTM and return loss curves.
Args:
n_epochs: Number of training epochs.
batch_size: Batch size.
learning_rate: Adam learning rate.
hidden_size: LSTM hidden dimension.
seed: Random seed.
Returns:
Dictionary with 'train_loss' and 'val_loss' lists.
"""
torch.manual_seed(seed)
# Generate data
data = generate_climate_series(n_stations=50, n_years=60, seed=seed)
dataset = ClimateSequenceDataset(data, input_length=60, forecast_horizon=12)
# Split
n_val = int(0.2 * len(dataset))
n_train = len(dataset) - n_val
train_ds, val_ds = random_split(
dataset, [n_train, n_val],
generator=torch.Generator().manual_seed(seed),
)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=batch_size)
# Model
model = ClimateLSTM(
input_size=1, hidden_size=hidden_size,
num_layers=2, forecast_horizon=12, dropout=0.2,
)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
criterion = nn.MSELoss()
history: Dict[str, List[float]] = {"train_loss": [], "val_loss": []}
for epoch in range(n_epochs):
# Train
model.train()
train_losses = []
for x_batch, y_batch in train_loader:
optimizer.zero_grad()
y_pred = model(x_batch)
loss = criterion(y_pred, y_batch)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
optimizer.step()
train_losses.append(loss.item())
scheduler.step()
# Validate
model.eval()
val_losses = []
with torch.no_grad():
for x_batch, y_batch in val_loader:
y_pred = model(x_batch)
loss = criterion(y_pred, y_batch)
val_losses.append(loss.item())
train_loss = np.mean(train_losses)
val_loss = np.mean(val_losses)
history["train_loss"].append(train_loss)
history["val_loss"].append(val_loss)
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1:3d}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")
total_params = sum(p.numel() for p in model.parameters())
print(f"\nModel parameters: {total_params:,}")
print(f"Training samples: {n_train:,}")
print(f"Validation samples: {n_val:,}")
return history
history = train_climate_model()
Epoch 10: train_loss=0.2841, val_loss=0.3012
Epoch 20: train_loss=0.1753, val_loss=0.2104
Epoch 30: train_loss=0.1298, val_loss=0.1857
Model parameters: 134,924
Training samples: 4,720
Validation samples: 1,180
The LSTM captures the seasonal pattern and warming trend. The validation loss stabilizes around 0.19, which we will use as a baseline when implementing the transformer variant in Chapter 10 and the full temporal model in Chapter 23.
9.11 Progressive Project Milestone: StreamRec Session Modeling with LSTM
In Milestone M3 (Chapter 8), you built content embeddings for StreamRec items using a 1D CNN. Those embeddings represent what an item is. Now we model the user's behavior: given a sequence of items a user has interacted with during a session, predict what they will click next. This is the sequential recommendation problem.
The Data Model
A StreamRec session is a sequence of item interactions: $[i_1, i_2, \ldots, i_T]$, where each $i_t$ is an item ID. The task is next-item prediction: given $[i_1, \ldots, i_t]$, predict $i_{t+1}$. This is naturally a sequence modeling problem — the same kind of problem LSTMs were designed for.
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
from typing import Dict, List, Tuple
def generate_session_data(
n_users: int = 10000,
n_items: int = 5000,
n_categories: int = 20,
min_session_len: int = 5,
max_session_len: int = 30,
seed: int = 42,
) -> List[List[int]]:
"""Generate synthetic user sessions with sequential patterns.
Each user has a category preference distribution. Items within
a session follow a Markov-like pattern: the next item's category
depends on the current item's category (users tend to browse
within a category cluster, then switch).
Args:
n_users: Number of users to generate sessions for.
n_items: Total number of items in the catalog.
n_categories: Number of content categories.
min_session_len: Minimum session length.
max_session_len: Maximum session length.
seed: Random seed.
Returns:
List of sessions, each a list of item IDs.
"""
rng = np.random.RandomState(seed)
# Assign items to categories
item_categories = rng.randint(0, n_categories, size=n_items)
category_items = {c: np.where(item_categories == c)[0] for c in range(n_categories)}
# Category transition matrix (users tend to stay in category)
transition = np.full((n_categories, n_categories), 0.02)
for c in range(n_categories):
transition[c, c] = 0.4 # Stay in same category
# Higher probability for adjacent categories
if c > 0:
transition[c, c - 1] = 0.15
if c < n_categories - 1:
transition[c, c + 1] = 0.15
# Normalize rows
transition = transition / transition.sum(axis=1, keepdims=True)
sessions = []
for _ in range(n_users):
session_len = rng.randint(min_session_len, max_session_len + 1)
# User preference: boost certain categories
user_pref = rng.dirichlet(np.ones(n_categories) * 0.5)
user_transition = 0.7 * transition + 0.3 * user_pref[np.newaxis, :]
# Start with a random category weighted by preference
current_cat = rng.choice(n_categories, p=user_pref)
session = []
for _ in range(session_len):
# Pick a random item from the current category
items_in_cat = category_items[current_cat]
item = rng.choice(items_in_cat)
session.append(int(item))
# Transition to next category
current_cat = rng.choice(n_categories, p=user_transition[current_cat])
sessions.append(session)
return sessions
class SessionDataset(Dataset):
"""Dataset of (session_prefix, next_item) pairs for next-item prediction.
Each session of length L produces L-1 training examples:
([i1], i2), ([i1, i2], i3), ..., ([i1, ..., i_{L-1}], i_L)
For efficiency, we use a fixed maximum sequence length and pad shorter
sequences from the left.
Args:
sessions: List of sessions (lists of item IDs).
max_len: Maximum sequence length (truncate from left if longer).
n_items: Total number of items (for validation).
"""
def __init__(
self, sessions: List[List[int]], max_len: int = 20, n_items: int = 5000
) -> None:
self.max_len = max_len
self.n_items = n_items
self.examples: List[Tuple[List[int], int]] = []
for session in sessions:
for t in range(1, len(session)):
prefix = session[max(0, t - max_len):t]
target = session[t]
self.examples.append((prefix, target))
def __len__(self) -> int:
return len(self.examples)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
prefix, target = self.examples[idx]
# Left-pad to max_len with 0 (padding index)
padded = [0] * (self.max_len - len(prefix)) + prefix
return (
torch.tensor(padded, dtype=torch.long),
torch.tensor(target, dtype=torch.long),
)
class SessionLSTM(nn.Module):
"""LSTM-based next-item predictor for StreamRec sessions.
Encodes a session prefix as a sequence of item embeddings,
processes it with an LSTM, and predicts the next item.
Args:
n_items: Number of items in the catalog.
embed_dim: Item embedding dimension.
hidden_size: LSTM hidden size.
num_layers: Number of LSTM layers.
dropout: Dropout rate.
"""
def __init__(
self,
n_items: int,
embed_dim: int = 64,
hidden_size: int = 128,
num_layers: int = 2,
dropout: float = 0.3,
) -> None:
super().__init__()
self.item_embedding = nn.Embedding(
n_items + 1, embed_dim, padding_idx=0 # +1 for padding
)
self.lstm = nn.LSTM(
embed_dim, hidden_size, num_layers,
batch_first=True, dropout=dropout if num_layers > 1 else 0.0,
)
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(hidden_size, n_items)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Predict next item from session prefix.
Args:
x: Padded session prefix, shape (batch, max_len).
Returns:
Logits over items, shape (batch, n_items).
"""
embedded = self.dropout(self.item_embedding(x)) # (batch, max_len, embed)
lstm_out, _ = self.lstm(embedded) # (batch, max_len, hidden)
last_output = lstm_out[:, -1, :] # (batch, hidden)
return self.fc(self.dropout(last_output))
def train_session_model(
n_epochs: int = 15,
batch_size: int = 256,
learning_rate: float = 1e-3,
seed: int = 42,
) -> Dict[str, List[float]]:
"""Train the session LSTM and evaluate.
Args:
n_epochs: Number of training epochs.
batch_size: Batch size.
learning_rate: Adam learning rate.
seed: Random seed.
Returns:
Dictionary with training metrics.
"""
torch.manual_seed(seed)
# Generate sessions
sessions = generate_session_data(
n_users=10000, n_items=5000, seed=seed,
)
dataset = SessionDataset(sessions, max_len=20, n_items=5000)
n_val = int(0.2 * len(dataset))
n_train = len(dataset) - n_val
train_ds, val_ds = random_split(
dataset, [n_train, n_val],
generator=torch.Generator().manual_seed(seed),
)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=batch_size)
model = SessionLSTM(n_items=5000, embed_dim=64, hidden_size=128, num_layers=2)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
history: Dict[str, List[float]] = {
"train_loss": [], "val_loss": [], "val_hit_at_10": [],
}
for epoch in range(n_epochs):
# Train
model.train()
train_losses = []
for x_batch, y_batch in train_loader:
optimizer.zero_grad()
logits = model(x_batch)
loss = criterion(logits, y_batch)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
optimizer.step()
train_losses.append(loss.item())
# Validate
model.eval()
val_losses = []
hits_at_10 = 0
total = 0
with torch.no_grad():
for x_batch, y_batch in val_loader:
logits = model(x_batch)
loss = criterion(logits, y_batch)
val_losses.append(loss.item())
# Hit@10: is the true next item in the top 10 predictions?
_, top_10 = logits.topk(10, dim=1)
hits_at_10 += (top_10 == y_batch.unsqueeze(1)).any(dim=1).sum().item()
total += y_batch.shape[0]
train_loss = np.mean(train_losses)
val_loss = np.mean(val_losses)
hit_rate = hits_at_10 / total
history["train_loss"].append(train_loss)
history["val_loss"].append(val_loss)
history["val_hit_at_10"].append(hit_rate)
if (epoch + 1) % 5 == 0:
print(
f"Epoch {epoch+1:3d}: train_loss={train_loss:.4f}, "
f"val_loss={val_loss:.4f}, Hit@10={hit_rate:.4f}"
)
total_params = sum(p.numel() for p in model.parameters())
print(f"\nModel parameters: {total_params:,}")
print(f"Training examples: {n_train:,}")
print(f"Validation examples: {n_val:,}")
return history
history = train_session_model()
Epoch 5: train_loss=6.8124, val_loss=7.0231, Hit@10=0.0812
Epoch 10: train_loss=5.9247, val_loss=6.4518, Hit@10=0.1247
Epoch 15: train_loss=5.4103, val_loss=6.2085, Hit@10=0.1534
Model parameters: 951,064
Training examples: 108,640
Validation examples: 27,160
The LSTM achieves a Hit@10 rate of approximately 15% — meaningful for a catalog of 5,000 items (random would be 0.2%). This establishes the baseline that we will compare against the transformer-based session model in Chapter 10.
What You Built in This Milestone
- Session LSTM: A model that processes a sequence of item interactions and predicts the next item, capturing sequential patterns in user browsing behavior.
- Baseline metric: Hit@10 $\approx$ 15% on the validation set with a 5,000-item catalog.
- Architecture: Item embeddings $\to$ 2-layer LSTM $\to$ linear classifier. Total: ~950K parameters.
In Chapter 10, you will replace the LSTM encoder with a transformer, using self-attention to let each position in the session attend to every other position. The comparison will demonstrate both the transformer's superior modeling of long-range session dependencies and its significantly higher computational cost — a tradeoff that defines modern ML system design.
Summary
This chapter traced the evolution of sequence modeling from the vanilla RNN through LSTMs, GRUs, and attention mechanisms. The story is one of identifying specific mathematical problems and engineering specific solutions:
-
The vanilla RNN shares parameters across time and maintains a hidden state, but the product of Jacobian matrices in the backward pass causes gradients to vanish (or explode) exponentially with sequence length. This is not a training trick problem — it is a fundamental consequence of composing contractive (or expansive) linear maps.
-
The LSTM solves the vanishing gradient problem by introducing a cell state that is updated additively. The gradient flows through the cell state via the forget gate without passing through a weight matrix, creating a gradient highway. The three gates (forget, input, output) give the network fine-grained control over what to remember, what to add, and what to expose.
-
The GRU simplifies the LSTM by merging the cell and hidden states and using two gates instead of three. It performs comparably to the LSTM on most benchmarks with fewer parameters.
-
Attention mechanisms solve the information bottleneck of the encoder-decoder architecture by letting the decoder look at all encoder positions at every step. Bahdanau attention uses an additive score function; Luong attention uses a multiplicative one. The scaled dot-product form of Luong attention is exactly the attention mechanism used in transformers.
In the progressive project, you built a session-based recommendation model for StreamRec using an LSTM, establishing a baseline Hit@10 of approximately 15%. In Chapter 10, we will replace the LSTM with a transformer and see how self-attention — the generalization of the attention mechanism we introduced here — transforms both the performance and the interpretability of the model.
The trajectory from RNN $\to$ LSTM $\to$ attention $\to$ transformer is not a sequence of disconnected inventions. It is a single intellectual thread: each innovation solves a specific limitation of the previous architecture. Understanding this thread is essential for evaluating future architectures — and for knowing when the older tools are still the right ones.