> "The only reason for time is so that everything doesn't happen at once."
In This Chapter
- 15.1 The Nature of Sequential Data
- 15.2 Vanilla Recurrent Neural Networks
- 15.3 The Vanishing and Exploding Gradient Problems
- 15.4 Long Short-Term Memory (LSTM)
- 15.5 Gated Recurrent Units (GRU)
- 15.6 Bidirectional RNNs
- 15.7 Deep RNNs and Stacking Layers
- 15.8 Sequence-to-Sequence Models
- 15.9 Teacher Forcing
- 15.10 Beam Search
- 15.10.5 Practical Beam Search Settings
- 15.11 Attention Mechanism Preview
- 15.12 Comparison with 1D Convolutions for Sequences
- 15.13 Practical Considerations
- 15.14 RNNs in the Modern Landscape
- 15.15 Summary
- References
Chapter 15: Recurrent Neural Networks and Sequence Modeling
"The only reason for time is so that everything doesn't happen at once." --- Albert Einstein
In the previous chapters, we explored feedforward networks and convolutional architectures that process fixed-size inputs. But the world is fundamentally sequential. Language unfolds word by word. Stock prices evolve tick by tick. Sensor readings stream continuously. Music progresses note by note. To model these phenomena, we need architectures that can process sequences of arbitrary length while maintaining memory of what came before.
Recurrent Neural Networks (RNNs) introduced a revolutionary idea: networks with loops. By feeding a hidden state from one time step to the next, RNNs create a form of memory that allows information to persist across time. This simple concept unlocked an entire category of problems that feedforward networks could not address, from machine translation to speech recognition to time series forecasting.
In this chapter, we will build RNNs from the ground up. We will start with the vanilla RNN, understand why it struggles with long sequences, and then explore the gated architectures---Long Short-Term Memory (LSTM) and Gated Recurrent Unit (GRU)---that solved those problems. We will cover bidirectional processing, sequence-to-sequence models, teacher forcing, and beam search. By the end, you will have a thorough understanding of how to model sequential data and the tools to implement these models in PyTorch.
15.1 The Nature of Sequential Data
Before diving into architectures, let us understand what makes sequential data fundamentally different from the static data we have processed so far.
15.1.1 Why Sequences Are Different
Consider the sentence "The cat sat on the mat." Each word's meaning is influenced by the words that precede it. The word "sat" makes sense because we know the subject is "cat." If we shuffled the words randomly---"mat the on cat sat the"---the meaning disintegrates. Order matters.
Sequential data appears everywhere in engineering and science:
- Natural language: Words form sentences; sentences form documents
- Time series: Stock prices, sensor readings, weather data
- Audio: Waveforms sampled at regular intervals
- Video: Sequences of image frames
- Genomics: DNA sequences of nucleotide bases (A, T, G, C)
- User behavior: Clickstreams, purchase histories, navigation paths
What these domains share is temporal or positional dependence: the meaning or value at position $t$ depends on what happened at positions $t-1, t-2, \ldots$
15.1.2 Modeling Requirements for Sequences
An architecture designed for sequences must satisfy several requirements:
- Variable-length input handling: Sequences can be 5 tokens or 5,000 tokens long.
- Parameter sharing across time: The same transformation should apply at each time step; we should not learn separate weights for position 1 versus position 100.
- Memory: The model must maintain a summary of past inputs that it can use when processing the current input.
- Compositionality: The model should be able to build complex representations from simpler ones over time.
A naive approach might flatten a sequence into a single long vector and feed it to a feedforward network. This fails for several reasons: it cannot handle variable-length inputs, it requires separate parameters for each position (no sharing), and it scales poorly. RNNs address all of these issues elegantly.
15.1.3 Sequence Modeling Tasks
RNNs support several task configurations:
| Task Type | Input | Output | Example |
|---|---|---|---|
| Many-to-one | Sequence | Single value | Sentiment analysis |
| One-to-many | Single value | Sequence | Image captioning |
| Many-to-many (aligned) | Sequence | Sequence (same length) | Part-of-speech tagging |
| Many-to-many (unaligned) | Sequence | Sequence (different length) | Machine translation |
This flexibility makes RNNs one of the most versatile architectural families in deep learning.
15.2 Vanilla Recurrent Neural Networks
15.2.1 The Recurrence Relation
The core idea of an RNN is breathtakingly simple. At each time step $t$, the network takes two inputs: the current input $\mathbf{x}_t$ and the previous hidden state $\mathbf{h}_{t-1}$. It produces a new hidden state $\mathbf{h}_t$:
$$\mathbf{h}_t = \tanh(\mathbf{W}_{hh}\mathbf{h}_{t-1} + \mathbf{W}_{xh}\mathbf{x}_t + \mathbf{b}_h)$$
The output at time step $t$ is computed from the hidden state:
$$\mathbf{y}_t = \mathbf{W}_{hy}\mathbf{h}_t + \mathbf{b}_y$$
Here: - $\mathbf{x}_t \in \mathbb{R}^d$ is the input at time step $t$ - $\mathbf{h}_t \in \mathbb{R}^n$ is the hidden state at time step $t$ - $\mathbf{W}_{xh} \in \mathbb{R}^{n \times d}$ maps inputs to hidden states - $\mathbf{W}_{hh} \in \mathbb{R}^{n \times n}$ maps previous hidden states to current hidden states - $\mathbf{W}_{hy} \in \mathbb{R}^{m \times n}$ maps hidden states to outputs - $\mathbf{b}_h, \mathbf{b}_y$ are bias vectors - $\mathbf{h}_0$ is typically initialized to zeros
The critical insight is that the same weights $\mathbf{W}_{xh}$, $\mathbf{W}_{hh}$, and $\mathbf{W}_{hy}$ are used at every time step. This parameter sharing is what allows RNNs to generalize across sequence positions and handle variable-length inputs.
15.2.2 Unrolling Through Time
To understand how an RNN processes a sequence, we "unroll" it through time. Given a sequence $(\mathbf{x}_1, \mathbf{x}_2, \ldots, \mathbf{x}_T)$, the unrolled computation graph looks like a chain of identical transformations:
x_1 x_2 x_3 x_T
| | | |
v v v v
[RNN] --> [RNN] --> [RNN] --> ... [RNN]
| | | |
v v v v
h_1 h_2 h_3 h_T
| | | |
v v v v
y_1 y_2 y_3 y_T
Each [RNN] block applies the same transformation with the same weights. The hidden state $\mathbf{h}_t$ serves as the network's "memory," carrying information from all previous time steps.
15.2.3 Backpropagation Through Time (BPTT)
Training an RNN requires computing gradients through the unrolled computation graph. This algorithm is called Backpropagation Through Time (BPTT). It is not a new algorithm per se---it is simply standard backpropagation (as we derived in Chapter 11) applied to the unrolled computation graph of the RNN. However, the temporal structure introduces unique challenges that deserve careful analysis.
Deriving BPTT Step by Step.
Given a loss function $\mathcal{L} = \sum_{t=1}^{T} \mathcal{L}_t$, where $\mathcal{L}_t$ is the loss at time step $t$, we need the gradient of the total loss with respect to the shared weight matrices. Let us derive $\frac{\partial \mathcal{L}}{\partial \mathbf{W}_{hh}}$.
Since $\mathbf{W}_{hh}$ is used at every time step, the total gradient is the sum of contributions from each time step:
$$\frac{\partial \mathcal{L}}{\partial \mathbf{W}_{hh}} = \sum_{t=1}^{T} \frac{\partial \mathcal{L}_t}{\partial \mathbf{W}_{hh}}$$
For a single time step's contribution, we need to account for the fact that $\mathbf{W}_{hh}$ affects $\mathcal{L}_t$ through every hidden state from $\mathbf{h}_1$ to $\mathbf{h}_t$. Applying the chain rule:
$$\frac{\partial \mathcal{L}_t}{\partial \mathbf{W}_{hh}} = \sum_{k=1}^{t} \frac{\partial \mathcal{L}_t}{\partial \mathbf{h}_t} \frac{\partial \mathbf{h}_t}{\partial \mathbf{h}_k} \frac{\partial^+ \mathbf{h}_k}{\partial \mathbf{W}_{hh}}$$
where $\frac{\partial^+ \mathbf{h}_k}{\partial \mathbf{W}_{hh}}$ denotes the "immediate" derivative of $\mathbf{h}_k$ with respect to $\mathbf{W}_{hh}$ (treating $\mathbf{h}_{k-1}$ as a constant). This immediate derivative is straightforward:
$$\frac{\partial^+ \mathbf{h}_k}{\partial \mathbf{W}_{hh}} = \text{diag}(1 - \mathbf{h}_k^2) \cdot \mathbf{h}_{k-1}^T$$
where $\text{diag}(1 - \mathbf{h}_k^2)$ is the derivative of tanh applied element-wise.
The critical term is the Jacobian product that propagates the error signal backward through time:
$$\frac{\partial \mathbf{h}_t}{\partial \mathbf{h}_k} = \prod_{j=k+1}^{t} \frac{\partial \mathbf{h}_j}{\partial \mathbf{h}_{j-1}}$$
Each factor in this product is the Jacobian of one time step's hidden state with respect to the previous:
$$\frac{\partial \mathbf{h}_j}{\partial \mathbf{h}_{j-1}} = \text{diag}(1 - \mathbf{h}_j^2) \cdot \mathbf{W}_{hh}$$
Combining everything:
$$\frac{\partial \mathcal{L}}{\partial \mathbf{W}_{hh}} = \sum_{t=1}^{T} \sum_{k=1}^{t} \frac{\partial \mathcal{L}_t}{\partial \mathbf{h}_t} \left(\prod_{j=k+1}^{t} \text{diag}(1 - \mathbf{h}_j^2) \cdot \mathbf{W}_{hh}\right) \text{diag}(1 - \mathbf{h}_k^2) \cdot \mathbf{h}_{k-1}^T$$
This double sum over $t$ and $k$ is computationally expensive, scaling as $O(T^2)$ in the worst case. In practice, we implement BPTT by performing a forward pass to compute and cache all hidden states, then sweeping backward from $t = T$ to $t = 1$, accumulating gradients as we go. This is algorithmically efficient: $O(T)$ time and $O(T)$ memory (for caching hidden states).
Truncated BPTT. For very long sequences, storing all $T$ hidden states and backpropagating through all of them is prohibitive. Truncated BPTT limits backpropagation to a fixed window of $\tau$ time steps. The sequence is divided into chunks of length $\tau$; for each chunk, the forward pass uses the hidden state from the previous chunk, but gradients are only computed within the chunk. This introduces a bias (gradients from dependencies longer than $\tau$ steps are ignored) but dramatically reduces memory and computation. Typical values are $\tau = 35$ to $\tau = 200$.
# Truncated BPTT: process sequence in chunks
chunk_size = 35 # Truncation length
for i in range(0, seq_len, chunk_size):
chunk = sequence[:, i:i+chunk_size, :]
output, hidden = model(chunk, hidden)
loss = criterion(output, targets[:, i:i+chunk_size])
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Detach hidden state to prevent backprop beyond this chunk
hidden = hidden.detach()
This product of Jacobians $\prod_{j=k+1}^{t} \frac{\partial \mathbf{h}_j}{\partial \mathbf{h}_{j-1}}$ is the source of the vanishing and exploding gradient problems that we will discuss next.
15.2.4 Implementing a Vanilla RNN Cell
Let us implement a vanilla RNN from scratch to solidify our understanding:
import torch
import torch.nn as nn
torch.manual_seed(42)
class VanillaRNNCell(nn.Module):
"""A single vanilla RNN cell.
Implements the recurrence: h_t = tanh(W_xh @ x_t + W_hh @ h_{t-1} + b_h)
Args:
input_size: Dimensionality of the input vectors.
hidden_size: Dimensionality of the hidden state.
"""
def __init__(self, input_size: int, hidden_size: int) -> None:
super().__init__()
self.hidden_size = hidden_size
self.W_xh = nn.Linear(input_size, hidden_size)
self.W_hh = nn.Linear(hidden_size, hidden_size, bias=False)
def forward(
self, x_t: torch.Tensor, h_prev: torch.Tensor
) -> torch.Tensor:
"""Compute the next hidden state.
Args:
x_t: Input at current time step, shape (batch_size, input_size).
h_prev: Previous hidden state, shape (batch_size, hidden_size).
Returns:
New hidden state, shape (batch_size, hidden_size).
"""
return torch.tanh(self.W_xh(x_t) + self.W_hh(h_prev))
To process a full sequence:
class VanillaRNN(nn.Module):
"""A vanilla RNN that processes entire sequences.
Args:
input_size: Dimensionality of input vectors.
hidden_size: Dimensionality of hidden state.
output_size: Dimensionality of output vectors.
"""
def __init__(
self, input_size: int, hidden_size: int, output_size: int
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.rnn_cell = VanillaRNNCell(input_size, hidden_size)
self.output_layer = nn.Linear(hidden_size, output_size)
def forward(
self, x: torch.Tensor, h_0: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""Process a sequence and return outputs and final hidden state.
Args:
x: Input sequence, shape (batch_size, seq_len, input_size).
h_0: Initial hidden state. Defaults to zeros.
Returns:
Tuple of (outputs, h_final) where outputs has shape
(batch_size, seq_len, output_size) and h_final has shape
(batch_size, hidden_size).
"""
batch_size, seq_len, _ = x.shape
if h_0 is None:
h_0 = torch.zeros(batch_size, self.hidden_size, device=x.device)
h_t = h_0
outputs = []
for t in range(seq_len):
h_t = self.rnn_cell(x[:, t, :], h_t)
y_t = self.output_layer(h_t)
outputs.append(y_t)
outputs = torch.stack(outputs, dim=1)
return outputs, h_t
15.3 The Vanishing and Exploding Gradient Problems
15.3.1 Mathematical Analysis
The vanilla RNN has an Achilles' heel. Recall the product of Jacobians from BPTT:
$$\prod_{j=k+1}^{t} \frac{\partial \mathbf{h}_j}{\partial \mathbf{h}_{j-1}}$$
Each Jacobian $\frac{\partial \mathbf{h}_j}{\partial \mathbf{h}_{j-1}} = \text{diag}(1 - \mathbf{h}_j^2) \cdot \mathbf{W}_{hh}$ where $\text{diag}(1 - \mathbf{h}_j^2)$ is the derivative of tanh.
Bounding the Jacobian norm. Since $\tanh'(z) = 1 - \tanh^2(z) \leq 1$ for all $z$, we have $\|\text{diag}(1 - \mathbf{h}_j^2)\| \leq 1$. Therefore:
$$\left\| \frac{\partial \mathbf{h}_j}{\partial \mathbf{h}_{j-1}} \right\| \leq \|\text{diag}(1 - \mathbf{h}_j^2)\| \cdot \|\mathbf{W}_{hh}\| \leq \|\mathbf{W}_{hh}\|$$
For the product over $t - k$ time steps:
$$\left\| \prod_{j=k+1}^{t} \frac{\partial \mathbf{h}_j}{\partial \mathbf{h}_{j-1}} \right\| \leq \|\mathbf{W}_{hh}\|^{t-k}$$
If the largest singular value $\sigma_{\max}(\mathbf{W}_{hh}) < 1$, then:
$$\left\| \prod_{j=k+1}^{t} \frac{\partial \mathbf{h}_j}{\partial \mathbf{h}_{j-1}} \right\| \leq (\gamma)^{t-k} \to 0 \text{ as } t-k \to \infty$$
where $\gamma < 1$. This is the vanishing gradient problem: gradients from distant time steps shrink exponentially to zero, making it impossible for the network to learn long-range dependencies.
Conversely, if $\sigma_{\max}(\mathbf{W}_{hh}) > 1$, gradients can grow exponentially---the exploding gradient problem.
How fast do gradients vanish? Let us put numbers to this. Suppose $\|\mathbf{W}_{hh}\| = 0.9$ and the tanh derivative contributes an average factor of 0.5 (since tanh saturates for large activations). The effective decay rate per time step is $\gamma \approx 0.9 \times 0.5 = 0.45$. Over 20 time steps:
$$\gamma^{20} = 0.45^{20} \approx 1.2 \times 10^{-7}$$
The gradient signal from 20 steps ago is attenuated by a factor of 10 million. After 50 steps, it is essentially zero in floating-point arithmetic. This calculation, first presented rigorously by Bengio, Simard, and Frasconi (1994), explains why vanilla RNNs cannot learn dependencies spanning more than about 10--20 time steps.
The exploding side. If $\gamma = 1.1$ (only slightly above 1), then $\gamma^{50} \approx 117$. This makes the gradient 100x too large, causing the weight update to overshoot catastrophically. Gradient norms can reach $10^{10}$ or higher, producing NaN values and completely destroying the model. The exploding gradient problem is easier to detect (the training diverges obviously) and easier to fix (gradient clipping, as discussed in Section 15.3.3 and Chapter 12). The vanishing gradient problem is more insidious because training appears to proceed normally---the loss decreases---but the model silently fails to learn long-range patterns.
Connection to eigenvalue analysis. A more precise analysis uses the eigendecomposition of $\mathbf{W}_{hh}$. If $\mathbf{W}_{hh} = \mathbf{P} \boldsymbol{\Lambda} \mathbf{P}^{-1}$ where $\boldsymbol{\Lambda}$ contains the eigenvalues $\lambda_1, \ldots, \lambda_n$, then $\mathbf{W}_{hh}^k = \mathbf{P} \boldsymbol{\Lambda}^k \mathbf{P}^{-1}$. The behavior is dominated by the largest eigenvalue: if $|\lambda_{\max}| < 1$, the product vanishes; if $|\lambda_{\max}| > 1$, it explodes. The transition between vanishing and exploding occurs at $|\lambda_{\max}| = 1$, a knife-edge that is impossible to maintain during training. This is the fundamental reason why vanilla RNNs fail on long sequences---there is no stable equilibrium for gradient flow.
15.3.2 Practical Consequences
The vanishing gradient problem means that a vanilla RNN effectively has a limited memory horizon. In practice, vanilla RNNs struggle to learn dependencies spanning more than about 10--20 time steps. Consider a language modeling task where the relevant context is 50 words earlier: "The woman who grew up in France ... spoke fluent French." A vanilla RNN would fail to connect "France" to "French" across the intervening text.
15.3.3 Gradient Clipping
While gradient clipping does not solve the vanishing gradient problem, it provides a practical solution to exploding gradients. The idea is simple: if the gradient norm exceeds a threshold $\theta$, rescale it:
$$\hat{\mathbf{g}} = \begin{cases} \mathbf{g} & \text{if } \|\mathbf{g}\| \leq \theta \\ \frac{\theta}{\|\mathbf{g}\|} \mathbf{g} & \text{if } \|\mathbf{g}\| > \theta \end{cases}$$
In PyTorch:
# After loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
optimizer.step()
15.3.4 Other Mitigations
Several techniques partially address gradient flow issues in vanilla RNNs:
- Proper initialization: Initializing $\mathbf{W}_{hh}$ as an orthogonal matrix ensures all singular values are exactly 1 at the start of training, preventing both vanishing and exploding gradients initially. As training proceeds and $\mathbf{W}_{hh}$ changes, this property is not maintained, but it provides a healthier starting point.
- Skip connections: Adding residual connections across time steps, analogous to the skip connections in ResNets (Chapter 13). The hidden state update becomes $\mathbf{h}_t = \tanh(\mathbf{W}_{hh}\mathbf{h}_{t-1} + \mathbf{W}_{xh}\mathbf{x}_t) + \mathbf{h}_{t-1}$, creating an additive path for gradient flow.
- Truncated BPTT: Limiting backpropagation to a fixed number of time steps, as described in Section 15.2.3. This does not solve the vanishing gradient problem but limits its impact.
- Echo State Networks / Reservoir Computing: Fix the recurrent weights (no training) and only train the output layer. This avoids gradient issues entirely but limits expressiveness.
- Norm-preserving architectures: Design $\mathbf{W}_{hh}$ to be unitary (complex-valued) or orthogonal, maintaining gradient norm exactly. The Unitary RNN (Arjovsky et al., 2016) uses this approach.
However, the definitive and most practical solution came with gated architectures: LSTM and GRU. Their gating mechanisms provide a learned mechanism for controlling gradient flow, which is far more flexible than any fixed architectural constraint.
15.4 Long Short-Term Memory (LSTM)
The LSTM, introduced by Hochreiter and Schmidhuber in 1997, is one of the most important innovations in deep learning. It solved the vanishing gradient problem by introducing a cell state that acts as an information highway, allowing gradients to flow unchanged across many time steps.
15.4.1 The Cell State: An Information Highway
The key innovation of LSTM is the cell state $\mathbf{c}_t$, a vector that runs through the entire sequence with only minor linear interactions. Think of it as a conveyor belt: information can be placed on it, read from it, or removed from it, but it flows forward with minimal degradation.
The cell state is regulated by three gates---the forget gate, input gate, and output gate---each of which is a sigmoid-activated layer that outputs values between 0 and 1, controlling how much information flows through.
15.4.2 The Forget Gate
The forget gate decides what information to discard from the cell state. It looks at the previous hidden state $\mathbf{h}_{t-1}$ and the current input $\mathbf{x}_t$, and outputs a value between 0 (completely forget) and 1 (completely remember) for each element of the cell state:
$$\mathbf{f}_t = \sigma(\mathbf{W}_f [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_f)$$
where $\sigma$ is the sigmoid function and $[\mathbf{h}_{t-1}, \mathbf{x}_t]$ denotes concatenation.
For example, in a language model tracking the subject's gender, when the model encounters a new subject, the forget gate should erase the old gender information from the cell state.
15.4.3 The Input Gate
The input gate decides what new information to store in the cell state. This has two parts:
- The input gate layer decides which values to update:
$$\mathbf{i}_t = \sigma(\mathbf{W}_i [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_i)$$
- A candidate cell state is created with tanh:
$$\tilde{\mathbf{c}}_t = \tanh(\mathbf{W}_c [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_c)$$
15.4.4 Cell State Update
The old cell state $\mathbf{c}_{t-1}$ is updated by:
- Multiplying by the forget gate (erasing what should be forgotten)
- Adding the new candidate values scaled by the input gate
$$\mathbf{c}_t = \mathbf{f}_t \odot \mathbf{c}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{c}}_t$$
where $\odot$ denotes element-wise multiplication. This is the critical equation: the cell state update is additive, not multiplicative. This additive structure is what prevents gradient vanishing, because the gradient $\frac{\partial \mathbf{c}_t}{\partial \mathbf{c}_{t-1}} = \mathbf{f}_t$, which can remain close to 1 for long periods.
15.4.5 The Output Gate
The output gate determines what part of the cell state to expose as the hidden state:
$$\mathbf{o}_t = \sigma(\mathbf{W}_o [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_o)$$
$$\mathbf{h}_t = \mathbf{o}_t \odot \tanh(\mathbf{c}_t)$$
The cell state is passed through tanh (to push values between $-1$ and $1$) and then filtered by the output gate. This allows the LSTM to maintain information in the cell state that is not immediately relevant to the current output but may be useful later.
15.4.6 LSTM Summary
Putting it all together, the complete LSTM equations at each time step are:
$$\mathbf{f}_t = \sigma(\mathbf{W}_f [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_f)$$ $$\mathbf{i}_t = \sigma(\mathbf{W}_i [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_i)$$ $$\tilde{\mathbf{c}}_t = \tanh(\mathbf{W}_c [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_c)$$ $$\mathbf{c}_t = \mathbf{f}_t \odot \mathbf{c}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{c}}_t$$ $$\mathbf{o}_t = \sigma(\mathbf{W}_o [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_o)$$ $$\mathbf{h}_t = \mathbf{o}_t \odot \tanh(\mathbf{c}_t)$$
The LSTM has roughly 4x the parameters of a vanilla RNN for the same hidden size, because it has four sets of weight matrices (one for each of the three gates plus the candidate cell state).
15.4.7 Why LSTMs Solve the Vanishing Gradient Problem
The gradient of the loss with respect to the cell state at time $k$ flows through the chain:
$$\frac{\partial \mathbf{c}_t}{\partial \mathbf{c}_k} = \prod_{j=k+1}^{t} \frac{\partial \mathbf{c}_j}{\partial \mathbf{c}_{j-1}} = \prod_{j=k+1}^{t} \mathbf{f}_j$$
Each forget gate $\mathbf{f}_j$ has values between 0 and 1. When the forget gate is close to 1 (remembering), the gradient flows through nearly unchanged. The network learns to set the forget gate appropriately: close to 1 when information should be preserved, and close to 0 when it should be erased. This learned gating mechanism is what enables LSTMs to capture dependencies spanning hundreds of time steps.
15.4.8 LSTM Implementation
import torch
import torch.nn as nn
torch.manual_seed(42)
class LSTMCell(nn.Module):
"""A single LSTM cell with explicit gate computations.
Args:
input_size: Dimensionality of input vectors.
hidden_size: Dimensionality of hidden state and cell state.
"""
def __init__(self, input_size: int, hidden_size: int) -> None:
super().__init__()
self.hidden_size = hidden_size
# Combined weight matrix for efficiency
self.gates = nn.Linear(
input_size + hidden_size, 4 * hidden_size
)
def forward(
self,
x_t: torch.Tensor,
h_prev: torch.Tensor,
c_prev: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute next hidden state and cell state.
Args:
x_t: Input at current time step, shape (batch, input_size).
h_prev: Previous hidden state, shape (batch, hidden_size).
c_prev: Previous cell state, shape (batch, hidden_size).
Returns:
Tuple of (h_t, c_t), each shape (batch, hidden_size).
"""
combined = torch.cat([h_prev, x_t], dim=1)
gates = self.gates(combined)
# Split into four gate values
i_t, f_t, g_t, o_t = gates.chunk(4, dim=1)
i_t = torch.sigmoid(i_t) # Input gate
f_t = torch.sigmoid(f_t) # Forget gate
g_t = torch.tanh(g_t) # Candidate cell state
o_t = torch.sigmoid(o_t) # Output gate
c_t = f_t * c_prev + i_t * g_t
h_t = o_t * torch.tanh(c_t)
return h_t, c_t
15.5 Gated Recurrent Units (GRU)
15.5.1 A Streamlined Alternative
The Gated Recurrent Unit, introduced by Cho et al. in 2014, is a simplified variant of the LSTM that achieves comparable performance with fewer parameters. The GRU merges the cell state and hidden state into a single vector and uses two gates instead of three.
15.5.2 GRU Equations
The GRU uses an update gate $\mathbf{z}_t$ and a reset gate $\mathbf{r}_t$:
$$\mathbf{z}_t = \sigma(\mathbf{W}_z [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_z)$$
$$\mathbf{r}_t = \sigma(\mathbf{W}_r [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_r)$$
The candidate hidden state uses the reset gate to control how much of the previous state to consider:
$$\tilde{\mathbf{h}}_t = \tanh(\mathbf{W}_h [\mathbf{r}_t \odot \mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_h)$$
The final hidden state is a linear interpolation between the previous state and the candidate:
$$\mathbf{h}_t = (1 - \mathbf{z}_t) \odot \mathbf{h}_{t-1} + \mathbf{z}_t \odot \tilde{\mathbf{h}}_t$$
15.5.3 Understanding the Gates
The update gate $\mathbf{z}_t$ controls the balance between keeping the old state and accepting the new candidate. When $\mathbf{z}_t \approx 0$, the hidden state is simply copied forward (like an LSTM forget gate close to 1). When $\mathbf{z}_t \approx 1$, the hidden state is replaced by the candidate.
The reset gate $\mathbf{r}_t$ controls how much of the previous hidden state to use when computing the candidate. When $\mathbf{r}_t \approx 0$, the candidate is computed as if the previous state were zero, allowing the network to "forget" and start fresh. When $\mathbf{r}_t \approx 1$, the candidate considers the full previous state.
15.5.4 LSTM vs. GRU
| Feature | LSTM | GRU |
|---|---|---|
| Gates | 3 (forget, input, output) | 2 (update, reset) |
| State vectors | 2 (hidden + cell) | 1 (hidden) |
| Parameters | 4 weight matrices | 3 weight matrices |
| Performance | Slightly better on long sequences | Comparable overall |
| Training speed | Slower (more parameters) | Faster |
| When to use | Long dependencies, complex patterns | Smaller datasets, faster training |
In practice, the performance difference between LSTM and GRU is often marginal. The GRU trains faster due to fewer parameters, while the LSTM may have a slight edge on tasks requiring very long-range memory. When in doubt, try both and compare.
Detailed Comparison. The GRU achieves comparable performance to the LSTM with fewer parameters because:
-
Combined forget and input gates. In the GRU, the update gate $\mathbf{z}_t$ controls both forgetting and inputting. When $\mathbf{z}_t \approx 0$, the old state is kept (analogous to the LSTM forget gate being close to 1 and the input gate being close to 0). This coupling means the GRU cannot independently control forgetting and inputting, which is a mild limitation but saves an entire gate's worth of parameters.
-
No separate cell state. The LSTM maintains two state vectors ($\mathbf{h}_t$ and $\mathbf{c}_t$), while the GRU uses only $\mathbf{h}_t$. The LSTM's cell state provides a dedicated "memory highway" that is distinct from the output. The GRU merges these roles, making its hidden state do double duty.
-
No output gate. The LSTM's output gate controls how much of the cell state to expose. The GRU exposes its entire hidden state at every step. This means the GRU cannot selectively hide information from downstream layers, which is occasionally useful for tasks where the relevant memory content changes over time.
Empirically, Chung et al. (2014) found that both architectures outperform vanilla RNNs by a large margin, but neither consistently outperforms the other across all tasks. Greff et al. (2017) conducted an extensive ablation study of LSTM variants and found that the forget gate and output gate are the most critical components---removing either significantly degrades performance. The GRU's update gate functions similarly to the forget gate, which may explain its competitive performance.
15.5.5 GRU Implementation
class GRUCell(nn.Module):
"""A single GRU cell with explicit gate computations.
Args:
input_size: Dimensionality of input vectors.
hidden_size: Dimensionality of hidden state.
"""
def __init__(self, input_size: int, hidden_size: int) -> None:
super().__init__()
self.hidden_size = hidden_size
self.W_z = nn.Linear(input_size + hidden_size, hidden_size)
self.W_r = nn.Linear(input_size + hidden_size, hidden_size)
self.W_h = nn.Linear(input_size + hidden_size, hidden_size)
def forward(
self, x_t: torch.Tensor, h_prev: torch.Tensor
) -> torch.Tensor:
"""Compute next hidden state.
Args:
x_t: Input at current time step, shape (batch, input_size).
h_prev: Previous hidden state, shape (batch, hidden_size).
Returns:
New hidden state, shape (batch, hidden_size).
"""
combined = torch.cat([h_prev, x_t], dim=1)
z_t = torch.sigmoid(self.W_z(combined))
r_t = torch.sigmoid(self.W_r(combined))
reset_combined = torch.cat([r_t * h_prev, x_t], dim=1)
h_candidate = torch.tanh(self.W_h(reset_combined))
h_t = (1 - z_t) * h_prev + z_t * h_candidate
return h_t
15.6 Bidirectional RNNs
15.6.1 Motivation
In many tasks, the context needed to interpret a position comes from both directions. Consider named entity recognition: in "He went to Washington to meet the president," identifying "Washington" as a place (rather than a person) benefits from seeing "to meet the president" after it.
A standard RNN processes the sequence left-to-right, so $\mathbf{h}_t$ only captures information from $\mathbf{x}_1, \ldots, \mathbf{x}_t$. A bidirectional RNN adds a second RNN that processes the sequence right-to-left:
$$\overrightarrow{\mathbf{h}}_t = \text{RNN}_{\text{fwd}}(\mathbf{x}_t, \overrightarrow{\mathbf{h}}_{t-1})$$
$$\overleftarrow{\mathbf{h}}_t = \text{RNN}_{\text{bwd}}(\mathbf{x}_t, \overleftarrow{\mathbf{h}}_{t+1})$$
The final representation at each time step concatenates both:
$$\mathbf{h}_t = [\overrightarrow{\mathbf{h}}_t; \overleftarrow{\mathbf{h}}_t]$$
15.6.2 When to Use Bidirectional RNNs
Bidirectional RNNs are appropriate when the entire input sequence is available before making predictions. This includes:
- Text classification (sentiment analysis)
- Named entity recognition
- Part-of-speech tagging
- Machine translation (encoder side)
- Speech recognition (with sufficient buffer)
Bidirectional RNNs are not appropriate for tasks requiring real-time, left-to-right generation, such as language modeling or online speech synthesis, because the backward pass requires seeing the full sequence first.
Extracting representations from bidirectional RNNs. For sequence classification tasks (many-to-one), you need to reduce the sequence of hidden states into a single vector. Common strategies include:
-
Concatenate the final hidden states from both directions: $\mathbf{h} = [\overrightarrow{\mathbf{h}}_T; \overleftarrow{\mathbf{h}}_1]$. This captures the end of the forward pass and the end of the backward pass (which corresponds to the start of the sequence).
-
Mean pooling over all time steps: $\mathbf{h} = \frac{1}{T} \sum_{t=1}^{T} [\overrightarrow{\mathbf{h}}_t; \overleftarrow{\mathbf{h}}_t]$. This gives equal weight to all positions.
-
Max pooling over each dimension: $h_j = \max_t [\overrightarrow{h}_{t,j}; \overleftarrow{h}_{t,j}]$. This captures the strongest activation at any position.
-
Attention pooling: Learn a weighted average, where the weights depend on the hidden states (a precursor to the self-attention mechanism we will explore in Chapter 18).
15.6.3 PyTorch Implementation
PyTorch makes bidirectional RNNs trivial:
import torch
import torch.nn as nn
torch.manual_seed(42)
# Bidirectional LSTM
bilstm = nn.LSTM(
input_size=128,
hidden_size=256,
num_layers=2,
batch_first=True,
bidirectional=True,
dropout=0.3,
)
# Input: (batch_size, seq_len, input_size)
x = torch.randn(32, 50, 128)
output, (h_n, c_n) = bilstm(x)
# output shape: (32, 50, 512) -- 256 * 2 due to bidirectional
# h_n shape: (4, 32, 256) -- 2 layers * 2 directions
print(f"Output shape: {output.shape}")
print(f"Hidden shape: {h_n.shape}")
15.7 Deep RNNs and Stacking Layers
15.7.1 Why Go Deeper?
Just as deeper CNNs learn increasingly abstract features, stacking multiple RNN layers creates a hierarchy of temporal representations. The first layer captures low-level patterns (e.g., phonemes in speech), while higher layers capture more abstract patterns (e.g., words, phrases).
In a deep RNN with $L$ layers, the hidden state of layer $l$ at time $t$ is:
$$\mathbf{h}_t^{(l)} = \text{RNN}^{(l)}(\mathbf{h}_t^{(l-1)}, \mathbf{h}_{t-1}^{(l)})$$
where $\mathbf{h}_t^{(0)} = \mathbf{x}_t$.
15.7.2 Dropout in RNNs
Applying dropout to RNNs requires care. Naive dropout at every time step would destroy the temporal information in the hidden state. The standard approach is to apply dropout between layers (vertically) but not across time steps (horizontally):
# PyTorch handles this automatically with the dropout parameter
deep_lstm = nn.LSTM(
input_size=128,
hidden_size=256,
num_layers=3, # Three stacked LSTM layers
batch_first=True,
dropout=0.5, # Applied between layers (not on last layer)
)
For more sophisticated dropout, variational dropout (Gal and Ghahramani, 2016) uses the same dropout mask at every time step within a sequence, which preserves more temporal information.
15.7.3 Residual Connections in Deep RNNs
For deep RNNs with more than 2--3 layers, training can become difficult due to gradient degradation across the layer stack (a vertical analogue of the temporal vanishing gradient problem). Residual connections, inspired by ResNets (Chapter 13), help by adding a skip connection from the input of each layer to its output:
$$\mathbf{h}_t^{(l)} = \text{RNN}^{(l)}(\mathbf{h}_t^{(l-1)}, \mathbf{h}_{t-1}^{(l)}) + \mathbf{h}_t^{(l-1)}$$
This requires that the input and output dimensions of each RNN layer match. If they differ, a linear projection can be used on the skip path. Residual connections are especially important for deep bidirectional models like those used in speech recognition, where stacking 5--8 layers is common.
15.7.4 Practical Considerations for Deep RNNs
- Depth: 2--4 layers is typical. Beyond 4 layers, training becomes difficult without residual connections.
- Hidden size: Common choices range from 128 to 1024. Larger hidden sizes increase capacity but also computation time and overfitting risk.
- Gradient clipping: Essential for deep RNNs. A max norm of 1.0--5.0 is standard.
- Layer normalization: Can stabilize training in deep RNNs. Apply it within each RNN cell, normalizing the pre-activation values before the nonlinearity.
- Regularization: Apply dropout between layers (not across time steps). For very deep RNNs, weight tying (sharing input and output embeddings) and variational dropout provide additional regularization.
15.8 Sequence-to-Sequence Models
15.8.1 The Encoder-Decoder Architecture
Many important tasks involve mapping one sequence to another of a potentially different length: machine translation (English to French), text summarization (long document to short summary), or speech recognition (audio to text). The sequence-to-sequence (seq2seq) framework, introduced by Sutskever et al. (2014), handles these tasks elegantly.
The architecture has two components:
- Encoder: Reads the input sequence and compresses it into a fixed-length context vector $\mathbf{c}$
- Decoder: Generates the output sequence one token at a time, conditioned on the context vector
Encoder
The encoder processes the input sequence $(\mathbf{x}_1, \ldots, \mathbf{x}_T)$ and produces hidden states. The final hidden state (or a transformation of it) becomes the context vector:
$$\mathbf{h}_t^{\text{enc}} = \text{LSTM}_{\text{enc}}(\mathbf{x}_t, \mathbf{h}_{t-1}^{\text{enc}})$$
$$\mathbf{c} = \mathbf{h}_T^{\text{enc}}$$
Decoder
The decoder is an RNN that generates the output sequence step by step. At each step, it takes the previous output token and the hidden state:
$$\mathbf{h}_t^{\text{dec}} = \text{LSTM}_{\text{dec}}(\mathbf{y}_{t-1}, \mathbf{h}_{t-1}^{\text{dec}})$$
$$P(\mathbf{y}_t | \mathbf{y}_{ The decoder is initialized with the context vector: $\mathbf{h}_0^{\text{dec}} = \mathbf{c}$. A critical limitation of the basic seq2seq model is that the entire input sequence must be compressed into a single fixed-length vector $\mathbf{c}$. For short sequences, this works well. But for long sequences, this creates an information bottleneck: the context vector cannot retain all the nuances of a 100-word sentence. Quantifying the bottleneck. Cho et al. (2014) demonstrated this empirically by measuring BLEU scores (a machine translation quality metric) as a function of input sentence length. For basic seq2seq models, performance degraded dramatically for sentences longer than about 20--30 words. With attention (Section 15.11), this degradation was largely eliminated, confirming that the fixed-length bottleneck was the primary cause. Strategies to mitigate the bottleneck without attention:
1. Reverse the input sequence. Sutskever et al. (2014) found that reversing the source sentence improved translation quality. The intuition is that reversing places the first few source words closer to the first few target words in the unrolled computation graph, reducing the distance that gradients must travel for the most important early alignments.
2. Use a bidirectional encoder. The final context vector concatenates the forward and backward final states, giving the decoder information about both the beginning and end of the source sequence.
3. Use a deeper encoder. More layers increase the model's capacity to compress the sequence into a rich context vector. However, none of these strategies fully solve the problem. The definitive solution is the attention mechanism, which we preview in Section 15.11 and cover in depth in Chapter 18. During training, the decoder must decide what input to use at each time step. There are two options: Teacher forcing dramatically accelerates training. In free-running mode, if the model makes a wrong prediction at time step $t$, the subsequent inputs are corrupted, leading to a cascade of errors. Early in training, when the model's predictions are nearly random, this makes learning extremely slow. With teacher forcing, each time step receives correct context, allowing the model to learn the conditional distribution $P(y_t | y_{ Teacher forcing introduces a discrepancy between training and inference. During training, the model always sees correct previous tokens. During inference, it sees its own (potentially incorrect) predictions. This mismatch is called exposure bias. Mitigation strategies include: At inference time, the decoder generates tokens one at a time. The simplest strategy is greedy decoding: at each step, pick the token with the highest probability. But greedy decoding can miss better sequences because a locally optimal choice at step $t$ may lead to a globally suboptimal sequence. For example, consider two partial sequences:
- "The dog" (probability 0.6) followed by "runs" (probability 0.1): total 0.06
- "A cat" (probability 0.4) followed by "sleeps" (probability 0.5): total 0.20 Greedy decoding would choose "The dog" at step 1, missing the higher-probability sequence. Beam search maintains a set of $k$ best partial hypotheses (the "beam") at each time step, where $k$ is the beam width. At each step: The probability of a sequence is: $$\log P(\mathbf{y} | \mathbf{x}) = \sum_{t=1}^{T} \log P(y_t | y_{ Beam search has a bias toward shorter sequences because each additional token multiplies a probability less than 1 (adds a negative log probability). To counteract this, we normalize by sequence length: $$\text{score}(\mathbf{y}) = \frac{1}{|\mathbf{y}|^\alpha} \sum_{t=1}^{|\mathbf{y}|} \log P(y_t | y_{ where $\alpha \in [0, 1]$ controls the strength of length normalization. With $\alpha = 0$, there is no normalization; with $\alpha = 1$, we normalize fully by length. Typical beam widths range from 4 to 10. Larger beams improve quality but increase computation linearly. In practice: Beyond a certain point, increasing beam width provides diminishing returns and can even degrade quality (the "beam search curse"), where longer, lower-quality sequences are favored. The attention mechanism, introduced by Bahdanau et al. (2015), addresses the information bottleneck of seq2seq models. Instead of compressing the entire input into a single context vector, attention allows the decoder to look back at all encoder hidden states and focus on the most relevant ones at each decoding step. At each decoder time step $t$, attention computes: $$e_{t,s} = \text{score}(\mathbf{h}_t^{\text{dec}}, \mathbf{h}_s^{\text{enc}})$$ $$\alpha_{t,s} = \frac{\exp(e_{t,s})}{\sum_{s'=1}^{S} \exp(e_{t,s'})}$$ $$\mathbf{c}_t = \sum_{s=1}^{S} \alpha_{t,s} \mathbf{h}_s^{\text{enc}}$$ This context vector $\mathbf{c}_t$ is then used alongside the decoder hidden state to predict the output. The alignment score function $\text{score}(\mathbf{h}_t^{\text{dec}}, \mathbf{h}_s^{\text{enc}})$ can take several forms: Additive attention (Bahdanau, 2015):
$$e_{t,s} = \mathbf{v}^T \tanh(\mathbf{W}_1 \mathbf{h}_t^{\text{dec}} + \mathbf{W}_2 \mathbf{h}_s^{\text{enc}})$$ where $\mathbf{W}_1$, $\mathbf{W}_2$, and $\mathbf{v}$ are learned parameters. Multiplicative (dot-product) attention (Luong, 2015):
$$e_{t,s} = (\mathbf{h}_t^{\text{dec}})^T \mathbf{h}_s^{\text{enc}}$$ This is computationally cheaper (a simple dot product) and scales better to large hidden dimensions. Scaled dot-product attention (Vaswani et al., 2017):
$$e_{t,s} = \frac{(\mathbf{h}_t^{\text{dec}})^T \mathbf{h}_s^{\text{enc}}}{\sqrt{d}}$$ where $d$ is the dimension of the hidden state. The scaling factor $\frac{1}{\sqrt{d}}$ prevents the dot products from growing too large in high dimensions, which would push the softmax into saturation and produce near-one-hot attention weights. This scaled dot-product attention is the foundation of the Transformer architecture that we will explore in Chapter 18. Attention transformed sequence modeling: The progression from seq2seq to seq2seq-with-attention to Transformers is one of the most important narrative arcs in modern deep learning. Each step removed a limitation: seq2seq removed the fixed-length bottleneck of simple RNNs; attention removed the information bottleneck of seq2seq; and Transformers removed the sequential processing bottleneck of RNN-based attention models. Understanding this progression---which you are now well positioned to do---is essential for appreciating why Transformers work and what they replaced. We will explore attention mechanisms thoroughly in Chapter 18. Before covering practical considerations, it is worth comparing RNNs with an alternative approach to sequence modeling: 1D convolutional neural networks (1D CNNs). As we saw in Chapter 13, convolutions are powerful for extracting local patterns from grid-structured data. They can also be applied to sequences. A 1D convolutional layer slides a filter of width $k$ across the sequence, computing a weighted sum at each position. For a sequence of embeddings $\mathbf{x}_1, \ldots, \mathbf{x}_T$, a 1D convolution with filter $\mathbf{w} \in \mathbb{R}^{k \times d}$ produces: $$\mathbf{h}_t = f\left(\sum_{i=0}^{k-1} \mathbf{w}_i \cdot \mathbf{x}_{t+i} + b\right)$$ Stacking multiple convolutional layers with increasing dilation rates (as in WaveNet) allows the receptive field to grow exponentially with depth, enabling long-range dependencies without recurrence. When to prefer 1D CNNs. If your task involves detecting local patterns (e.g., n-gram features in text classification, motifs in genomic sequences), 1D CNNs are fast and effective. They also parallelize perfectly on GPUs, making training significantly faster than RNNs. When to prefer RNNs. If your task requires maintaining long-term state (e.g., tracking the subject of a sentence across many clauses) or online processing (e.g., real-time speech recognition), RNNs are more natural. The hidden state provides a compact summary of the entire past that 1D CNNs lack. Hybrid approaches. Some architectures combine both: use 1D convolutions for local feature extraction and feed the resulting features into an RNN for global temporal reasoning. This can be faster than a pure RNN approach while retaining the ability to model long-range dependencies. In practice, both RNNs and 1D CNNs have been largely superseded by Transformers (Chapter 18) for most NLP tasks, but they remain important for time series, audio processing, and resource-constrained settings. For text data, inputs are typically represented as word embeddings. Each token in the vocabulary is mapped to a dense vector: You can initialize embeddings randomly (trained from scratch) or use pretrained embeddings like GloVe or Word2Vec. Pretrained embeddings provide a strong starting point, especially with limited data. For numerical time series, inputs may be raw values, normalized values, or engineered features. Sequences in a batch typically have different lengths. PyTorch handles this with padding and packing: Packing ensures that the LSTM does not process padding tokens, which improves both correctness and efficiency. Based on extensive empirical research and practical experience: Based on years of accumulated experience from the research community and practitioners: 1. Always use gradient clipping. This is non-negotiable for RNN training. Set 2. Initialize forget gate biases to 1. For LSTMs, set the forget gate bias to 1.0 (or even 2.0) at initialization. This encourages the network to remember by default, allowing it to learn what to forget rather than what to remember. This simple trick, suggested by Jozefowicz et al. (2015), can significantly improve performance on tasks with long-range dependencies. 3. Use orthogonal initialization for recurrent weights. Initializing $\mathbf{W}_{hh}$ as an orthogonal matrix (all singular values equal to 1) prevents both vanishing and exploding gradients at the start of training: 4. Sort sequences by length. When training with mini-batches, sort sequences by length within each batch (or use 5. Use bidirectional RNNs for classification. If the full sequence is available and you are performing classification or tagging, bidirectional RNNs almost always outperform unidirectional ones. The improvement is essentially free (only 2x parameters). 6. Start with a single-layer LSTM. Do not jump to deep, multi-layer RNNs immediately. A single-layer LSTM with 256--512 hidden units is a strong baseline for most tasks. Add layers only if you have enough data and the single-layer model is clearly underfitting. 7. Monitor hidden state norms. During training, log the L2 norm of the hidden state at each time step. If the norm grows without bound over the sequence, you likely have exploding activations (even if gradients are clipped). Reduce the hidden size or add layer normalization. Since 2017, Transformer-based models have largely replaced RNNs for many sequence tasks, especially in NLP. Transformers process all positions in parallel (rather than sequentially), which enables much faster training and better scaling. However, RNNs remain relevant in several contexts: Recent architectures like Mamba and S4 (Structured State Spaces for Sequence Modeling) draw inspiration from both RNNs and linear dynamical systems. They maintain the sequential processing advantage of RNNs while achieving Transformer-like performance, hinting that the story of recurrence is far from over. The core idea of state space models is to model the hidden state dynamics as a continuous-time linear system: $$\frac{d\mathbf{h}(t)}{dt} = \mathbf{A}\mathbf{h}(t) + \mathbf{B}\mathbf{x}(t), \quad \mathbf{y}(t) = \mathbf{C}\mathbf{h}(t)$$ When discretized, this becomes a recurrence similar to an RNN, but with carefully structured weight matrices that enable both efficient sequential processing (like an RNN) and parallel training (like a CNN). The key innovation is that the matrix $\mathbf{A}$ is parameterized using special structures (such as diagonal plus low-rank, or the HiPPO framework) that enable long-range memory. Mamba (Gu and Dao, 2023) further adds input-dependent gating, making the state space model's dynamics depend on the content being processed---much like LSTM gates, but with the efficiency advantages of the state space formulation. These models represent an exciting convergence of ideas from control theory, signal processing, and deep learning. Choose RNNs when:
- You need to process truly streaming data
- Computational resources are limited
- Sequence lengths are moderate (< 1000 tokens)
- You are working on time series forecasting
- Interpretability of the hidden state dynamics matters Choose Transformers when:
- You have large datasets and computational resources
- Sequences are long and require global attention
- You can benefit from pretrained models (BERT, GPT)
- Parallelism during training is important This chapter has taken you on a journey through the evolution of recurrent architectures: Vanilla RNNs introduced the recurrence relation, enabling sequential processing with shared parameters, but suffered from vanishing and exploding gradients. LSTMs solved the vanishing gradient problem with gated cell states, using forget, input, and output gates to control information flow. The additive cell state update is the key innovation. GRUs simplified the LSTM architecture to two gates (update and reset) while maintaining comparable performance with fewer parameters. Bidirectional RNNs process sequences in both directions, capturing context from past and future for classification and labeling tasks. Sequence-to-sequence models use an encoder-decoder architecture to map variable-length input sequences to variable-length output sequences. Teacher forcing accelerates training by providing ground-truth tokens as decoder input, though it introduces exposure bias. Beam search improves inference quality over greedy decoding by maintaining multiple hypotheses. Attention mechanisms preview how the information bottleneck is resolved by allowing the decoder to focus on relevant parts of the input at each step. The progression from fixed context vectors to attention to self-attention (Transformers) is one of the most important evolutionary paths in deep learning. 1D convolutions offer a parallel-friendly alternative to RNNs for local pattern extraction in sequences, though they lack the adaptive memory that makes RNNs powerful for long-range dependencies. Practical training of RNNs requires gradient clipping, careful initialization (orthogonal weights, forget gate bias of 1), and attention to sequence padding and packing. The key takeaway is that sequence modeling requires architectures that can maintain and update a representation of context over time. Vanilla RNNs introduced this idea but suffered from vanishing gradients. LSTMs and GRUs solved this with learned gating mechanisms that control information flow. While Transformers have largely replaced RNNs for many tasks, the concepts of sequential processing, hidden state dynamics, and gated information flow remain central to understanding modern deep learning. In the next chapter, we will explore attention mechanisms in depth and build the foundation for understanding the Transformer architecture, which has become the dominant paradigm in modern deep learning for sequences.15.8.2 The Information Bottleneck
15.8.3 Seq2Seq Implementation
import torch
import torch.nn as nn
torch.manual_seed(42)
class Encoder(nn.Module):
"""Seq2seq encoder using LSTM.
Args:
vocab_size: Size of the source vocabulary.
embed_dim: Dimensionality of word embeddings.
hidden_size: Dimensionality of LSTM hidden state.
num_layers: Number of stacked LSTM layers.
dropout: Dropout rate between LSTM layers.
"""
def __init__(
self,
vocab_size: int,
embed_dim: int,
hidden_size: int,
num_layers: int = 2,
dropout: float = 0.3,
) -> None:
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(
embed_dim,
hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=dropout if num_layers > 1 else 0.0,
)
self.dropout = nn.Dropout(dropout)
def forward(
self, src: torch.Tensor
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""Encode source sequence.
Args:
src: Source token IDs, shape (batch_size, src_len).
Returns:
Tuple of (encoder_outputs, (h_n, c_n)).
"""
embedded = self.dropout(self.embedding(src))
outputs, (h_n, c_n) = self.lstm(embedded)
return outputs, (h_n, c_n)
class Decoder(nn.Module):
"""Seq2seq decoder using LSTM.
Args:
vocab_size: Size of the target vocabulary.
embed_dim: Dimensionality of word embeddings.
hidden_size: Dimensionality of LSTM hidden state.
num_layers: Number of stacked LSTM layers.
dropout: Dropout rate between LSTM layers.
"""
def __init__(
self,
vocab_size: int,
embed_dim: int,
hidden_size: int,
num_layers: int = 2,
dropout: float = 0.3,
) -> None:
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(
embed_dim,
hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=dropout if num_layers > 1 else 0.0,
)
self.fc_out = nn.Linear(hidden_size, vocab_size)
self.dropout = nn.Dropout(dropout)
def forward(
self,
trg_token: torch.Tensor,
hidden: tuple[torch.Tensor, torch.Tensor],
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""Decode one time step.
Args:
trg_token: Target token IDs, shape (batch_size, 1).
hidden: Previous (h, c) states from LSTM.
Returns:
Tuple of (prediction, (h_n, c_n)).
"""
embedded = self.dropout(self.embedding(trg_token))
output, (h_n, c_n) = self.lstm(embedded, hidden)
prediction = self.fc_out(output.squeeze(1))
return prediction, (h_n, c_n)
class Seq2Seq(nn.Module):
"""Complete sequence-to-sequence model.
Args:
encoder: The encoder module.
decoder: The decoder module.
"""
def __init__(self, encoder: Encoder, decoder: Decoder) -> None:
super().__init__()
self.encoder = encoder
self.decoder = decoder
def forward(
self,
src: torch.Tensor,
trg: torch.Tensor,
teacher_forcing_ratio: float = 0.5,
) -> torch.Tensor:
"""Run seq2seq forward pass.
Args:
src: Source token IDs, shape (batch_size, src_len).
trg: Target token IDs, shape (batch_size, trg_len).
teacher_forcing_ratio: Probability of using ground truth
as next input during training.
Returns:
Predictions, shape (batch_size, trg_len, vocab_size).
"""
batch_size, trg_len = trg.shape
vocab_size = self.decoder.fc_out.out_features
outputs = torch.zeros(
batch_size, trg_len, vocab_size, device=src.device
)
_, hidden = self.encoder(src)
input_token = trg[:, 0:1] # <SOS> token
for t in range(1, trg_len):
prediction, hidden = self.decoder(input_token, hidden)
outputs[:, t, :] = prediction
if torch.rand(1).item() < teacher_forcing_ratio:
input_token = trg[:, t:t+1]
else:
input_token = prediction.argmax(dim=1, keepdim=True)
return outputs
15.9 Teacher Forcing
15.9.1 The Training Dilemma
15.9.2 Why Teacher Forcing Helps
15.9.3 The Exposure Bias Problem
def get_teacher_forcing_ratio(
epoch: int, total_epochs: int, strategy: str = "linear"
) -> float:
"""Compute teacher forcing ratio with scheduled decay.
Args:
epoch: Current training epoch (0-indexed).
total_epochs: Total number of training epochs.
strategy: Decay strategy, one of "linear" or "exponential".
Returns:
Teacher forcing ratio between 0 and 1.
"""
if strategy == "linear":
return max(0.0, 1.0 - epoch / total_epochs)
elif strategy == "exponential":
return 0.99 ** epoch
else:
raise ValueError(f"Unknown strategy: {strategy}")
15.10 Beam Search
15.10.1 Beyond Greedy Decoding
15.10.2 The Beam Search Algorithm
15.10.3 Length Normalization
15.10.4 Beam Search Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(42)
def beam_search(
decoder: nn.Module,
initial_hidden: tuple[torch.Tensor, torch.Tensor],
sos_token: int,
eos_token: int,
beam_width: int = 5,
max_len: int = 50,
length_penalty_alpha: float = 0.7,
) -> list[tuple[list[int], float]]:
"""Perform beam search decoding.
Args:
decoder: The decoder module (processes one token at a time).
initial_hidden: Initial decoder hidden state (h, c).
sos_token: Start-of-sequence token ID.
eos_token: End-of-sequence token ID.
beam_width: Number of hypotheses to maintain.
max_len: Maximum output sequence length.
length_penalty_alpha: Length normalization exponent.
Returns:
List of (token_ids, score) tuples sorted by score (best first).
"""
device = initial_hidden[0].device
# Each beam: (token_ids, log_prob, hidden_state, is_complete)
beams = [(
[sos_token],
0.0,
initial_hidden,
False,
)]
completed = []
for _ in range(max_len):
candidates = []
for tokens, log_prob, hidden, is_complete in beams:
if is_complete:
completed.append((tokens, log_prob))
continue
# Get decoder prediction for last token
input_token = torch.tensor(
[[tokens[-1]]], device=device
)
prediction, new_hidden = decoder(input_token, hidden)
log_probs = F.log_softmax(prediction, dim=-1)
# Get top-k tokens
topk_log_probs, topk_indices = log_probs.topk(beam_width)
for i in range(beam_width):
next_token = topk_indices[0, i].item()
next_log_prob = log_prob + topk_log_probs[0, i].item()
next_tokens = tokens + [next_token]
is_done = next_token == eos_token
candidates.append((
next_tokens,
next_log_prob,
new_hidden,
is_done,
))
if not candidates:
break
# Sort by log probability and keep top beam_width
candidates.sort(key=lambda x: x[1], reverse=True)
beams = candidates[:beam_width]
# Add any remaining incomplete beams
for tokens, log_prob, _, _ in beams:
completed.append((tokens, log_prob))
# Apply length normalization and sort
scored = []
for tokens, log_prob in completed:
length = len(tokens)
normalized_score = log_prob / (length ** length_penalty_alpha)
scored.append((tokens, normalized_score))
scored.sort(key=lambda x: x[1], reverse=True)
return scored
15.10.5 Practical Beam Search Settings
15.11 Attention Mechanism Preview
15.11.1 The Core Idea
15.11.2 Scoring Functions
15.11.3 Impact of Attention
15.12 Comparison with 1D Convolutions for Sequences
15.12.1 1D Convolutions for Sequences
15.12.2 RNNs vs. 1D CNNs
Aspect
RNNs (LSTM/GRU)
1D CNNs
Receptive field
Theoretically infinite (entire history)
Fixed by kernel size and depth
Parallelism
Sequential (cannot parallelize across time)
Fully parallel across time
Training speed
Slow (sequential computation)
Fast (parallel computation)
Memory
Adaptive (hidden state evolves)
Fixed (captured by filter patterns)
Long-range dependencies
Good with LSTM/GRU gates
Requires deep stacks or dilated convolutions
Streaming/online
Natural (process one token at a time)
Requires buffering a window
Parameter count
Grows with hidden size, not sequence length
Grows with kernel size and number of filters
15.13 Practical Considerations
15.13.1 Input Representation
embedding = nn.Embedding(num_embeddings=10000, embedding_dim=300)
15.13.2 Padding and Packing
from torch.nn.utils.rnn import (
pack_padded_sequence,
pad_packed_sequence,
)
# Assume sequences is a padded tensor: (batch, max_len, features)
# lengths is a tensor of actual sequence lengths
packed = pack_padded_sequence(
sequences, lengths, batch_first=True, enforce_sorted=False
)
output, (h_n, c_n) = lstm(packed)
output, output_lengths = pad_packed_sequence(output, batch_first=True)
15.13.3 Choosing Hyperparameters
Hyperparameter
Recommended Range
Notes
Hidden size
128--1024
Scale with data size
Number of layers
1--4
2 is a strong default
Dropout
0.2--0.5
Higher for larger models
Learning rate
1e-4 to 1e-2
Adam optimizer
Gradient clip norm
1.0--5.0
Essential for stability
Embedding dim
100--300
256 is a good default
15.13.4 Practical Tips for Training RNNs
max_norm between 1.0 and 5.0. Without it, a single batch can produce gradients large enough to destroy your model.# After creating an LSTM layer, initialize forget gate biases
for name, param in lstm.named_parameters():
if "bias" in name:
n = param.size(0)
# bias is structured as [input_gate, forget_gate, cell_gate, output_gate]
param.data[n // 4 : n // 2].fill_(1.0)
for name, param in model.named_parameters():
if "weight_hh" in name:
torch.nn.init.orthogonal_(param)
pack_padded_sequence). This minimizes the amount of padding and avoids wasting computation on padding tokens.15.13.5 Common Pitfalls
15.14 RNNs in the Modern Landscape
15.14.1 The Rise of Transformers
15.14.2 State Space Models
15.14.3 When to Choose RNNs
15.15 Summary
References