36 min read

> "The only reason for time is so that everything doesn't happen at once."

Chapter 15: Recurrent Neural Networks and Sequence Modeling

"The only reason for time is so that everything doesn't happen at once." --- Albert Einstein

In the previous chapters, we explored feedforward networks and convolutional architectures that process fixed-size inputs. But the world is fundamentally sequential. Language unfolds word by word. Stock prices evolve tick by tick. Sensor readings stream continuously. Music progresses note by note. To model these phenomena, we need architectures that can process sequences of arbitrary length while maintaining memory of what came before.

Recurrent Neural Networks (RNNs) introduced a revolutionary idea: networks with loops. By feeding a hidden state from one time step to the next, RNNs create a form of memory that allows information to persist across time. This simple concept unlocked an entire category of problems that feedforward networks could not address, from machine translation to speech recognition to time series forecasting.

In this chapter, we will build RNNs from the ground up. We will start with the vanilla RNN, understand why it struggles with long sequences, and then explore the gated architectures---Long Short-Term Memory (LSTM) and Gated Recurrent Unit (GRU)---that solved those problems. We will cover bidirectional processing, sequence-to-sequence models, teacher forcing, and beam search. By the end, you will have a thorough understanding of how to model sequential data and the tools to implement these models in PyTorch.


15.1 The Nature of Sequential Data

Before diving into architectures, let us understand what makes sequential data fundamentally different from the static data we have processed so far.

15.1.1 Why Sequences Are Different

Consider the sentence "The cat sat on the mat." Each word's meaning is influenced by the words that precede it. The word "sat" makes sense because we know the subject is "cat." If we shuffled the words randomly---"mat the on cat sat the"---the meaning disintegrates. Order matters.

Sequential data appears everywhere in engineering and science:

  • Natural language: Words form sentences; sentences form documents
  • Time series: Stock prices, sensor readings, weather data
  • Audio: Waveforms sampled at regular intervals
  • Video: Sequences of image frames
  • Genomics: DNA sequences of nucleotide bases (A, T, G, C)
  • User behavior: Clickstreams, purchase histories, navigation paths

What these domains share is temporal or positional dependence: the meaning or value at position $t$ depends on what happened at positions $t-1, t-2, \ldots$

15.1.2 Modeling Requirements for Sequences

An architecture designed for sequences must satisfy several requirements:

  1. Variable-length input handling: Sequences can be 5 tokens or 5,000 tokens long.
  2. Parameter sharing across time: The same transformation should apply at each time step; we should not learn separate weights for position 1 versus position 100.
  3. Memory: The model must maintain a summary of past inputs that it can use when processing the current input.
  4. Compositionality: The model should be able to build complex representations from simpler ones over time.

A naive approach might flatten a sequence into a single long vector and feed it to a feedforward network. This fails for several reasons: it cannot handle variable-length inputs, it requires separate parameters for each position (no sharing), and it scales poorly. RNNs address all of these issues elegantly.

15.1.3 Sequence Modeling Tasks

RNNs support several task configurations:

Task Type Input Output Example
Many-to-one Sequence Single value Sentiment analysis
One-to-many Single value Sequence Image captioning
Many-to-many (aligned) Sequence Sequence (same length) Part-of-speech tagging
Many-to-many (unaligned) Sequence Sequence (different length) Machine translation

This flexibility makes RNNs one of the most versatile architectural families in deep learning.


15.2 Vanilla Recurrent Neural Networks

15.2.1 The Recurrence Relation

The core idea of an RNN is breathtakingly simple. At each time step $t$, the network takes two inputs: the current input $\mathbf{x}_t$ and the previous hidden state $\mathbf{h}_{t-1}$. It produces a new hidden state $\mathbf{h}_t$:

$$\mathbf{h}_t = \tanh(\mathbf{W}_{hh}\mathbf{h}_{t-1} + \mathbf{W}_{xh}\mathbf{x}_t + \mathbf{b}_h)$$

The output at time step $t$ is computed from the hidden state:

$$\mathbf{y}_t = \mathbf{W}_{hy}\mathbf{h}_t + \mathbf{b}_y$$

Here: - $\mathbf{x}_t \in \mathbb{R}^d$ is the input at time step $t$ - $\mathbf{h}_t \in \mathbb{R}^n$ is the hidden state at time step $t$ - $\mathbf{W}_{xh} \in \mathbb{R}^{n \times d}$ maps inputs to hidden states - $\mathbf{W}_{hh} \in \mathbb{R}^{n \times n}$ maps previous hidden states to current hidden states - $\mathbf{W}_{hy} \in \mathbb{R}^{m \times n}$ maps hidden states to outputs - $\mathbf{b}_h, \mathbf{b}_y$ are bias vectors - $\mathbf{h}_0$ is typically initialized to zeros

The critical insight is that the same weights $\mathbf{W}_{xh}$, $\mathbf{W}_{hh}$, and $\mathbf{W}_{hy}$ are used at every time step. This parameter sharing is what allows RNNs to generalize across sequence positions and handle variable-length inputs.

15.2.2 Unrolling Through Time

To understand how an RNN processes a sequence, we "unroll" it through time. Given a sequence $(\mathbf{x}_1, \mathbf{x}_2, \ldots, \mathbf{x}_T)$, the unrolled computation graph looks like a chain of identical transformations:

x_1       x_2       x_3           x_T
 |         |         |             |
 v         v         v             v
[RNN] --> [RNN] --> [RNN] --> ... [RNN]
 |         |         |             |
 v         v         v             v
h_1       h_2       h_3           h_T
 |         |         |             |
 v         v         v             v
y_1       y_2       y_3           y_T

Each [RNN] block applies the same transformation with the same weights. The hidden state $\mathbf{h}_t$ serves as the network's "memory," carrying information from all previous time steps.

15.2.3 Backpropagation Through Time (BPTT)

Training an RNN requires computing gradients through the unrolled computation graph. This algorithm is called Backpropagation Through Time (BPTT). It is not a new algorithm per se---it is simply standard backpropagation (as we derived in Chapter 11) applied to the unrolled computation graph of the RNN. However, the temporal structure introduces unique challenges that deserve careful analysis.

Deriving BPTT Step by Step.

Given a loss function $\mathcal{L} = \sum_{t=1}^{T} \mathcal{L}_t$, where $\mathcal{L}_t$ is the loss at time step $t$, we need the gradient of the total loss with respect to the shared weight matrices. Let us derive $\frac{\partial \mathcal{L}}{\partial \mathbf{W}_{hh}}$.

Since $\mathbf{W}_{hh}$ is used at every time step, the total gradient is the sum of contributions from each time step:

$$\frac{\partial \mathcal{L}}{\partial \mathbf{W}_{hh}} = \sum_{t=1}^{T} \frac{\partial \mathcal{L}_t}{\partial \mathbf{W}_{hh}}$$

For a single time step's contribution, we need to account for the fact that $\mathbf{W}_{hh}$ affects $\mathcal{L}_t$ through every hidden state from $\mathbf{h}_1$ to $\mathbf{h}_t$. Applying the chain rule:

$$\frac{\partial \mathcal{L}_t}{\partial \mathbf{W}_{hh}} = \sum_{k=1}^{t} \frac{\partial \mathcal{L}_t}{\partial \mathbf{h}_t} \frac{\partial \mathbf{h}_t}{\partial \mathbf{h}_k} \frac{\partial^+ \mathbf{h}_k}{\partial \mathbf{W}_{hh}}$$

where $\frac{\partial^+ \mathbf{h}_k}{\partial \mathbf{W}_{hh}}$ denotes the "immediate" derivative of $\mathbf{h}_k$ with respect to $\mathbf{W}_{hh}$ (treating $\mathbf{h}_{k-1}$ as a constant). This immediate derivative is straightforward:

$$\frac{\partial^+ \mathbf{h}_k}{\partial \mathbf{W}_{hh}} = \text{diag}(1 - \mathbf{h}_k^2) \cdot \mathbf{h}_{k-1}^T$$

where $\text{diag}(1 - \mathbf{h}_k^2)$ is the derivative of tanh applied element-wise.

The critical term is the Jacobian product that propagates the error signal backward through time:

$$\frac{\partial \mathbf{h}_t}{\partial \mathbf{h}_k} = \prod_{j=k+1}^{t} \frac{\partial \mathbf{h}_j}{\partial \mathbf{h}_{j-1}}$$

Each factor in this product is the Jacobian of one time step's hidden state with respect to the previous:

$$\frac{\partial \mathbf{h}_j}{\partial \mathbf{h}_{j-1}} = \text{diag}(1 - \mathbf{h}_j^2) \cdot \mathbf{W}_{hh}$$

Combining everything:

$$\frac{\partial \mathcal{L}}{\partial \mathbf{W}_{hh}} = \sum_{t=1}^{T} \sum_{k=1}^{t} \frac{\partial \mathcal{L}_t}{\partial \mathbf{h}_t} \left(\prod_{j=k+1}^{t} \text{diag}(1 - \mathbf{h}_j^2) \cdot \mathbf{W}_{hh}\right) \text{diag}(1 - \mathbf{h}_k^2) \cdot \mathbf{h}_{k-1}^T$$

This double sum over $t$ and $k$ is computationally expensive, scaling as $O(T^2)$ in the worst case. In practice, we implement BPTT by performing a forward pass to compute and cache all hidden states, then sweeping backward from $t = T$ to $t = 1$, accumulating gradients as we go. This is algorithmically efficient: $O(T)$ time and $O(T)$ memory (for caching hidden states).

Truncated BPTT. For very long sequences, storing all $T$ hidden states and backpropagating through all of them is prohibitive. Truncated BPTT limits backpropagation to a fixed window of $\tau$ time steps. The sequence is divided into chunks of length $\tau$; for each chunk, the forward pass uses the hidden state from the previous chunk, but gradients are only computed within the chunk. This introduces a bias (gradients from dependencies longer than $\tau$ steps are ignored) but dramatically reduces memory and computation. Typical values are $\tau = 35$ to $\tau = 200$.

# Truncated BPTT: process sequence in chunks
chunk_size = 35  # Truncation length

for i in range(0, seq_len, chunk_size):
    chunk = sequence[:, i:i+chunk_size, :]
    output, hidden = model(chunk, hidden)

    loss = criterion(output, targets[:, i:i+chunk_size])
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    # Detach hidden state to prevent backprop beyond this chunk
    hidden = hidden.detach()

This product of Jacobians $\prod_{j=k+1}^{t} \frac{\partial \mathbf{h}_j}{\partial \mathbf{h}_{j-1}}$ is the source of the vanishing and exploding gradient problems that we will discuss next.

15.2.4 Implementing a Vanilla RNN Cell

Let us implement a vanilla RNN from scratch to solidify our understanding:

import torch
import torch.nn as nn

torch.manual_seed(42)


class VanillaRNNCell(nn.Module):
    """A single vanilla RNN cell.

    Implements the recurrence: h_t = tanh(W_xh @ x_t + W_hh @ h_{t-1} + b_h)

    Args:
        input_size: Dimensionality of the input vectors.
        hidden_size: Dimensionality of the hidden state.
    """

    def __init__(self, input_size: int, hidden_size: int) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.W_xh = nn.Linear(input_size, hidden_size)
        self.W_hh = nn.Linear(hidden_size, hidden_size, bias=False)

    def forward(
        self, x_t: torch.Tensor, h_prev: torch.Tensor
    ) -> torch.Tensor:
        """Compute the next hidden state.

        Args:
            x_t: Input at current time step, shape (batch_size, input_size).
            h_prev: Previous hidden state, shape (batch_size, hidden_size).

        Returns:
            New hidden state, shape (batch_size, hidden_size).
        """
        return torch.tanh(self.W_xh(x_t) + self.W_hh(h_prev))

To process a full sequence:

class VanillaRNN(nn.Module):
    """A vanilla RNN that processes entire sequences.

    Args:
        input_size: Dimensionality of input vectors.
        hidden_size: Dimensionality of hidden state.
        output_size: Dimensionality of output vectors.
    """

    def __init__(
        self, input_size: int, hidden_size: int, output_size: int
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.rnn_cell = VanillaRNNCell(input_size, hidden_size)
        self.output_layer = nn.Linear(hidden_size, output_size)

    def forward(
        self, x: torch.Tensor, h_0: torch.Tensor | None = None
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Process a sequence and return outputs and final hidden state.

        Args:
            x: Input sequence, shape (batch_size, seq_len, input_size).
            h_0: Initial hidden state. Defaults to zeros.

        Returns:
            Tuple of (outputs, h_final) where outputs has shape
            (batch_size, seq_len, output_size) and h_final has shape
            (batch_size, hidden_size).
        """
        batch_size, seq_len, _ = x.shape

        if h_0 is None:
            h_0 = torch.zeros(batch_size, self.hidden_size, device=x.device)

        h_t = h_0
        outputs = []

        for t in range(seq_len):
            h_t = self.rnn_cell(x[:, t, :], h_t)
            y_t = self.output_layer(h_t)
            outputs.append(y_t)

        outputs = torch.stack(outputs, dim=1)
        return outputs, h_t

15.3 The Vanishing and Exploding Gradient Problems

15.3.1 Mathematical Analysis

The vanilla RNN has an Achilles' heel. Recall the product of Jacobians from BPTT:

$$\prod_{j=k+1}^{t} \frac{\partial \mathbf{h}_j}{\partial \mathbf{h}_{j-1}}$$

Each Jacobian $\frac{\partial \mathbf{h}_j}{\partial \mathbf{h}_{j-1}} = \text{diag}(1 - \mathbf{h}_j^2) \cdot \mathbf{W}_{hh}$ where $\text{diag}(1 - \mathbf{h}_j^2)$ is the derivative of tanh.

Bounding the Jacobian norm. Since $\tanh'(z) = 1 - \tanh^2(z) \leq 1$ for all $z$, we have $\|\text{diag}(1 - \mathbf{h}_j^2)\| \leq 1$. Therefore:

$$\left\| \frac{\partial \mathbf{h}_j}{\partial \mathbf{h}_{j-1}} \right\| \leq \|\text{diag}(1 - \mathbf{h}_j^2)\| \cdot \|\mathbf{W}_{hh}\| \leq \|\mathbf{W}_{hh}\|$$

For the product over $t - k$ time steps:

$$\left\| \prod_{j=k+1}^{t} \frac{\partial \mathbf{h}_j}{\partial \mathbf{h}_{j-1}} \right\| \leq \|\mathbf{W}_{hh}\|^{t-k}$$

If the largest singular value $\sigma_{\max}(\mathbf{W}_{hh}) < 1$, then:

$$\left\| \prod_{j=k+1}^{t} \frac{\partial \mathbf{h}_j}{\partial \mathbf{h}_{j-1}} \right\| \leq (\gamma)^{t-k} \to 0 \text{ as } t-k \to \infty$$

where $\gamma < 1$. This is the vanishing gradient problem: gradients from distant time steps shrink exponentially to zero, making it impossible for the network to learn long-range dependencies.

Conversely, if $\sigma_{\max}(\mathbf{W}_{hh}) > 1$, gradients can grow exponentially---the exploding gradient problem.

How fast do gradients vanish? Let us put numbers to this. Suppose $\|\mathbf{W}_{hh}\| = 0.9$ and the tanh derivative contributes an average factor of 0.5 (since tanh saturates for large activations). The effective decay rate per time step is $\gamma \approx 0.9 \times 0.5 = 0.45$. Over 20 time steps:

$$\gamma^{20} = 0.45^{20} \approx 1.2 \times 10^{-7}$$

The gradient signal from 20 steps ago is attenuated by a factor of 10 million. After 50 steps, it is essentially zero in floating-point arithmetic. This calculation, first presented rigorously by Bengio, Simard, and Frasconi (1994), explains why vanilla RNNs cannot learn dependencies spanning more than about 10--20 time steps.

The exploding side. If $\gamma = 1.1$ (only slightly above 1), then $\gamma^{50} \approx 117$. This makes the gradient 100x too large, causing the weight update to overshoot catastrophically. Gradient norms can reach $10^{10}$ or higher, producing NaN values and completely destroying the model. The exploding gradient problem is easier to detect (the training diverges obviously) and easier to fix (gradient clipping, as discussed in Section 15.3.3 and Chapter 12). The vanishing gradient problem is more insidious because training appears to proceed normally---the loss decreases---but the model silently fails to learn long-range patterns.

Connection to eigenvalue analysis. A more precise analysis uses the eigendecomposition of $\mathbf{W}_{hh}$. If $\mathbf{W}_{hh} = \mathbf{P} \boldsymbol{\Lambda} \mathbf{P}^{-1}$ where $\boldsymbol{\Lambda}$ contains the eigenvalues $\lambda_1, \ldots, \lambda_n$, then $\mathbf{W}_{hh}^k = \mathbf{P} \boldsymbol{\Lambda}^k \mathbf{P}^{-1}$. The behavior is dominated by the largest eigenvalue: if $|\lambda_{\max}| < 1$, the product vanishes; if $|\lambda_{\max}| > 1$, it explodes. The transition between vanishing and exploding occurs at $|\lambda_{\max}| = 1$, a knife-edge that is impossible to maintain during training. This is the fundamental reason why vanilla RNNs fail on long sequences---there is no stable equilibrium for gradient flow.

15.3.2 Practical Consequences

The vanishing gradient problem means that a vanilla RNN effectively has a limited memory horizon. In practice, vanilla RNNs struggle to learn dependencies spanning more than about 10--20 time steps. Consider a language modeling task where the relevant context is 50 words earlier: "The woman who grew up in France ... spoke fluent French." A vanilla RNN would fail to connect "France" to "French" across the intervening text.

15.3.3 Gradient Clipping

While gradient clipping does not solve the vanishing gradient problem, it provides a practical solution to exploding gradients. The idea is simple: if the gradient norm exceeds a threshold $\theta$, rescale it:

$$\hat{\mathbf{g}} = \begin{cases} \mathbf{g} & \text{if } \|\mathbf{g}\| \leq \theta \\ \frac{\theta}{\|\mathbf{g}\|} \mathbf{g} & \text{if } \|\mathbf{g}\| > \theta \end{cases}$$

In PyTorch:

# After loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
optimizer.step()

15.3.4 Other Mitigations

Several techniques partially address gradient flow issues in vanilla RNNs:

  • Proper initialization: Initializing $\mathbf{W}_{hh}$ as an orthogonal matrix ensures all singular values are exactly 1 at the start of training, preventing both vanishing and exploding gradients initially. As training proceeds and $\mathbf{W}_{hh}$ changes, this property is not maintained, but it provides a healthier starting point.
  • Skip connections: Adding residual connections across time steps, analogous to the skip connections in ResNets (Chapter 13). The hidden state update becomes $\mathbf{h}_t = \tanh(\mathbf{W}_{hh}\mathbf{h}_{t-1} + \mathbf{W}_{xh}\mathbf{x}_t) + \mathbf{h}_{t-1}$, creating an additive path for gradient flow.
  • Truncated BPTT: Limiting backpropagation to a fixed number of time steps, as described in Section 15.2.3. This does not solve the vanishing gradient problem but limits its impact.
  • Echo State Networks / Reservoir Computing: Fix the recurrent weights (no training) and only train the output layer. This avoids gradient issues entirely but limits expressiveness.
  • Norm-preserving architectures: Design $\mathbf{W}_{hh}$ to be unitary (complex-valued) or orthogonal, maintaining gradient norm exactly. The Unitary RNN (Arjovsky et al., 2016) uses this approach.

However, the definitive and most practical solution came with gated architectures: LSTM and GRU. Their gating mechanisms provide a learned mechanism for controlling gradient flow, which is far more flexible than any fixed architectural constraint.


15.4 Long Short-Term Memory (LSTM)

The LSTM, introduced by Hochreiter and Schmidhuber in 1997, is one of the most important innovations in deep learning. It solved the vanishing gradient problem by introducing a cell state that acts as an information highway, allowing gradients to flow unchanged across many time steps.

15.4.1 The Cell State: An Information Highway

The key innovation of LSTM is the cell state $\mathbf{c}_t$, a vector that runs through the entire sequence with only minor linear interactions. Think of it as a conveyor belt: information can be placed on it, read from it, or removed from it, but it flows forward with minimal degradation.

The cell state is regulated by three gates---the forget gate, input gate, and output gate---each of which is a sigmoid-activated layer that outputs values between 0 and 1, controlling how much information flows through.

15.4.2 The Forget Gate

The forget gate decides what information to discard from the cell state. It looks at the previous hidden state $\mathbf{h}_{t-1}$ and the current input $\mathbf{x}_t$, and outputs a value between 0 (completely forget) and 1 (completely remember) for each element of the cell state:

$$\mathbf{f}_t = \sigma(\mathbf{W}_f [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_f)$$

where $\sigma$ is the sigmoid function and $[\mathbf{h}_{t-1}, \mathbf{x}_t]$ denotes concatenation.

For example, in a language model tracking the subject's gender, when the model encounters a new subject, the forget gate should erase the old gender information from the cell state.

15.4.3 The Input Gate

The input gate decides what new information to store in the cell state. This has two parts:

  1. The input gate layer decides which values to update:

$$\mathbf{i}_t = \sigma(\mathbf{W}_i [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_i)$$

  1. A candidate cell state is created with tanh:

$$\tilde{\mathbf{c}}_t = \tanh(\mathbf{W}_c [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_c)$$

15.4.4 Cell State Update

The old cell state $\mathbf{c}_{t-1}$ is updated by:

  1. Multiplying by the forget gate (erasing what should be forgotten)
  2. Adding the new candidate values scaled by the input gate

$$\mathbf{c}_t = \mathbf{f}_t \odot \mathbf{c}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{c}}_t$$

where $\odot$ denotes element-wise multiplication. This is the critical equation: the cell state update is additive, not multiplicative. This additive structure is what prevents gradient vanishing, because the gradient $\frac{\partial \mathbf{c}_t}{\partial \mathbf{c}_{t-1}} = \mathbf{f}_t$, which can remain close to 1 for long periods.

15.4.5 The Output Gate

The output gate determines what part of the cell state to expose as the hidden state:

$$\mathbf{o}_t = \sigma(\mathbf{W}_o [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_o)$$

$$\mathbf{h}_t = \mathbf{o}_t \odot \tanh(\mathbf{c}_t)$$

The cell state is passed through tanh (to push values between $-1$ and $1$) and then filtered by the output gate. This allows the LSTM to maintain information in the cell state that is not immediately relevant to the current output but may be useful later.

15.4.6 LSTM Summary

Putting it all together, the complete LSTM equations at each time step are:

$$\mathbf{f}_t = \sigma(\mathbf{W}_f [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_f)$$ $$\mathbf{i}_t = \sigma(\mathbf{W}_i [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_i)$$ $$\tilde{\mathbf{c}}_t = \tanh(\mathbf{W}_c [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_c)$$ $$\mathbf{c}_t = \mathbf{f}_t \odot \mathbf{c}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{c}}_t$$ $$\mathbf{o}_t = \sigma(\mathbf{W}_o [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_o)$$ $$\mathbf{h}_t = \mathbf{o}_t \odot \tanh(\mathbf{c}_t)$$

The LSTM has roughly 4x the parameters of a vanilla RNN for the same hidden size, because it has four sets of weight matrices (one for each of the three gates plus the candidate cell state).

15.4.7 Why LSTMs Solve the Vanishing Gradient Problem

The gradient of the loss with respect to the cell state at time $k$ flows through the chain:

$$\frac{\partial \mathbf{c}_t}{\partial \mathbf{c}_k} = \prod_{j=k+1}^{t} \frac{\partial \mathbf{c}_j}{\partial \mathbf{c}_{j-1}} = \prod_{j=k+1}^{t} \mathbf{f}_j$$

Each forget gate $\mathbf{f}_j$ has values between 0 and 1. When the forget gate is close to 1 (remembering), the gradient flows through nearly unchanged. The network learns to set the forget gate appropriately: close to 1 when information should be preserved, and close to 0 when it should be erased. This learned gating mechanism is what enables LSTMs to capture dependencies spanning hundreds of time steps.

15.4.8 LSTM Implementation

import torch
import torch.nn as nn

torch.manual_seed(42)


class LSTMCell(nn.Module):
    """A single LSTM cell with explicit gate computations.

    Args:
        input_size: Dimensionality of input vectors.
        hidden_size: Dimensionality of hidden state and cell state.
    """

    def __init__(self, input_size: int, hidden_size: int) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        # Combined weight matrix for efficiency
        self.gates = nn.Linear(
            input_size + hidden_size, 4 * hidden_size
        )

    def forward(
        self,
        x_t: torch.Tensor,
        h_prev: torch.Tensor,
        c_prev: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Compute next hidden state and cell state.

        Args:
            x_t: Input at current time step, shape (batch, input_size).
            h_prev: Previous hidden state, shape (batch, hidden_size).
            c_prev: Previous cell state, shape (batch, hidden_size).

        Returns:
            Tuple of (h_t, c_t), each shape (batch, hidden_size).
        """
        combined = torch.cat([h_prev, x_t], dim=1)
        gates = self.gates(combined)

        # Split into four gate values
        i_t, f_t, g_t, o_t = gates.chunk(4, dim=1)

        i_t = torch.sigmoid(i_t)  # Input gate
        f_t = torch.sigmoid(f_t)  # Forget gate
        g_t = torch.tanh(g_t)     # Candidate cell state
        o_t = torch.sigmoid(o_t)  # Output gate

        c_t = f_t * c_prev + i_t * g_t
        h_t = o_t * torch.tanh(c_t)

        return h_t, c_t

15.5 Gated Recurrent Units (GRU)

15.5.1 A Streamlined Alternative

The Gated Recurrent Unit, introduced by Cho et al. in 2014, is a simplified variant of the LSTM that achieves comparable performance with fewer parameters. The GRU merges the cell state and hidden state into a single vector and uses two gates instead of three.

15.5.2 GRU Equations

The GRU uses an update gate $\mathbf{z}_t$ and a reset gate $\mathbf{r}_t$:

$$\mathbf{z}_t = \sigma(\mathbf{W}_z [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_z)$$

$$\mathbf{r}_t = \sigma(\mathbf{W}_r [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_r)$$

The candidate hidden state uses the reset gate to control how much of the previous state to consider:

$$\tilde{\mathbf{h}}_t = \tanh(\mathbf{W}_h [\mathbf{r}_t \odot \mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_h)$$

The final hidden state is a linear interpolation between the previous state and the candidate:

$$\mathbf{h}_t = (1 - \mathbf{z}_t) \odot \mathbf{h}_{t-1} + \mathbf{z}_t \odot \tilde{\mathbf{h}}_t$$

15.5.3 Understanding the Gates

The update gate $\mathbf{z}_t$ controls the balance between keeping the old state and accepting the new candidate. When $\mathbf{z}_t \approx 0$, the hidden state is simply copied forward (like an LSTM forget gate close to 1). When $\mathbf{z}_t \approx 1$, the hidden state is replaced by the candidate.

The reset gate $\mathbf{r}_t$ controls how much of the previous hidden state to use when computing the candidate. When $\mathbf{r}_t \approx 0$, the candidate is computed as if the previous state were zero, allowing the network to "forget" and start fresh. When $\mathbf{r}_t \approx 1$, the candidate considers the full previous state.

15.5.4 LSTM vs. GRU

Feature LSTM GRU
Gates 3 (forget, input, output) 2 (update, reset)
State vectors 2 (hidden + cell) 1 (hidden)
Parameters 4 weight matrices 3 weight matrices
Performance Slightly better on long sequences Comparable overall
Training speed Slower (more parameters) Faster
When to use Long dependencies, complex patterns Smaller datasets, faster training

In practice, the performance difference between LSTM and GRU is often marginal. The GRU trains faster due to fewer parameters, while the LSTM may have a slight edge on tasks requiring very long-range memory. When in doubt, try both and compare.

Detailed Comparison. The GRU achieves comparable performance to the LSTM with fewer parameters because:

  1. Combined forget and input gates. In the GRU, the update gate $\mathbf{z}_t$ controls both forgetting and inputting. When $\mathbf{z}_t \approx 0$, the old state is kept (analogous to the LSTM forget gate being close to 1 and the input gate being close to 0). This coupling means the GRU cannot independently control forgetting and inputting, which is a mild limitation but saves an entire gate's worth of parameters.

  2. No separate cell state. The LSTM maintains two state vectors ($\mathbf{h}_t$ and $\mathbf{c}_t$), while the GRU uses only $\mathbf{h}_t$. The LSTM's cell state provides a dedicated "memory highway" that is distinct from the output. The GRU merges these roles, making its hidden state do double duty.

  3. No output gate. The LSTM's output gate controls how much of the cell state to expose. The GRU exposes its entire hidden state at every step. This means the GRU cannot selectively hide information from downstream layers, which is occasionally useful for tasks where the relevant memory content changes over time.

Empirically, Chung et al. (2014) found that both architectures outperform vanilla RNNs by a large margin, but neither consistently outperforms the other across all tasks. Greff et al. (2017) conducted an extensive ablation study of LSTM variants and found that the forget gate and output gate are the most critical components---removing either significantly degrades performance. The GRU's update gate functions similarly to the forget gate, which may explain its competitive performance.

15.5.5 GRU Implementation

class GRUCell(nn.Module):
    """A single GRU cell with explicit gate computations.

    Args:
        input_size: Dimensionality of input vectors.
        hidden_size: Dimensionality of hidden state.
    """

    def __init__(self, input_size: int, hidden_size: int) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.W_z = nn.Linear(input_size + hidden_size, hidden_size)
        self.W_r = nn.Linear(input_size + hidden_size, hidden_size)
        self.W_h = nn.Linear(input_size + hidden_size, hidden_size)

    def forward(
        self, x_t: torch.Tensor, h_prev: torch.Tensor
    ) -> torch.Tensor:
        """Compute next hidden state.

        Args:
            x_t: Input at current time step, shape (batch, input_size).
            h_prev: Previous hidden state, shape (batch, hidden_size).

        Returns:
            New hidden state, shape (batch, hidden_size).
        """
        combined = torch.cat([h_prev, x_t], dim=1)

        z_t = torch.sigmoid(self.W_z(combined))
        r_t = torch.sigmoid(self.W_r(combined))

        reset_combined = torch.cat([r_t * h_prev, x_t], dim=1)
        h_candidate = torch.tanh(self.W_h(reset_combined))

        h_t = (1 - z_t) * h_prev + z_t * h_candidate
        return h_t

15.6 Bidirectional RNNs

15.6.1 Motivation

In many tasks, the context needed to interpret a position comes from both directions. Consider named entity recognition: in "He went to Washington to meet the president," identifying "Washington" as a place (rather than a person) benefits from seeing "to meet the president" after it.

A standard RNN processes the sequence left-to-right, so $\mathbf{h}_t$ only captures information from $\mathbf{x}_1, \ldots, \mathbf{x}_t$. A bidirectional RNN adds a second RNN that processes the sequence right-to-left:

$$\overrightarrow{\mathbf{h}}_t = \text{RNN}_{\text{fwd}}(\mathbf{x}_t, \overrightarrow{\mathbf{h}}_{t-1})$$

$$\overleftarrow{\mathbf{h}}_t = \text{RNN}_{\text{bwd}}(\mathbf{x}_t, \overleftarrow{\mathbf{h}}_{t+1})$$

The final representation at each time step concatenates both:

$$\mathbf{h}_t = [\overrightarrow{\mathbf{h}}_t; \overleftarrow{\mathbf{h}}_t]$$

15.6.2 When to Use Bidirectional RNNs

Bidirectional RNNs are appropriate when the entire input sequence is available before making predictions. This includes:

  • Text classification (sentiment analysis)
  • Named entity recognition
  • Part-of-speech tagging
  • Machine translation (encoder side)
  • Speech recognition (with sufficient buffer)

Bidirectional RNNs are not appropriate for tasks requiring real-time, left-to-right generation, such as language modeling or online speech synthesis, because the backward pass requires seeing the full sequence first.

Extracting representations from bidirectional RNNs. For sequence classification tasks (many-to-one), you need to reduce the sequence of hidden states into a single vector. Common strategies include:

  1. Concatenate the final hidden states from both directions: $\mathbf{h} = [\overrightarrow{\mathbf{h}}_T; \overleftarrow{\mathbf{h}}_1]$. This captures the end of the forward pass and the end of the backward pass (which corresponds to the start of the sequence).

  2. Mean pooling over all time steps: $\mathbf{h} = \frac{1}{T} \sum_{t=1}^{T} [\overrightarrow{\mathbf{h}}_t; \overleftarrow{\mathbf{h}}_t]$. This gives equal weight to all positions.

  3. Max pooling over each dimension: $h_j = \max_t [\overrightarrow{h}_{t,j}; \overleftarrow{h}_{t,j}]$. This captures the strongest activation at any position.

  4. Attention pooling: Learn a weighted average, where the weights depend on the hidden states (a precursor to the self-attention mechanism we will explore in Chapter 18).

15.6.3 PyTorch Implementation

PyTorch makes bidirectional RNNs trivial:

import torch
import torch.nn as nn

torch.manual_seed(42)

# Bidirectional LSTM
bilstm = nn.LSTM(
    input_size=128,
    hidden_size=256,
    num_layers=2,
    batch_first=True,
    bidirectional=True,
    dropout=0.3,
)

# Input: (batch_size, seq_len, input_size)
x = torch.randn(32, 50, 128)
output, (h_n, c_n) = bilstm(x)

# output shape: (32, 50, 512)  -- 256 * 2 due to bidirectional
# h_n shape: (4, 32, 256)  -- 2 layers * 2 directions
print(f"Output shape: {output.shape}")
print(f"Hidden shape: {h_n.shape}")

15.7 Deep RNNs and Stacking Layers

15.7.1 Why Go Deeper?

Just as deeper CNNs learn increasingly abstract features, stacking multiple RNN layers creates a hierarchy of temporal representations. The first layer captures low-level patterns (e.g., phonemes in speech), while higher layers capture more abstract patterns (e.g., words, phrases).

In a deep RNN with $L$ layers, the hidden state of layer $l$ at time $t$ is:

$$\mathbf{h}_t^{(l)} = \text{RNN}^{(l)}(\mathbf{h}_t^{(l-1)}, \mathbf{h}_{t-1}^{(l)})$$

where $\mathbf{h}_t^{(0)} = \mathbf{x}_t$.

15.7.2 Dropout in RNNs

Applying dropout to RNNs requires care. Naive dropout at every time step would destroy the temporal information in the hidden state. The standard approach is to apply dropout between layers (vertically) but not across time steps (horizontally):

# PyTorch handles this automatically with the dropout parameter
deep_lstm = nn.LSTM(
    input_size=128,
    hidden_size=256,
    num_layers=3,       # Three stacked LSTM layers
    batch_first=True,
    dropout=0.5,        # Applied between layers (not on last layer)
)

For more sophisticated dropout, variational dropout (Gal and Ghahramani, 2016) uses the same dropout mask at every time step within a sequence, which preserves more temporal information.

15.7.3 Residual Connections in Deep RNNs

For deep RNNs with more than 2--3 layers, training can become difficult due to gradient degradation across the layer stack (a vertical analogue of the temporal vanishing gradient problem). Residual connections, inspired by ResNets (Chapter 13), help by adding a skip connection from the input of each layer to its output:

$$\mathbf{h}_t^{(l)} = \text{RNN}^{(l)}(\mathbf{h}_t^{(l-1)}, \mathbf{h}_{t-1}^{(l)}) + \mathbf{h}_t^{(l-1)}$$

This requires that the input and output dimensions of each RNN layer match. If they differ, a linear projection can be used on the skip path. Residual connections are especially important for deep bidirectional models like those used in speech recognition, where stacking 5--8 layers is common.

15.7.4 Practical Considerations for Deep RNNs

  • Depth: 2--4 layers is typical. Beyond 4 layers, training becomes difficult without residual connections.
  • Hidden size: Common choices range from 128 to 1024. Larger hidden sizes increase capacity but also computation time and overfitting risk.
  • Gradient clipping: Essential for deep RNNs. A max norm of 1.0--5.0 is standard.
  • Layer normalization: Can stabilize training in deep RNNs. Apply it within each RNN cell, normalizing the pre-activation values before the nonlinearity.
  • Regularization: Apply dropout between layers (not across time steps). For very deep RNNs, weight tying (sharing input and output embeddings) and variational dropout provide additional regularization.

15.8 Sequence-to-Sequence Models

15.8.1 The Encoder-Decoder Architecture

Many important tasks involve mapping one sequence to another of a potentially different length: machine translation (English to French), text summarization (long document to short summary), or speech recognition (audio to text). The sequence-to-sequence (seq2seq) framework, introduced by Sutskever et al. (2014), handles these tasks elegantly.

The architecture has two components:

  1. Encoder: Reads the input sequence and compresses it into a fixed-length context vector $\mathbf{c}$
  2. Decoder: Generates the output sequence one token at a time, conditioned on the context vector

Encoder

The encoder processes the input sequence $(\mathbf{x}_1, \ldots, \mathbf{x}_T)$ and produces hidden states. The final hidden state (or a transformation of it) becomes the context vector:

$$\mathbf{h}_t^{\text{enc}} = \text{LSTM}_{\text{enc}}(\mathbf{x}_t, \mathbf{h}_{t-1}^{\text{enc}})$$

$$\mathbf{c} = \mathbf{h}_T^{\text{enc}}$$

Decoder

The decoder is an RNN that generates the output sequence step by step. At each step, it takes the previous output token and the hidden state:

$$\mathbf{h}_t^{\text{dec}} = \text{LSTM}_{\text{dec}}(\mathbf{y}_{t-1}, \mathbf{h}_{t-1}^{\text{dec}})$$

$$P(\mathbf{y}_t | \mathbf{y}_{

The decoder is initialized with the context vector: $\mathbf{h}_0^{\text{dec}} = \mathbf{c}$.

15.8.2 The Information Bottleneck

A critical limitation of the basic seq2seq model is that the entire input sequence must be compressed into a single fixed-length vector $\mathbf{c}$. For short sequences, this works well. But for long sequences, this creates an information bottleneck: the context vector cannot retain all the nuances of a 100-word sentence.

Quantifying the bottleneck. Cho et al. (2014) demonstrated this empirically by measuring BLEU scores (a machine translation quality metric) as a function of input sentence length. For basic seq2seq models, performance degraded dramatically for sentences longer than about 20--30 words. With attention (Section 15.11), this degradation was largely eliminated, confirming that the fixed-length bottleneck was the primary cause.

Strategies to mitigate the bottleneck without attention: 1. Reverse the input sequence. Sutskever et al. (2014) found that reversing the source sentence improved translation quality. The intuition is that reversing places the first few source words closer to the first few target words in the unrolled computation graph, reducing the distance that gradients must travel for the most important early alignments. 2. Use a bidirectional encoder. The final context vector concatenates the forward and backward final states, giving the decoder information about both the beginning and end of the source sequence. 3. Use a deeper encoder. More layers increase the model's capacity to compress the sequence into a rich context vector.

However, none of these strategies fully solve the problem. The definitive solution is the attention mechanism, which we preview in Section 15.11 and cover in depth in Chapter 18.

15.8.3 Seq2Seq Implementation

import torch
import torch.nn as nn

torch.manual_seed(42)


class Encoder(nn.Module):
    """Seq2seq encoder using LSTM.

    Args:
        vocab_size: Size of the source vocabulary.
        embed_dim: Dimensionality of word embeddings.
        hidden_size: Dimensionality of LSTM hidden state.
        num_layers: Number of stacked LSTM layers.
        dropout: Dropout rate between LSTM layers.
    """

    def __init__(
        self,
        vocab_size: int,
        embed_dim: int,
        hidden_size: int,
        num_layers: int = 2,
        dropout: float = 0.3,
    ) -> None:
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(
            embed_dim,
            hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0.0,
        )
        self.dropout = nn.Dropout(dropout)

    def forward(
        self, src: torch.Tensor
    ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        """Encode source sequence.

        Args:
            src: Source token IDs, shape (batch_size, src_len).

        Returns:
            Tuple of (encoder_outputs, (h_n, c_n)).
        """
        embedded = self.dropout(self.embedding(src))
        outputs, (h_n, c_n) = self.lstm(embedded)
        return outputs, (h_n, c_n)


class Decoder(nn.Module):
    """Seq2seq decoder using LSTM.

    Args:
        vocab_size: Size of the target vocabulary.
        embed_dim: Dimensionality of word embeddings.
        hidden_size: Dimensionality of LSTM hidden state.
        num_layers: Number of stacked LSTM layers.
        dropout: Dropout rate between LSTM layers.
    """

    def __init__(
        self,
        vocab_size: int,
        embed_dim: int,
        hidden_size: int,
        num_layers: int = 2,
        dropout: float = 0.3,
    ) -> None:
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(
            embed_dim,
            hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0.0,
        )
        self.fc_out = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        trg_token: torch.Tensor,
        hidden: tuple[torch.Tensor, torch.Tensor],
    ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        """Decode one time step.

        Args:
            trg_token: Target token IDs, shape (batch_size, 1).
            hidden: Previous (h, c) states from LSTM.

        Returns:
            Tuple of (prediction, (h_n, c_n)).
        """
        embedded = self.dropout(self.embedding(trg_token))
        output, (h_n, c_n) = self.lstm(embedded, hidden)
        prediction = self.fc_out(output.squeeze(1))
        return prediction, (h_n, c_n)


class Seq2Seq(nn.Module):
    """Complete sequence-to-sequence model.

    Args:
        encoder: The encoder module.
        decoder: The decoder module.
    """

    def __init__(self, encoder: Encoder, decoder: Decoder) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(
        self,
        src: torch.Tensor,
        trg: torch.Tensor,
        teacher_forcing_ratio: float = 0.5,
    ) -> torch.Tensor:
        """Run seq2seq forward pass.

        Args:
            src: Source token IDs, shape (batch_size, src_len).
            trg: Target token IDs, shape (batch_size, trg_len).
            teacher_forcing_ratio: Probability of using ground truth
                as next input during training.

        Returns:
            Predictions, shape (batch_size, trg_len, vocab_size).
        """
        batch_size, trg_len = trg.shape
        vocab_size = self.decoder.fc_out.out_features

        outputs = torch.zeros(
            batch_size, trg_len, vocab_size, device=src.device
        )

        _, hidden = self.encoder(src)
        input_token = trg[:, 0:1]  # <SOS> token

        for t in range(1, trg_len):
            prediction, hidden = self.decoder(input_token, hidden)
            outputs[:, t, :] = prediction

            if torch.rand(1).item() < teacher_forcing_ratio:
                input_token = trg[:, t:t+1]
            else:
                input_token = prediction.argmax(dim=1, keepdim=True)

        return outputs

15.9 Teacher Forcing

15.9.1 The Training Dilemma

During training, the decoder must decide what input to use at each time step. There are two options:

  1. Free-running mode: Use the model's own previous prediction as input. This matches how the model will be used at inference time.
  2. Teacher forcing: Use the ground-truth previous token as input, regardless of what the model predicted.

15.9.2 Why Teacher Forcing Helps

Teacher forcing dramatically accelerates training. In free-running mode, if the model makes a wrong prediction at time step $t$, the subsequent inputs are corrupted, leading to a cascade of errors. Early in training, when the model's predictions are nearly random, this makes learning extremely slow.

With teacher forcing, each time step receives correct context, allowing the model to learn the conditional distribution $P(y_t | y_{

15.9.3 The Exposure Bias Problem

Teacher forcing introduces a discrepancy between training and inference. During training, the model always sees correct previous tokens. During inference, it sees its own (potentially incorrect) predictions. This mismatch is called exposure bias.

Mitigation strategies include:

  • Scheduled sampling: Gradually reduce the teacher forcing ratio during training, transitioning from fully teacher-forced to fully free-running
  • Curriculum learning: Start with easy (short) sequences and increase difficulty
  • Sequence-level training: Use reinforcement learning objectives like REINFORCE to optimize sequence-level metrics directly
def get_teacher_forcing_ratio(
    epoch: int, total_epochs: int, strategy: str = "linear"
) -> float:
    """Compute teacher forcing ratio with scheduled decay.

    Args:
        epoch: Current training epoch (0-indexed).
        total_epochs: Total number of training epochs.
        strategy: Decay strategy, one of "linear" or "exponential".

    Returns:
        Teacher forcing ratio between 0 and 1.
    """
    if strategy == "linear":
        return max(0.0, 1.0 - epoch / total_epochs)
    elif strategy == "exponential":
        return 0.99 ** epoch
    else:
        raise ValueError(f"Unknown strategy: {strategy}")

15.10.1 Beyond Greedy Decoding

At inference time, the decoder generates tokens one at a time. The simplest strategy is greedy decoding: at each step, pick the token with the highest probability. But greedy decoding can miss better sequences because a locally optimal choice at step $t$ may lead to a globally suboptimal sequence.

For example, consider two partial sequences: - "The dog" (probability 0.6) followed by "runs" (probability 0.1): total 0.06 - "A cat" (probability 0.4) followed by "sleeps" (probability 0.5): total 0.20

Greedy decoding would choose "The dog" at step 1, missing the higher-probability sequence.

15.10.2 The Beam Search Algorithm

Beam search maintains a set of $k$ best partial hypotheses (the "beam") at each time step, where $k$ is the beam width. At each step:

  1. Expand each hypothesis with all possible next tokens
  2. Score each expanded hypothesis by its cumulative log probability
  3. Keep only the top $k$ hypotheses
  4. Stop when all hypotheses have generated an end-of-sequence token (or a maximum length is reached)

The probability of a sequence is:

$$\log P(\mathbf{y} | \mathbf{x}) = \sum_{t=1}^{T} \log P(y_t | y_{

15.10.3 Length Normalization

Beam search has a bias toward shorter sequences because each additional token multiplies a probability less than 1 (adds a negative log probability). To counteract this, we normalize by sequence length:

$$\text{score}(\mathbf{y}) = \frac{1}{|\mathbf{y}|^\alpha} \sum_{t=1}^{|\mathbf{y}|} \log P(y_t | y_{

where $\alpha \in [0, 1]$ controls the strength of length normalization. With $\alpha = 0$, there is no normalization; with $\alpha = 1$, we normalize fully by length.

15.10.4 Beam Search Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(42)


def beam_search(
    decoder: nn.Module,
    initial_hidden: tuple[torch.Tensor, torch.Tensor],
    sos_token: int,
    eos_token: int,
    beam_width: int = 5,
    max_len: int = 50,
    length_penalty_alpha: float = 0.7,
) -> list[tuple[list[int], float]]:
    """Perform beam search decoding.

    Args:
        decoder: The decoder module (processes one token at a time).
        initial_hidden: Initial decoder hidden state (h, c).
        sos_token: Start-of-sequence token ID.
        eos_token: End-of-sequence token ID.
        beam_width: Number of hypotheses to maintain.
        max_len: Maximum output sequence length.
        length_penalty_alpha: Length normalization exponent.

    Returns:
        List of (token_ids, score) tuples sorted by score (best first).
    """
    device = initial_hidden[0].device

    # Each beam: (token_ids, log_prob, hidden_state, is_complete)
    beams = [(
        [sos_token],
        0.0,
        initial_hidden,
        False,
    )]

    completed = []

    for _ in range(max_len):
        candidates = []

        for tokens, log_prob, hidden, is_complete in beams:
            if is_complete:
                completed.append((tokens, log_prob))
                continue

            # Get decoder prediction for last token
            input_token = torch.tensor(
                [[tokens[-1]]], device=device
            )
            prediction, new_hidden = decoder(input_token, hidden)
            log_probs = F.log_softmax(prediction, dim=-1)

            # Get top-k tokens
            topk_log_probs, topk_indices = log_probs.topk(beam_width)

            for i in range(beam_width):
                next_token = topk_indices[0, i].item()
                next_log_prob = log_prob + topk_log_probs[0, i].item()
                next_tokens = tokens + [next_token]
                is_done = next_token == eos_token

                candidates.append((
                    next_tokens,
                    next_log_prob,
                    new_hidden,
                    is_done,
                ))

        if not candidates:
            break

        # Sort by log probability and keep top beam_width
        candidates.sort(key=lambda x: x[1], reverse=True)
        beams = candidates[:beam_width]

    # Add any remaining incomplete beams
    for tokens, log_prob, _, _ in beams:
        completed.append((tokens, log_prob))

    # Apply length normalization and sort
    scored = []
    for tokens, log_prob in completed:
        length = len(tokens)
        normalized_score = log_prob / (length ** length_penalty_alpha)
        scored.append((tokens, normalized_score))

    scored.sort(key=lambda x: x[1], reverse=True)
    return scored

15.10.5 Practical Beam Search Settings

Typical beam widths range from 4 to 10. Larger beams improve quality but increase computation linearly. In practice:

  • Machine translation: beam width 4--6
  • Image captioning: beam width 3--5
  • Speech recognition: beam width 10--100

Beyond a certain point, increasing beam width provides diminishing returns and can even degrade quality (the "beam search curse"), where longer, lower-quality sequences are favored.


15.11 Attention Mechanism Preview

15.11.1 The Core Idea

The attention mechanism, introduced by Bahdanau et al. (2015), addresses the information bottleneck of seq2seq models. Instead of compressing the entire input into a single context vector, attention allows the decoder to look back at all encoder hidden states and focus on the most relevant ones at each decoding step.

At each decoder time step $t$, attention computes:

  1. Alignment scores: How relevant is each encoder state to the current decoder state?

$$e_{t,s} = \text{score}(\mathbf{h}_t^{\text{dec}}, \mathbf{h}_s^{\text{enc}})$$

  1. Attention weights: Normalized scores via softmax

$$\alpha_{t,s} = \frac{\exp(e_{t,s})}{\sum_{s'=1}^{S} \exp(e_{t,s'})}$$

  1. Context vector: Weighted sum of encoder states

$$\mathbf{c}_t = \sum_{s=1}^{S} \alpha_{t,s} \mathbf{h}_s^{\text{enc}}$$

This context vector $\mathbf{c}_t$ is then used alongside the decoder hidden state to predict the output.

15.11.2 Scoring Functions

The alignment score function $\text{score}(\mathbf{h}_t^{\text{dec}}, \mathbf{h}_s^{\text{enc}})$ can take several forms:

Additive attention (Bahdanau, 2015): $$e_{t,s} = \mathbf{v}^T \tanh(\mathbf{W}_1 \mathbf{h}_t^{\text{dec}} + \mathbf{W}_2 \mathbf{h}_s^{\text{enc}})$$

where $\mathbf{W}_1$, $\mathbf{W}_2$, and $\mathbf{v}$ are learned parameters.

Multiplicative (dot-product) attention (Luong, 2015): $$e_{t,s} = (\mathbf{h}_t^{\text{dec}})^T \mathbf{h}_s^{\text{enc}}$$

This is computationally cheaper (a simple dot product) and scales better to large hidden dimensions.

Scaled dot-product attention (Vaswani et al., 2017): $$e_{t,s} = \frac{(\mathbf{h}_t^{\text{dec}})^T \mathbf{h}_s^{\text{enc}}}{\sqrt{d}}$$

where $d$ is the dimension of the hidden state. The scaling factor $\frac{1}{\sqrt{d}}$ prevents the dot products from growing too large in high dimensions, which would push the softmax into saturation and produce near-one-hot attention weights. This scaled dot-product attention is the foundation of the Transformer architecture that we will explore in Chapter 18.

15.11.3 Impact of Attention

Attention transformed sequence modeling:

  • Eliminated the bottleneck: The model can access any part of the input at any time
  • Improved long-sequence performance: Translation quality no longer degrades with input length
  • Provided interpretability: Attention weights show which input tokens the model focuses on for each output
  • Led to Transformers: The insight that attention alone (without recurrence) could be sufficient led to the Transformer architecture (Vaswani et al., 2017), which we cover in Chapter 18

The progression from seq2seq to seq2seq-with-attention to Transformers is one of the most important narrative arcs in modern deep learning. Each step removed a limitation: seq2seq removed the fixed-length bottleneck of simple RNNs; attention removed the information bottleneck of seq2seq; and Transformers removed the sequential processing bottleneck of RNN-based attention models. Understanding this progression---which you are now well positioned to do---is essential for appreciating why Transformers work and what they replaced.

We will explore attention mechanisms thoroughly in Chapter 18.


15.12 Comparison with 1D Convolutions for Sequences

Before covering practical considerations, it is worth comparing RNNs with an alternative approach to sequence modeling: 1D convolutional neural networks (1D CNNs). As we saw in Chapter 13, convolutions are powerful for extracting local patterns from grid-structured data. They can also be applied to sequences.

15.12.1 1D Convolutions for Sequences

A 1D convolutional layer slides a filter of width $k$ across the sequence, computing a weighted sum at each position. For a sequence of embeddings $\mathbf{x}_1, \ldots, \mathbf{x}_T$, a 1D convolution with filter $\mathbf{w} \in \mathbb{R}^{k \times d}$ produces:

$$\mathbf{h}_t = f\left(\sum_{i=0}^{k-1} \mathbf{w}_i \cdot \mathbf{x}_{t+i} + b\right)$$

Stacking multiple convolutional layers with increasing dilation rates (as in WaveNet) allows the receptive field to grow exponentially with depth, enabling long-range dependencies without recurrence.

15.12.2 RNNs vs. 1D CNNs

Aspect RNNs (LSTM/GRU) 1D CNNs
Receptive field Theoretically infinite (entire history) Fixed by kernel size and depth
Parallelism Sequential (cannot parallelize across time) Fully parallel across time
Training speed Slow (sequential computation) Fast (parallel computation)
Memory Adaptive (hidden state evolves) Fixed (captured by filter patterns)
Long-range dependencies Good with LSTM/GRU gates Requires deep stacks or dilated convolutions
Streaming/online Natural (process one token at a time) Requires buffering a window
Parameter count Grows with hidden size, not sequence length Grows with kernel size and number of filters

When to prefer 1D CNNs. If your task involves detecting local patterns (e.g., n-gram features in text classification, motifs in genomic sequences), 1D CNNs are fast and effective. They also parallelize perfectly on GPUs, making training significantly faster than RNNs.

When to prefer RNNs. If your task requires maintaining long-term state (e.g., tracking the subject of a sentence across many clauses) or online processing (e.g., real-time speech recognition), RNNs are more natural. The hidden state provides a compact summary of the entire past that 1D CNNs lack.

Hybrid approaches. Some architectures combine both: use 1D convolutions for local feature extraction and feed the resulting features into an RNN for global temporal reasoning. This can be faster than a pure RNN approach while retaining the ability to model long-range dependencies.

In practice, both RNNs and 1D CNNs have been largely superseded by Transformers (Chapter 18) for most NLP tasks, but they remain important for time series, audio processing, and resource-constrained settings.


15.13 Practical Considerations

15.13.1 Input Representation

For text data, inputs are typically represented as word embeddings. Each token in the vocabulary is mapped to a dense vector:

embedding = nn.Embedding(num_embeddings=10000, embedding_dim=300)

You can initialize embeddings randomly (trained from scratch) or use pretrained embeddings like GloVe or Word2Vec. Pretrained embeddings provide a strong starting point, especially with limited data.

For numerical time series, inputs may be raw values, normalized values, or engineered features.

15.13.2 Padding and Packing

Sequences in a batch typically have different lengths. PyTorch handles this with padding and packing:

from torch.nn.utils.rnn import (
    pack_padded_sequence,
    pad_packed_sequence,
)

# Assume sequences is a padded tensor: (batch, max_len, features)
# lengths is a tensor of actual sequence lengths
packed = pack_padded_sequence(
    sequences, lengths, batch_first=True, enforce_sorted=False
)
output, (h_n, c_n) = lstm(packed)
output, output_lengths = pad_packed_sequence(output, batch_first=True)

Packing ensures that the LSTM does not process padding tokens, which improves both correctness and efficiency.

15.13.3 Choosing Hyperparameters

Based on extensive empirical research and practical experience:

Hyperparameter Recommended Range Notes
Hidden size 128--1024 Scale with data size
Number of layers 1--4 2 is a strong default
Dropout 0.2--0.5 Higher for larger models
Learning rate 1e-4 to 1e-2 Adam optimizer
Gradient clip norm 1.0--5.0 Essential for stability
Embedding dim 100--300 256 is a good default

15.13.4 Practical Tips for Training RNNs

Based on years of accumulated experience from the research community and practitioners:

1. Always use gradient clipping. This is non-negotiable for RNN training. Set max_norm between 1.0 and 5.0. Without it, a single batch can produce gradients large enough to destroy your model.

2. Initialize forget gate biases to 1. For LSTMs, set the forget gate bias to 1.0 (or even 2.0) at initialization. This encourages the network to remember by default, allowing it to learn what to forget rather than what to remember. This simple trick, suggested by Jozefowicz et al. (2015), can significantly improve performance on tasks with long-range dependencies.

# After creating an LSTM layer, initialize forget gate biases
for name, param in lstm.named_parameters():
    if "bias" in name:
        n = param.size(0)
        # bias is structured as [input_gate, forget_gate, cell_gate, output_gate]
        param.data[n // 4 : n // 2].fill_(1.0)

3. Use orthogonal initialization for recurrent weights. Initializing $\mathbf{W}_{hh}$ as an orthogonal matrix (all singular values equal to 1) prevents both vanishing and exploding gradients at the start of training:

for name, param in model.named_parameters():
    if "weight_hh" in name:
        torch.nn.init.orthogonal_(param)

4. Sort sequences by length. When training with mini-batches, sort sequences by length within each batch (or use pack_padded_sequence). This minimizes the amount of padding and avoids wasting computation on padding tokens.

5. Use bidirectional RNNs for classification. If the full sequence is available and you are performing classification or tagging, bidirectional RNNs almost always outperform unidirectional ones. The improvement is essentially free (only 2x parameters).

6. Start with a single-layer LSTM. Do not jump to deep, multi-layer RNNs immediately. A single-layer LSTM with 256--512 hidden units is a strong baseline for most tasks. Add layers only if you have enough data and the single-layer model is clearly underfitting.

7. Monitor hidden state norms. During training, log the L2 norm of the hidden state at each time step. If the norm grows without bound over the sequence, you likely have exploding activations (even if gradients are clipped). Reduce the hidden size or add layer normalization.

15.13.5 Common Pitfalls

  1. Forgetting to detach hidden states: When training on long sequences in chunks, detach the hidden state between chunks to prevent backpropagating through the entire history.
  2. Not using packing: Processing padding tokens wastes computation and can bias the model.
  3. Wrong hidden state indexing: For bidirectional multi-layer LSTMs, the hidden state shape can be confusing. Use careful indexing.
  4. Ignoring overfitting: RNNs with large hidden sizes overfit easily. Use dropout, weight decay, and early stopping.

15.14 RNNs in the Modern Landscape

15.14.1 The Rise of Transformers

Since 2017, Transformer-based models have largely replaced RNNs for many sequence tasks, especially in NLP. Transformers process all positions in parallel (rather than sequentially), which enables much faster training and better scaling.

However, RNNs remain relevant in several contexts:

  • Low-resource settings: When data is limited, simpler RNN models may outperform large Transformers
  • Real-time processing: RNNs process tokens one at a time, which suits streaming applications
  • Time series: For many time series tasks, LSTMs and GRUs remain competitive
  • Edge devices: RNNs have smaller memory footprints than Transformers
  • Hybrid architectures: Some modern models combine RNN-like recurrence with Transformer-like attention

15.14.2 State Space Models

Recent architectures like Mamba and S4 (Structured State Spaces for Sequence Modeling) draw inspiration from both RNNs and linear dynamical systems. They maintain the sequential processing advantage of RNNs while achieving Transformer-like performance, hinting that the story of recurrence is far from over.

The core idea of state space models is to model the hidden state dynamics as a continuous-time linear system:

$$\frac{d\mathbf{h}(t)}{dt} = \mathbf{A}\mathbf{h}(t) + \mathbf{B}\mathbf{x}(t), \quad \mathbf{y}(t) = \mathbf{C}\mathbf{h}(t)$$

When discretized, this becomes a recurrence similar to an RNN, but with carefully structured weight matrices that enable both efficient sequential processing (like an RNN) and parallel training (like a CNN). The key innovation is that the matrix $\mathbf{A}$ is parameterized using special structures (such as diagonal plus low-rank, or the HiPPO framework) that enable long-range memory. Mamba (Gu and Dao, 2023) further adds input-dependent gating, making the state space model's dynamics depend on the content being processed---much like LSTM gates, but with the efficiency advantages of the state space formulation. These models represent an exciting convergence of ideas from control theory, signal processing, and deep learning.

15.14.3 When to Choose RNNs

Choose RNNs when: - You need to process truly streaming data - Computational resources are limited - Sequence lengths are moderate (< 1000 tokens) - You are working on time series forecasting - Interpretability of the hidden state dynamics matters

Choose Transformers when: - You have large datasets and computational resources - Sequences are long and require global attention - You can benefit from pretrained models (BERT, GPT) - Parallelism during training is important


15.15 Summary

This chapter has taken you on a journey through the evolution of recurrent architectures:

  1. Vanilla RNNs introduced the recurrence relation, enabling sequential processing with shared parameters, but suffered from vanishing and exploding gradients.

  2. LSTMs solved the vanishing gradient problem with gated cell states, using forget, input, and output gates to control information flow. The additive cell state update is the key innovation.

  3. GRUs simplified the LSTM architecture to two gates (update and reset) while maintaining comparable performance with fewer parameters.

  4. Bidirectional RNNs process sequences in both directions, capturing context from past and future for classification and labeling tasks.

  5. Sequence-to-sequence models use an encoder-decoder architecture to map variable-length input sequences to variable-length output sequences.

  6. Teacher forcing accelerates training by providing ground-truth tokens as decoder input, though it introduces exposure bias.

  7. Beam search improves inference quality over greedy decoding by maintaining multiple hypotheses.

  8. Attention mechanisms preview how the information bottleneck is resolved by allowing the decoder to focus on relevant parts of the input at each step. The progression from fixed context vectors to attention to self-attention (Transformers) is one of the most important evolutionary paths in deep learning.

  9. 1D convolutions offer a parallel-friendly alternative to RNNs for local pattern extraction in sequences, though they lack the adaptive memory that makes RNNs powerful for long-range dependencies.

  10. Practical training of RNNs requires gradient clipping, careful initialization (orthogonal weights, forget gate bias of 1), and attention to sequence padding and packing.

The key takeaway is that sequence modeling requires architectures that can maintain and update a representation of context over time. Vanilla RNNs introduced this idea but suffered from vanishing gradients. LSTMs and GRUs solved this with learned gating mechanisms that control information flow. While Transformers have largely replaced RNNs for many tasks, the concepts of sequential processing, hidden state dynamics, and gated information flow remain central to understanding modern deep learning.

In the next chapter, we will explore attention mechanisms in depth and build the foundation for understanding the Transformer architecture, which has become the dominant paradigm in modern deep learning for sequences.


References

  1. Hochreiter, S., & Schmidhuber, J. (1997). Long Short-Term Memory. Neural Computation, 9(8), 1735--1780.
  2. Cho, K., et al. (2014). Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation. arXiv:1406.1078.
  3. Sutskever, I., Vinyals, O., & Le, Q. V. (2014). Sequence to Sequence Learning with Neural Networks. NeurIPS.
  4. Bahdanau, D., Cho, K., & Bengio, Y. (2015). Neural Machine Translation by Jointly Learning to Align and Translate. ICLR.
  5. Bengio, Y., Simard, P., & Frasconi, P. (1994). Learning Long-Term Dependencies with Gradient Descent is Difficult. IEEE Transactions on Neural Networks, 5(2), 157--166.
  6. Pascanu, R., Mikolov, T., & Bengio, Y. (2013). On the Difficulty of Training Recurrent Neural Networks. ICML.
  7. Schuster, M., & Paliwal, K. K. (1997). Bidirectional Recurrent Neural Networks. IEEE Transactions on Signal Processing, 45(11), 2673--2681.
  8. Gu, A., et al. (2022). Efficiently Modeling Long Sequences with Structured State Spaces. ICLR.