36 min read

> --- Ashish Vaswani et al., Attention Is All You Need (2017)

Chapter 18: The Attention Mechanism

"Attention is all you need." --- Ashish Vaswani et al., Attention Is All You Need (2017)

In Chapter 17, we explored sequence-to-sequence (seq2seq) models and saw how an encoder--decoder architecture can map variable-length input sequences to variable-length output sequences. We also encountered a fundamental limitation: the entire input sequence is compressed into a single fixed-length context vector. For short sequences, this works tolerably well. For long sequences, it fails dramatically --- the model simply cannot squeeze hundreds of tokens worth of information into a single vector without catastrophic loss.

The attention mechanism is the solution to this information bottleneck. Rather than forcing the decoder to rely on a single summary vector, attention allows the decoder to look back at the entire input sequence at every decoding step, focusing on the parts most relevant to the current output token. This simple idea --- letting the model dynamically select which inputs matter --- revolutionized natural language processing and ultimately gave rise to the Transformer architecture that powers modern large language models.

In this chapter, you will learn how attention works from first principles. We will build the mathematical foundations step by step, implement multi-head attention in PyTorch, and develop the intuition needed to understand why attention has become the dominant paradigm in deep learning for sequences.


18.1 The Information Bottleneck Problem

18.1.1 Fixed-Length Context Vectors Revisited

Recall from Chapter 17 the standard seq2seq encoder--decoder architecture. The encoder processes an input sequence $x_1, x_2, \ldots, x_T$ and produces a sequence of hidden states $\mathbf{h}_1, \mathbf{h}_2, \ldots, \mathbf{h}_T$. The final hidden state $\mathbf{h}_T$ serves as the context vector $\mathbf{c}$, which is passed to the decoder as its initial hidden state.

The decoder then generates the output sequence $y_1, y_2, \ldots, y_{T'}$ one token at a time, with each decoder hidden state computed as:

$$\mathbf{s}_t = f(\mathbf{s}_{t-1}, y_{t-1}, \mathbf{c})$$

where $\mathbf{s}_t$ is the decoder hidden state at time step $t$, $y_{t-1}$ is the previously generated token, and $\mathbf{c} = \mathbf{h}_T$ is the fixed context vector.

The problem is immediately apparent: $\mathbf{c}$ is a vector of fixed dimensionality (say, 256 or 512 dimensions), yet it must encode everything about the input sequence. Whether the input is 5 tokens or 500 tokens long, all information passes through the same narrow bottleneck.

18.1.2 Empirical Evidence of the Bottleneck

Cho et al. (2014) demonstrated this bottleneck empirically. They trained a standard encoder--decoder model on English-to-French translation and measured BLEU scores (a translation quality metric) as a function of sentence length. The results were stark:

  • For sentences up to 10--15 words, the model performed well
  • For sentences of 20--30 words, quality degraded significantly
  • For sentences beyond 30 words, translations became nearly incoherent

This length-dependent degradation is the hallmark of the information bottleneck. The context vector simply cannot faithfully represent long sequences.

18.1.3 The Key Insight: Dynamic Relevance

Consider how a human translator works. When translating a long sentence, you do not memorize the entire source sentence as a single mental "snapshot" and then produce the translation. Instead, you refer back to specific parts of the source sentence as you write each word of the translation. When translating a verb, you look at the source verb. When translating a noun phrase, you look at the corresponding source noun phrase.

This is precisely what attention does. Instead of compressing all encoder hidden states into one vector, attention allows the decoder to dynamically attend to different encoder hidden states at each decoding step. The decoder computes a weighted combination of all encoder hidden states, where the weights reflect the relevance of each input position to the current output position.

Formally, instead of a single fixed context vector $\mathbf{c}$, we now have a time-dependent context vector $\mathbf{c}_t$ for each decoder time step $t$:

$$\mathbf{c}_t = \sum_{j=1}^{T} \alpha_{tj} \mathbf{h}_j$$

where $\alpha_{tj}$ is the attention weight that the decoder at step $t$ assigns to the encoder hidden state at position $j$. These weights satisfy:

$$\sum_{j=1}^{T} \alpha_{tj} = 1, \quad \alpha_{tj} \geq 0$$

The attention weights form a probability distribution over the input positions, and the context vector is the corresponding expected value of the encoder hidden states.


18.2 Bahdanau Attention (Additive Attention)

18.2.1 The Bahdanau Architecture

The first attention mechanism for seq2seq models was proposed by Bahdanau, Cho, and Bengio (2015) in their landmark paper "Neural Machine Translation by Jointly Learning to Align and Translate." The mechanism is often called additive attention or concat attention because it computes alignment scores by feeding a concatenation of vectors through a small neural network.

The Bahdanau attention mechanism computes the attention weight $\alpha_{tj}$ in three steps:

Step 1: Compute alignment scores. For each encoder hidden state $\mathbf{h}_j$ and the current decoder hidden state $\mathbf{s}_{t-1}$, compute a scalar alignment score:

$$e_{tj} = \mathbf{v}_a^\top \tanh(\mathbf{W}_a \mathbf{s}_{t-1} + \mathbf{U}_a \mathbf{h}_j)$$

where:

  • $\mathbf{W}_a \in \mathbb{R}^{d_a \times d_s}$ is a learnable weight matrix that projects the decoder state
  • $\mathbf{U}_a \in \mathbb{R}^{d_a \times d_h}$ is a learnable weight matrix that projects the encoder state
  • $\mathbf{v}_a \in \mathbb{R}^{d_a}$ is a learnable weight vector
  • $d_a$ is the dimension of the alignment model (a hyperparameter)
  • $d_s$ is the decoder hidden state dimension
  • $d_h$ is the encoder hidden state dimension

The alignment score $e_{tj}$ measures how well the input at position $j$ "aligns with" (or is relevant to) the output at position $t$.

Step 2: Normalize to attention weights. Apply the softmax function across all input positions:

$$\alpha_{tj} = \frac{\exp(e_{tj})}{\sum_{k=1}^{T} \exp(e_{tk})}$$

This ensures the attention weights form a valid probability distribution.

Step 3: Compute the context vector. Take the weighted sum of encoder hidden states:

$$\mathbf{c}_t = \sum_{j=1}^{T} \alpha_{tj} \mathbf{h}_j$$

The decoder then uses this context vector along with its previous state and the previous output to compute the next hidden state:

$$\mathbf{s}_t = f(\mathbf{s}_{t-1}, y_{t-1}, \mathbf{c}_t)$$

18.2.2 Why "Additive"?

The name "additive attention" comes from the fact that the alignment score involves an addition inside the $\tanh$:

$$e_{tj} = \mathbf{v}_a^\top \tanh(\mathbf{W}_a \mathbf{s}_{t-1} + \mathbf{U}_a \mathbf{h}_j)$$

The projected decoder state and projected encoder state are added together before being passed through the nonlinearity. This is in contrast to multiplicative attention, which we will see next.

18.2.3 Worked Example: Bahdanau Attention

Let us trace through a concrete example. Suppose we have an encoder with 4 time steps and hidden dimension $d_h = 3$, a decoder hidden dimension $d_s = 3$, and alignment dimension $d_a = 2$.

The encoder hidden states are:

$$\mathbf{h}_1 = \begin{bmatrix} 0.2 \\ 0.5 \\ -0.1 \end{bmatrix}, \quad \mathbf{h}_2 = \begin{bmatrix} 0.8 \\ -0.3 \\ 0.4 \end{bmatrix}, \quad \mathbf{h}_3 = \begin{bmatrix} -0.1 \\ 0.7 \\ 0.6 \end{bmatrix}, \quad \mathbf{h}_4 = \begin{bmatrix} 0.3 \\ 0.1 \\ -0.5 \end{bmatrix}$$

The current decoder state is:

$$\mathbf{s}_{t-1} = \begin{bmatrix} 0.4 \\ -0.2 \\ 0.7 \end{bmatrix}$$

With learned parameters $\mathbf{W}_a \in \mathbb{R}^{2 \times 3}$, $\mathbf{U}_a \in \mathbb{R}^{2 \times 3}$, and $\mathbf{v}_a \in \mathbb{R}^{2}$, we would compute the projected decoder state $\mathbf{W}_a \mathbf{s}_{t-1}$ (a 2D vector), add it to each projected encoder state $\mathbf{U}_a \mathbf{h}_j$, apply $\tanh$, and then dot with $\mathbf{v}_a$ to get a scalar alignment score $e_{tj}$ for each position $j$.

After softmax, suppose we obtain:

$$\boldsymbol{\alpha}_t = [0.05, \; 0.10, \; 0.75, \; 0.10]$$

This tells us the decoder at step $t$ is paying 75% of its attention to position 3. The context vector would then be:

$$\mathbf{c}_t = 0.05 \cdot \mathbf{h}_1 + 0.10 \cdot \mathbf{h}_2 + 0.75 \cdot \mathbf{h}_3 + 0.10 \cdot \mathbf{h}_4$$

18.2.4 PyTorch Implementation of Bahdanau Attention

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(42)


class BahdanauAttention(nn.Module):
    """Bahdanau (additive) attention mechanism.

    Computes attention weights using a learned alignment model
    that combines encoder and decoder states through addition
    and a tanh nonlinearity.

    Args:
        encoder_dim: Dimensionality of encoder hidden states.
        decoder_dim: Dimensionality of decoder hidden states.
        attention_dim: Dimensionality of the alignment model.
    """

    def __init__(
        self,
        encoder_dim: int,
        decoder_dim: int,
        attention_dim: int,
    ) -> None:
        super().__init__()
        self.W_a = nn.Linear(decoder_dim, attention_dim, bias=False)
        self.U_a = nn.Linear(encoder_dim, attention_dim, bias=False)
        self.v_a = nn.Linear(attention_dim, 1, bias=False)

    def forward(
        self,
        decoder_state: torch.Tensor,
        encoder_outputs: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Compute attention-weighted context vector.

        Args:
            decoder_state: Current decoder hidden state
                of shape (batch_size, decoder_dim).
            encoder_outputs: All encoder hidden states
                of shape (batch_size, seq_len, encoder_dim).

        Returns:
            context: Weighted context vector
                of shape (batch_size, encoder_dim).
            attention_weights: Attention distribution
                of shape (batch_size, seq_len).
        """
        # decoder_state: (batch_size, decoder_dim)
        # -> (batch_size, 1, attention_dim)
        query = self.W_a(decoder_state).unsqueeze(1)

        # encoder_outputs: (batch_size, seq_len, encoder_dim)
        # -> (batch_size, seq_len, attention_dim)
        keys = self.U_a(encoder_outputs)

        # Alignment scores: (batch_size, seq_len, 1)
        scores = self.v_a(torch.tanh(query + keys))
        scores = scores.squeeze(-1)  # (batch_size, seq_len)

        # Attention weights via softmax
        attention_weights = F.softmax(scores, dim=-1)

        # Context vector: weighted sum of encoder outputs
        context = torch.bmm(
            attention_weights.unsqueeze(1), encoder_outputs
        ).squeeze(1)

        return context, attention_weights


# --- Demo ---
batch_size, seq_len, encoder_dim, decoder_dim = 2, 10, 64, 64
attention_dim = 32

attention = BahdanauAttention(encoder_dim, decoder_dim, attention_dim)

encoder_outputs = torch.randn(batch_size, seq_len, encoder_dim)
decoder_state = torch.randn(batch_size, decoder_dim)

context, weights = attention(decoder_state, encoder_outputs)
print(f"Context shape: {context.shape}")      # (2, 64)
print(f"Weights shape: {weights.shape}")       # (2, 10)
print(f"Weights sum:   {weights.sum(dim=-1)}") # [1.0, 1.0]

18.3 Luong Attention (Multiplicative Attention)

18.3.1 Three Scoring Functions

Shortly after Bahdanau's work, Luong, Pham, and Manning (2015) proposed a family of attention mechanisms that are simpler and often faster. The key difference is in how alignment scores are computed. Luong proposed three scoring functions:

Dot product:

$$e_{tj} = \mathbf{s}_t^\top \mathbf{h}_j$$

This is the simplest form --- just a dot product between the decoder and encoder hidden states. It requires that $d_s = d_h$ (same dimensionality).

General (bilinear):

$$e_{tj} = \mathbf{s}_t^\top \mathbf{W}_a \mathbf{h}_j$$

where $\mathbf{W}_a \in \mathbb{R}^{d_s \times d_h}$ is a learnable weight matrix. This allows different dimensionalities and adds a learnable interaction.

Concat (similar to Bahdanau):

$$e_{tj} = \mathbf{v}_a^\top \tanh(\mathbf{W}_a [\mathbf{s}_t ; \mathbf{h}_j])$$

where $[\mathbf{s}_t ; \mathbf{h}_j]$ denotes concatenation.

The general scoring function is the one most commonly associated with "Luong attention" and is often called multiplicative attention because it involves a matrix multiplication between the decoder and encoder representations.

18.3.2 Luong vs. Bahdanau: Key Differences

Beyond the scoring function, there are several architectural differences between Luong and Bahdanau attention:

Aspect Bahdanau Luong
Scoring Additive (two projections + tanh) Multiplicative (dot or bilinear)
Decoder state used $\mathbf{s}_{t-1}$ (previous) $\mathbf{s}_t$ (current)
Encoder states Bidirectional RNN Top-layer unidirectional
Context usage Concatenated with input to decoder Concatenated with decoder output
Computational cost Higher (more parameters) Lower (especially dot product)

18.3.3 Global vs. Local Attention

Luong et al. also introduced the distinction between global attention and local attention:

  • Global attention attends to all encoder hidden states at every time step. This is what we have described so far.
  • Local attention first predicts an aligned position $p_t$ in the source sentence, then attends only to a window of encoder hidden states around $p_t$. This reduces computational cost for long sequences.

Local attention can be seen as a blend between hard attention (which selects a single position) and soft attention (which attends to all positions). In practice, global attention with efficient implementations has become the dominant approach, but local attention foreshadowed ideas like sliding-window attention in modern efficient Transformers.

18.3.4 Bahdanau vs. Luong: A Detailed Comparison

The choice between Bahdanau and Luong attention involves several practical tradeoffs that are worth understanding in depth.

Computational cost. Bahdanau attention requires three matrix multiplications per query-key pair: $\mathbf{W}_a \mathbf{s}_{t-1}$, $\mathbf{U}_a \mathbf{h}_j$, and $\mathbf{v}_a^\top(\cdot)$, plus a $\tanh$ nonlinearity. Luong's dot-product attention requires only a single dot product $\mathbf{s}_t^\top \mathbf{h}_j$. For a sequence of length $T$ and hidden dimension $d$, Bahdanau attention computes $O(T \cdot d_a \cdot (d_s + d_h))$ operations per decoder step, while Luong dot attention computes $O(T \cdot d)$ operations. In practice, this makes Luong attention significantly faster, especially for long sequences.

Expressiveness. Bahdanau's additive formulation is more expressive because the $\tanh$ nonlinearity allows it to model non-linear relationships between the decoder state and encoder states. Luong's dot product is bilinear --- it can only model interactions that are linear in both the query and key. The "general" variant $\mathbf{s}_t^\top \mathbf{W}_a \mathbf{h}_j$ adds a learnable interaction matrix but is still bilinear. In practice, the difference in expressiveness is often compensated by the rest of the model's capacity.

When the decoder state is used. Bahdanau attention uses the previous decoder state $\mathbf{s}_{t-1}$ to compute attention, then incorporates the resulting context vector into the computation of $\mathbf{s}_t$. Luong attention uses the current decoder state $\mathbf{s}_t$ (computed without attention), computes the context vector, and then combines both into a refined output. This is a subtle but consequential difference in the information flow.

Which to use? In modern practice, neither variant is used directly. The scaled dot-product attention in Transformers (Section 18.5) has superseded both. However, understanding Bahdanau and Luong attention is valuable because they demonstrate the core principles --- alignment scoring, softmax normalization, and context aggregation --- that underlie all attention mechanisms.

18.3.5 PyTorch Implementation of Luong Attention

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(42)


class LuongAttention(nn.Module):
    """Luong (multiplicative) attention mechanism.

    Supports three scoring methods: dot, general, and concat.

    Args:
        encoder_dim: Dimensionality of encoder hidden states.
        decoder_dim: Dimensionality of decoder hidden states.
        method: Scoring method ('dot', 'general', or 'concat').
    """

    def __init__(
        self,
        encoder_dim: int,
        decoder_dim: int,
        method: str = "general",
    ) -> None:
        super().__init__()
        self.method = method

        if method == "general":
            self.W_a = nn.Linear(encoder_dim, decoder_dim, bias=False)
        elif method == "concat":
            self.W_a = nn.Linear(
                encoder_dim + decoder_dim, decoder_dim, bias=False
            )
            self.v_a = nn.Linear(decoder_dim, 1, bias=False)
        elif method == "dot":
            assert encoder_dim == decoder_dim, (
                "Dot scoring requires encoder_dim == decoder_dim"
            )
        else:
            raise ValueError(f"Unknown method: {method}")

    def forward(
        self,
        decoder_state: torch.Tensor,
        encoder_outputs: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Compute Luong attention.

        Args:
            decoder_state: Current decoder hidden state
                of shape (batch_size, decoder_dim).
            encoder_outputs: Encoder hidden states
                of shape (batch_size, seq_len, encoder_dim).

        Returns:
            context: Context vector of shape (batch_size, encoder_dim).
            attention_weights: Weights of shape (batch_size, seq_len).
        """
        if self.method == "dot":
            # (batch, seq_len)
            scores = torch.bmm(
                encoder_outputs,
                decoder_state.unsqueeze(-1),
            ).squeeze(-1)

        elif self.method == "general":
            # Project encoder outputs, then dot with decoder state
            projected = self.W_a(encoder_outputs)  # (batch, seq, dec_dim)
            scores = torch.bmm(
                projected,
                decoder_state.unsqueeze(-1),
            ).squeeze(-1)

        elif self.method == "concat":
            seq_len = encoder_outputs.size(1)
            # Repeat decoder state for each position
            repeated = decoder_state.unsqueeze(1).expand(
                -1, seq_len, -1
            )
            combined = torch.cat([repeated, encoder_outputs], dim=-1)
            scores = self.v_a(torch.tanh(self.W_a(combined))).squeeze(-1)

        attention_weights = F.softmax(scores, dim=-1)
        context = torch.bmm(
            attention_weights.unsqueeze(1), encoder_outputs
        ).squeeze(1)

        return context, attention_weights


# --- Demo ---
batch_size, seq_len, dim = 2, 10, 64

for method in ["dot", "general", "concat"]:
    attention = LuongAttention(dim, dim, method=method)
    enc_out = torch.randn(batch_size, seq_len, dim)
    dec_state = torch.randn(batch_size, dim)
    ctx, wts = attention(dec_state, enc_out)
    print(f"[{method:>7s}] Context: {ctx.shape}, Weights sum: "
          f"{wts.sum(dim=-1).tolist()}")

18.4 Self-Attention and the Query-Key-Value Framework

18.4.1 From Cross-Attention to Self-Attention

The attention mechanisms we have seen so far are examples of cross-attention: the decoder attends to the encoder. The query comes from one sequence (the decoder), and the keys and values come from another sequence (the encoder).

Self-attention (also called intra-attention) is a different paradigm: a sequence attends to itself. Each position in the input sequence computes attention weights over all other positions in the same sequence. Self-attention allows the model to capture dependencies between any two positions in a sequence, regardless of their distance.

For example, in the sentence "The animal didn't cross the street because it was too tired," self-attention at the position of "it" should assign high weight to "animal" --- learning that "it" refers to "the animal." This kind of long-range dependency resolution is exactly what self-attention excels at.

18.4.2 The Query-Key-Value Abstraction

To unify all forms of attention under a single framework, we introduce the Query-Key-Value (QKV) abstraction. This is one of the most important conceptual frameworks in modern deep learning.

Think of attention as a differentiable dictionary lookup:

  • Query (Q): What you are looking for --- "I need information relevant to this."
  • Key (K): An index or label for each item --- "Here is what I contain."
  • Value (V): The actual content --- "Here is my information."

In a traditional dictionary (hash map), you provide a key and get back the exact matching value. In attention, you provide a query, compute its similarity to every key, and return a weighted combination of all values, where the weights are proportional to the query-key similarities.

Formally, given:

  • A query matrix $\mathbf{Q} \in \mathbb{R}^{n \times d_k}$ (one row per query)
  • A key matrix $\mathbf{K} \in \mathbb{R}^{m \times d_k}$ (one row per key)
  • A value matrix $\mathbf{V} \in \mathbb{R}^{m \times d_v}$ (one row per value)

attention computes:

$$\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\!\left(\frac{\mathbf{Q}\mathbf{K}^\top}{\sqrt{d_k}}\right)\mathbf{V}$$

where:

  • $\mathbf{Q}\mathbf{K}^\top \in \mathbb{R}^{n \times m}$ is the matrix of all pairwise query-key dot products
  • $d_k$ is the dimension of the keys (and queries)
  • The softmax is applied row-wise (each query gets its own probability distribution over keys)
  • The result is in $\mathbb{R}^{n \times d_v}$ --- one output vector per query

18.4.3 Mapping Previous Attention Variants to QKV

With the QKV framework, we can express all previous attention variants:

Bahdanau/Luong cross-attention: - $\mathbf{Q}$: decoder hidden states - $\mathbf{K}$: encoder hidden states - $\mathbf{V}$: encoder hidden states (K = V in most formulations)

Self-attention: - $\mathbf{Q}$, $\mathbf{K}$, $\mathbf{V}$: all derived from the same input sequence, typically through learned linear projections:

$$\mathbf{Q} = \mathbf{X}\mathbf{W}^Q, \quad \mathbf{K} = \mathbf{X}\mathbf{W}^K, \quad \mathbf{V} = \mathbf{X}\mathbf{W}^V$$

where $\mathbf{X} \in \mathbb{R}^{n \times d_{\text{model}}}$ is the input sequence and $\mathbf{W}^Q, \mathbf{W}^K \in \mathbb{R}^{d_{\text{model}} \times d_k}$, $\mathbf{W}^V \in \mathbb{R}^{d_{\text{model}} \times d_v}$ are learnable projection matrices.

18.4.4 Why Three Separate Projections?

A natural question is: why not just use $\mathbf{X}$ directly for queries, keys, and values? The answer is that separate projections give the model much more flexibility:

  1. The query projection learns what to look for
  2. The key projection learns what to advertise
  3. The value projection learns what to communicate

These are fundamentally different roles. A token might advertise itself as "a noun" (via its key) so that verbs can find it (via their queries), but the information it actually communicates (via its value) might be its semantic embedding. Without separate projections, these roles would be conflated.


18.5 Scaled Dot-Product Attention: Full Derivation

18.5.1 Why Scale by $\sqrt{d_k}$?

The scaling factor $\frac{1}{\sqrt{d_k}}$ in the attention formula is not arbitrary --- it serves a critical numerical purpose. Let us derive why it is necessary.

Consider two random vectors $\mathbf{q}, \mathbf{k} \in \mathbb{R}^{d_k}$, where each component is independently drawn from a distribution with mean 0 and variance 1. Their dot product is:

$$\mathbf{q} \cdot \mathbf{k} = \sum_{i=1}^{d_k} q_i k_i$$

Each term $q_i k_i$ is the product of two independent random variables, each with mean 0 and variance 1. Therefore:

  • $\mathbb{E}[q_i k_i] = \mathbb{E}[q_i]\mathbb{E}[k_i] = 0$
  • $\text{Var}(q_i k_i) = \mathbb{E}[q_i^2]\mathbb{E}[k_i^2] = 1 \cdot 1 = 1$

Since the dot product is a sum of $d_k$ independent terms, each with mean 0 and variance 1:

$$\mathbb{E}[\mathbf{q} \cdot \mathbf{k}] = 0, \quad \text{Var}(\mathbf{q} \cdot \mathbf{k}) = d_k$$

So the standard deviation of the dot product is $\sqrt{d_k}$. As $d_k$ grows, the dot products grow in magnitude, pushing the softmax into regions where its gradients are extremely small (the saturated regime of softmax).

The fix: Divide by $\sqrt{d_k}$ to normalize the variance back to 1:

$$\text{Var}\!\left(\frac{\mathbf{q} \cdot \mathbf{k}}{\sqrt{d_k}}\right) = \frac{\text{Var}(\mathbf{q} \cdot \mathbf{k})}{d_k} = \frac{d_k}{d_k} = 1$$

This keeps the softmax inputs in a well-conditioned range regardless of the dimensionality.

18.5.2 Numerical Demonstration

To see this concretely, consider $d_k = 512$. Without scaling, a typical dot product might be around $\pm 22$ (since $\sqrt{512} \approx 22.6$). If we compute softmax over scores like $[20, -15, 22, -10, 18, \ldots]$, the result would be extremely peaked --- essentially a one-hot vector --- with near-zero gradients for most positions. Scaling by $\sqrt{512}$ brings these values down to approximately $\pm 1$, where softmax produces a more informative gradient landscape.

18.5.3 Step-by-Step Computation

Let us walk through a complete example of scaled dot-product attention with concrete numbers.

Setup: Suppose we have 3 tokens and $d_k = d_v = 4$.

$$\mathbf{Q} = \begin{bmatrix} 1 & 0 & 1 & 0 \\ 0 & 1 & 0 & 1 \\ 1 & 1 & 0 & 0 \end{bmatrix}, \quad \mathbf{K} = \begin{bmatrix} 1 & 1 & 0 & 0 \\ 0 & 0 & 1 & 1 \\ 1 & 0 & 1 & 0 \end{bmatrix}, \quad \mathbf{V} = \begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8 \\ 9 & 10 & 11 & 12 \end{bmatrix}$$

Step 1: Compute $\mathbf{Q}\mathbf{K}^\top$

$$\mathbf{Q}\mathbf{K}^\top = \begin{bmatrix} 1 \cdot 1 + 0 \cdot 1 + 1 \cdot 0 + 0 \cdot 0 & 1 \cdot 0 + 0 \cdot 0 + 1 \cdot 1 + 0 \cdot 1 & 1 \cdot 1 + 0 \cdot 0 + 1 \cdot 1 + 0 \cdot 0 \\ 0 + 1 + 0 + 0 & 0 + 0 + 0 + 1 & 0 + 0 + 0 + 0 \\ 1 + 1 + 0 + 0 & 0 + 0 + 0 + 0 & 1 + 0 + 0 + 0 \end{bmatrix} = \begin{bmatrix} 1 & 1 & 2 \\ 1 & 1 & 0 \\ 2 & 0 & 1 \end{bmatrix}$$

Step 2: Scale by $\sqrt{d_k} = \sqrt{4} = 2$

$$\frac{\mathbf{Q}\mathbf{K}^\top}{\sqrt{d_k}} = \begin{bmatrix} 0.5 & 0.5 & 1.0 \\ 0.5 & 0.5 & 0.0 \\ 1.0 & 0.0 & 0.5 \end{bmatrix}$$

Step 3: Apply softmax (row-wise)

For row 1: softmax$([0.5, 0.5, 1.0]) = [0.2312, 0.2312, 0.5376]$

For row 2: softmax$([0.5, 0.5, 0.0]) = [0.3876, 0.3876, 0.2249]$

For row 3: softmax$([1.0, 0.0, 0.5]) = [0.4656, 0.1713, 0.3631]$

$$\boldsymbol{\alpha} = \begin{bmatrix} 0.2312 & 0.2312 & 0.5376 \\ 0.3876 & 0.3876 & 0.2249 \\ 0.4656 & 0.1713 & 0.3631 \end{bmatrix}$$

Step 4: Multiply by $\mathbf{V}$

Each output row is a weighted combination of the value rows:

$$\text{Output}_1 = 0.2312 \cdot [1,2,3,4] + 0.2312 \cdot [5,6,7,8] + 0.5376 \cdot [9,10,11,12]$$ $$= [6.2288, 7.2288, 8.2288, 9.2288]$$

We see that token 1 is attending mostly to token 3 (weight 0.5376), so its output is pulled toward $\mathbf{V}_3 = [9, 10, 11, 12]$.

18.5.4 PyTorch Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(42)


def scaled_dot_product_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    mask: torch.Tensor | None = None,
    dropout: nn.Dropout | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute scaled dot-product attention.

    Args:
        query: Query tensor of shape (..., seq_len_q, d_k).
        key: Key tensor of shape (..., seq_len_k, d_k).
        value: Value tensor of shape (..., seq_len_k, d_v).
        mask: Optional mask tensor. Positions with True (or 1)
            are masked (set to -inf before softmax).
        dropout: Optional dropout layer applied to attention weights.

    Returns:
        output: Attention output of shape (..., seq_len_q, d_v).
        attention_weights: Weights of shape (..., seq_len_q, seq_len_k).
    """
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / (d_k ** 0.5)

    if mask is not None:
        scores = scores.masked_fill(mask.bool(), float("-inf"))

    attention_weights = F.softmax(scores, dim=-1)

    if dropout is not None:
        attention_weights = dropout(attention_weights)

    output = torch.matmul(attention_weights, value)
    return output, attention_weights


# --- Demo ---
seq_len, d_k, d_v = 3, 4, 4
Q = torch.tensor(
    [[1, 0, 1, 0], [0, 1, 0, 1], [1, 1, 0, 0]], dtype=torch.float32
)
K = torch.tensor(
    [[1, 1, 0, 0], [0, 0, 1, 1], [1, 0, 1, 0]], dtype=torch.float32
)
V = torch.tensor(
    [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=torch.float32
)

output, weights = scaled_dot_product_attention(Q, K, V)
print("Attention weights:\n", weights)
print("Output:\n", output)

18.6 Multi-Head Attention

18.6.1 Motivation: Multiple Representation Subspaces

A single attention head computes one set of attention weights. But in language, a token may need to attend to different positions for different reasons simultaneously. For example, in "The cat sat on the mat because it was soft," the word "it" needs to attend to "mat" for coreference resolution and to "soft" for semantic agreement, simultaneously.

Multi-head attention addresses this by running multiple attention operations in parallel, each with its own learned projections. Each "head" can learn to focus on a different type of relationship:

  • One head might learn positional relationships (nearby words)
  • Another might learn syntactic relationships (subject-verb agreement)
  • Another might learn semantic relationships (entity coreference)

18.6.2 Mathematical Formulation

Multi-head attention with $h$ heads is defined as:

$$\text{MultiHead}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)\mathbf{W}^O$$

where each head is:

$$\text{head}_i = \text{Attention}(\mathbf{Q}\mathbf{W}_i^Q, \mathbf{K}\mathbf{W}_i^K, \mathbf{V}\mathbf{W}_i^V)$$

and:

  • $\mathbf{W}_i^Q \in \mathbb{R}^{d_{\text{model}} \times d_k}$ projects queries for head $i$
  • $\mathbf{W}_i^K \in \mathbb{R}^{d_{\text{model}} \times d_k}$ projects keys for head $i$
  • $\mathbf{W}_i^V \in \mathbb{R}^{d_{\text{model}} \times d_v}$ projects values for head $i$
  • $\mathbf{W}^O \in \mathbb{R}^{hd_v \times d_{\text{model}}}$ is the output projection

Typically, $d_k = d_v = d_{\text{model}} / h$, so the total computational cost is similar to single-head attention with full dimensionality.

18.6.3 Why Divide the Dimensions?

If $d_{\text{model}} = 512$ and $h = 8$, then each head operates on $d_k = d_v = 64$ dimensions. The total computation across all heads is:

$$8 \times (n \times 64 \times 64) = n \times 512 \times 64$$

This is approximately the same as a single head with $d_k = 512$:

$$1 \times (n \times 512 \times 512) = n \times 512 \times 512$$

Actually, the multi-head version is cheaper because $8 \times 64^2 = 32{,}768 < 512^2 = 262{,}144$. We get multiple specialized attention patterns for less compute.

Parameter count analysis. Let us compute the exact parameter count for multi-head attention with $d_{\text{model}} = 512$ and $h = 8$:

  • $\mathbf{W}_i^Q$: $h$ matrices of size $d_{\text{model}} \times d_k = 512 \times 64$. Combined: $512 \times 512 = 262{,}144$ parameters.
  • $\mathbf{W}_i^K$: Same as $\mathbf{W}_i^Q$: $262{,}144$ parameters.
  • $\mathbf{W}_i^V$: Same: $262{,}144$ parameters.
  • $\mathbf{W}^O$: $hd_v \times d_{\text{model}} = 512 \times 512 = 262{,}144$ parameters.
  • Total: $4 \times 262{,}144 = 1{,}048{,}576$ parameters.

This is the same as four $512 \times 512$ matrices. The multi-head structure is an inductive bias, not a parameter efficiency trick --- we use the same number of parameters but organize them to learn multiple attention patterns.

18.6.4 Mathematical Derivation: Why Multi-Head Attention Works

To build deeper intuition, consider what happens when we use a single head with the full $d_{\text{model}}$ dimension. The attention weights are:

$$\alpha_{ij} = \frac{\exp(\mathbf{q}_i^\top \mathbf{k}_j / \sqrt{d_{\text{model}}})}{\sum_l \exp(\mathbf{q}_i^\top \mathbf{k}_l / \sqrt{d_{\text{model}}})}$$

This is a single soft selection over positions. Token $i$ can attend most strongly to one position, or spread attention broadly, but it computes only one attention distribution.

With $h$ heads, token $i$ computes $h$ independent attention distributions, each in a different subspace:

$$\alpha_{ij}^{(r)} = \frac{\exp((\mathbf{W}_r^Q \mathbf{x}_i)^\top (\mathbf{W}_r^K \mathbf{x}_j) / \sqrt{d_k})}{\sum_l \exp((\mathbf{W}_r^Q \mathbf{x}_i)^\top (\mathbf{W}_r^K \mathbf{x}_l) / \sqrt{d_k})} \quad \text{for } r = 1, \ldots, h$$

Each head can attend to a different position. Head 1 might attend to the syntactic head of a phrase, head 2 to a semantically related word, and head 3 to the previous token. The output projection $\mathbf{W}^O$ learns to combine these different perspectives into a single representation.

This is analogous to having multiple "read heads" on a memory --- each head can independently access a different part of the sequence, and the results are combined to form a richer representation than any single head could provide.

18.6.4 Full PyTorch Implementation from Scratch

Here is a complete, production-style implementation of multi-head attention:

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(42)


class MultiHeadAttention(nn.Module):
    """Multi-head attention mechanism.

    Implements the multi-head attention from 'Attention Is All You Need'
    (Vaswani et al., 2017). Splits the model dimension across multiple
    heads, applies scaled dot-product attention in parallel, and
    concatenates the results.

    Args:
        d_model: The model embedding dimension.
        num_heads: Number of parallel attention heads.
        dropout: Dropout probability for attention weights.
    """

    def __init__(
        self,
        d_model: int,
        num_heads: int,
        dropout: float = 0.0,
    ) -> None:
        super().__init__()
        assert d_model % num_heads == 0, (
            f"d_model ({d_model}) must be divisible by "
            f"num_heads ({num_heads})"
        )

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # Linear projections for Q, K, V, and output
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

        self.dropout = nn.Dropout(dropout)
        self.attention_weights = None  # Store for visualization

    def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
        """Split the last dimension into (num_heads, d_k).

        Args:
            x: Tensor of shape (batch_size, seq_len, d_model).

        Returns:
            Tensor of shape (batch_size, num_heads, seq_len, d_k).
        """
        batch_size, seq_len, _ = x.size()
        x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
        return x.transpose(1, 2)

    def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
        """Reverse the split_heads operation.

        Args:
            x: Tensor of shape (batch_size, num_heads, seq_len, d_k).

        Returns:
            Tensor of shape (batch_size, seq_len, d_model).
        """
        batch_size, _, seq_len, _ = x.size()
        x = x.transpose(1, 2).contiguous()
        return x.view(batch_size, seq_len, self.d_model)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Apply multi-head attention.

        Args:
            query: Query tensor of shape (batch_size, seq_len_q, d_model).
            key: Key tensor of shape (batch_size, seq_len_k, d_model).
            value: Value tensor of shape (batch_size, seq_len_k, d_model).
            mask: Optional mask of shape (batch_size, 1, 1, seq_len_k)
                or (batch_size, 1, seq_len_q, seq_len_k).
                True positions are masked out.

        Returns:
            Output tensor of shape (batch_size, seq_len_q, d_model).
        """
        # Project inputs
        Q = self._split_heads(self.W_q(query))   # (B, h, n, d_k)
        K = self._split_heads(self.W_k(key))     # (B, h, m, d_k)
        V = self._split_heads(self.W_v(value))   # (B, h, m, d_k)

        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1))  # (B, h, n, m)
        scores = scores / math.sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask.bool(), float("-inf"))

        attn_weights = F.softmax(scores, dim=-1)  # (B, h, n, m)
        self.attention_weights = attn_weights.detach()
        attn_weights = self.dropout(attn_weights)

        # Weighted combination of values
        context = torch.matmul(attn_weights, V)  # (B, h, n, d_k)

        # Merge heads and project
        output = self._merge_heads(context)  # (B, n, d_model)
        output = self.W_o(output)

        return output


# --- Demo ---
batch_size = 2
seq_len = 10
d_model = 64
num_heads = 8

mha = MultiHeadAttention(d_model, num_heads, dropout=0.1)

x = torch.randn(batch_size, seq_len, d_model)

# Self-attention: Q = K = V = x
output = mha(x, x, x)
print(f"Input shape:  {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {mha.attention_weights.shape}")
# (2, 8, 10, 10) -> batch, heads, queries, keys

# Cross-attention: Q from one sequence, K and V from another
memory = torch.randn(batch_size, 20, d_model)
output_cross = mha(x, memory, memory)
print(f"\nCross-attention output shape: {output_cross.shape}")
print(f"Cross-attention weights shape: "
      f"{mha.attention_weights.shape}")
# (2, 8, 10, 20) -> queries attend to 20 memory positions

18.6.5 Efficient Implementation Notes

In practice, the three separate linear projections for Q, K, and V can be combined into a single large matrix multiplication for self-attention (where all three come from the same input):

# Instead of three separate projections:
# Q = self.W_q(x)
# K = self.W_k(x)
# V = self.W_v(x)

# Use a single projection and split:
# self.qkv_proj = nn.Linear(d_model, 3 * d_model)
# qkv = self.qkv_proj(x)
# Q, K, V = qkv.chunk(3, dim=-1)

This fused projection is more GPU-friendly because it performs one large matrix multiplication instead of three smaller ones. PyTorch's nn.MultiheadAttention uses this optimization internally.


18.7 Attention Masking

18.7.1 Why Masking Is Necessary

Attention masking serves two essential purposes:

  1. Padding masks: When processing batches of sequences with different lengths, shorter sequences are padded to the maximum length. We must prevent attention from attending to padding tokens.

  2. Causal (look-ahead) masks: In autoregressive models (e.g., language models), each position should only attend to previous positions. A position should not be able to "see the future."

18.7.2 Padding Masks

Given a batch of sequences with lengths $[5, 3, 7]$ padded to length 7, the padding mask for the second sequence (length 3) would be:

$$\text{mask} = [\underbrace{0, 0, 0}_{\text{real tokens}}, \underbrace{1, 1, 1, 1}_{\text{padding}}]$$

This mask is applied to the attention scores before softmax by setting masked positions to $-\infty$:

$$\text{scores}_{ij} \leftarrow \begin{cases} \text{scores}_{ij} & \text{if } \text{mask}_j = 0 \\ -\infty & \text{if } \text{mask}_j = 1 \end{cases}$$

After softmax, $e^{-\infty} = 0$, so padding positions receive zero attention weight.

18.7.3 Causal Masks

A causal mask is an upper-triangular matrix that prevents position $i$ from attending to any position $j > i$:

$$\text{CausalMask} = \begin{bmatrix} 0 & 1 & 1 & 1 \\ 0 & 0 & 1 & 1 \\ 0 & 0 & 0 & 1 \\ 0 & 0 & 0 & 0 \end{bmatrix}$$

where 1 indicates "masked" (set to $-\infty$). Position 1 can only attend to itself, position 2 can attend to positions 1 and 2, and so on.

import torch

torch.manual_seed(42)


def create_causal_mask(seq_len: int) -> torch.Tensor:
    """Create a causal (look-ahead) mask.

    Args:
        seq_len: Length of the sequence.

    Returns:
        Upper triangular boolean mask of shape (seq_len, seq_len).
        True values indicate positions to be masked.
    """
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
    return mask.bool()


def create_padding_mask(
    lengths: torch.Tensor, max_len: int
) -> torch.Tensor:
    """Create a padding mask from sequence lengths.

    Args:
        lengths: Actual lengths of shape (batch_size,).
        max_len: Maximum sequence length (padded length).

    Returns:
        Boolean mask of shape (batch_size, 1, 1, max_len).
        True values indicate padding positions to be masked.
    """
    batch_size = lengths.size(0)
    positions = torch.arange(max_len).unsqueeze(0)  # (1, max_len)
    mask = positions >= lengths.unsqueeze(1)           # (batch, max_len)
    return mask.unsqueeze(1).unsqueeze(1)  # (batch, 1, 1, max_len)


# --- Demo ---
causal = create_causal_mask(5)
print("Causal mask:")
print(causal.int())

lengths = torch.tensor([5, 3, 4])
padding = create_padding_mask(lengths, max_len=5)
print("\nPadding mask (sequence 2, length=3):")
print(padding[1, 0, 0].int())  # [0, 0, 0, 1, 1]

18.7.4 Combining Masks

In decoder self-attention for sequence-to-sequence models, we often need both causal and padding masks simultaneously. These are combined with a logical OR:

combined_mask = causal_mask | padding_mask

This ensures that a position can only attend to previous, non-padding positions.


18.8 Computational Complexity of Attention

18.8.1 Time and Space Complexity

The core operation in attention is the matrix multiplication $\mathbf{Q}\mathbf{K}^\top$, which produces an $n \times m$ attention matrix (where $n$ is the number of queries and $m$ is the number of keys). For self-attention, $n = m$, so:

Operation Time Complexity Space Complexity
$\mathbf{Q}\mathbf{K}^\top$ $O(n^2 d_k)$ $O(n^2)$
Softmax $O(n^2)$ $O(n^2)$
$\boldsymbol{\alpha} \mathbf{V}$ $O(n^2 d_v)$ $O(n d_v)$
Total $O(n^2 d)$ $O(n^2 + n d)$

The quadratic dependence on sequence length $n$ is the primary computational concern with attention. For a sequence of length 1,000, the attention matrix has 1,000,000 entries. For length 10,000, it has 100,000,000 entries.

18.8.2 Comparison with RNNs

Interestingly, attention's quadratic cost is not always worse than RNNs. Consider the trade-offs:

Property Self-Attention RNN
Complexity per layer $O(n^2 d)$ $O(n d^2)$
Sequential operations $O(1)$ $O(n)$
Maximum path length $O(1)$ $O(n)$
Parallelizable Yes No

For sequences where $n < d$ (which is common for $d = 512$ or $d = 1024$), self-attention is actually faster than RNNs. Moreover, attention is fully parallelizable across positions, while RNNs must process positions sequentially. The $O(1)$ maximum path length means attention can directly connect any two positions, while information in an RNN must traverse $O(n)$ steps.

18.8.3 Efficient Attention Variants

The quadratic cost has motivated a rich line of research into efficient attention. These variants trade some expressiveness for reduced computational complexity, enabling attention over much longer sequences.

Sparse Attention (Child et al., 2019). Instead of attending to all positions, each position attends only to a fixed subset. The Sparse Transformer combines two patterns:

  • Strided attention: Each position attends to every $l$-th position (e.g., positions 0, 128, 256, ...). This captures long-range dependencies.
  • Local attention: Each position attends to its $w$ nearest neighbors (e.g., positions $i-128$ to $i$). This captures short-range dependencies.

Combining both patterns ensures every pair of positions is connected within a constant number of layers while reducing the per-layer cost from $O(n^2)$ to $O(n\sqrt{n})$.

Linear Attention (Katharopoulos et al., 2020). The key insight is to decompose the softmax kernel. Standard attention computes:

$$\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V})_i = \frac{\sum_j \exp(\mathbf{q}_i^\top \mathbf{k}_j) \mathbf{v}_j}{\sum_j \exp(\mathbf{q}_i^\top \mathbf{k}_j)}$$

Linear attention replaces $\exp(\mathbf{q}^\top \mathbf{k})$ with $\phi(\mathbf{q})^\top \phi(\mathbf{k})$ for a feature map $\phi$:

$$\text{LinearAttention}(\mathbf{Q}, \mathbf{K}, \mathbf{V})_i = \frac{\phi(\mathbf{q}_i)^\top \sum_j \phi(\mathbf{k}_j) \mathbf{v}_j^\top}{\phi(\mathbf{q}_i)^\top \sum_j \phi(\mathbf{k}_j)}$$

The crucial trick is that $\sum_j \phi(\mathbf{k}_j) \mathbf{v}_j^\top$ can be precomputed once and reused for all queries, giving $O(n d^2)$ complexity instead of $O(n^2 d)$. When $d < n$ (which is common for long sequences), this is a significant improvement.

Sliding Window Attention (Beltagy et al., 2020). The Longformer uses a combination of local sliding window attention (each token attends to $w$ neighbors) and global attention (a few special tokens attend to all positions). This achieves $O(n \cdot w)$ complexity. BigBird (Zaheer et al., 2020) adds random attention connections on top of local and global patterns, proving that this combination is theoretically as expressive as full attention (it can approximate any sequence-to-sequence function).

Flash Attention (Dao et al., 2022). Unlike the methods above, Flash Attention computes exact standard attention but uses IO-aware GPU optimizations to avoid materializing the full $n \times n$ attention matrix in GPU high-bandwidth memory (HBM). Instead, it tiles the computation into blocks that fit in the fast SRAM cache, computing attention in a fused kernel that is 2--4x faster than standard implementations with $O(n)$ memory instead of $O(n^2)$.

The practical impact of Flash Attention cannot be overstated: it makes standard attention viable for much longer sequences (up to 64K tokens or more) without any approximation. Most modern Transformer implementations use Flash Attention by default. In PyTorch:

import torch
import torch.nn.functional as F

torch.manual_seed(42)

# Flash Attention via PyTorch's scaled_dot_product_attention
# (automatically selects the most efficient implementation)
Q = torch.randn(2, 8, 1024, 64, device="cpu")
K = torch.randn(2, 8, 1024, 64, device="cpu")
V = torch.randn(2, 8, 1024, 64, device="cpu")

# This uses Flash Attention on supported hardware (CUDA GPUs)
output = F.scaled_dot_product_attention(Q, K, V)
print(f"Output shape: {output.shape}")  # (2, 8, 1024, 64)

Summary of efficient attention variants:

Method Complexity Exact? Key Idea
Full attention $O(n^2 d)$ Yes Standard baseline
Sparse attention $O(n\sqrt{n} d)$ No Attend to fixed subset
Linear attention $O(n d^2)$ No Kernel decomposition
Sliding window $O(nwd)$ No Local + global patterns
Flash Attention $O(n^2 d)$ Yes IO-aware GPU tiling

18.9 Positional Encodings

18.9.1 Why Position Information Is Needed

Self-attention is permutation equivariant: if you shuffle the input tokens, the attention weights change, but the computation is the same. The output at each position depends on the content of the other tokens, not their positions. Without positional information, the sentence "dog bites man" and "man bites dog" would produce the same representation.

To break this symmetry, we add positional encodings to the input embeddings. These encodings inject information about each token's position in the sequence, allowing the model to distinguish between the same word at different positions.

18.9.2 Sinusoidal Positional Encodings

The original Transformer (Vaswani et al., 2017) used fixed sinusoidal positional encodings:

$$\text{PE}(pos, 2i) = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)$$

$$\text{PE}(pos, 2i+1) = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)$$

where:

  • $pos$ is the position in the sequence (0, 1, 2, ...)
  • $i$ is the dimension index (0, 1, ..., $d_{\text{model}}/2 - 1$)
  • $d_{\text{model}}$ is the model dimension

Intuition. Each dimension of the encoding oscillates at a different frequency. Low dimensions ($i$ near 0) oscillate rapidly (high frequency), capturing fine positional differences. High dimensions ($i$ near $d_{\text{model}}/2$) oscillate slowly (low frequency), capturing coarse positional information. Together, they form a unique "fingerprint" for each position.

Key property: relative position encoding. For any fixed offset $k$, the positional encoding $\text{PE}(pos + k)$ can be expressed as a linear transformation of $\text{PE}(pos)$. This means the model can learn to attend to relative positions (e.g., "two tokens to the left") using a fixed linear operation, regardless of the absolute position.

import torch
import math

torch.manual_seed(42)


def sinusoidal_positional_encoding(
    max_len: int, d_model: int
) -> torch.Tensor:
    """Compute sinusoidal positional encodings.

    Args:
        max_len: Maximum sequence length.
        d_model: Model dimension (must be even).

    Returns:
        Tensor of shape (max_len, d_model).
    """
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len).unsqueeze(1).float()
    div_term = torch.exp(
        torch.arange(0, d_model, 2).float()
        * (-math.log(10000.0) / d_model)
    )
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe


pe = sinusoidal_positional_encoding(100, 64)
print(f"PE shape: {pe.shape}")  # (100, 64)
print(f"PE[0] norm: {pe[0].norm():.4f}")
print(f"PE[50] norm: {pe[50].norm():.4f}")

18.9.3 Learned Positional Embeddings

BERT and GPT-2 use learned positional embeddings: a standard embedding layer that maps each position index to a dense vector:

import torch
import torch.nn as nn

torch.manual_seed(42)

max_len = 512
d_model = 768
position_embedding = nn.Embedding(max_len, d_model)

positions = torch.arange(max_len)
pe = position_embedding(positions)
print(f"Learned PE shape: {pe.shape}")  # (512, 768)

Learned positional embeddings are more flexible than sinusoidal encodings but are limited to a fixed maximum sequence length (512 for BERT, 1024 or 2048 for GPT-2). They cannot generalize to positions beyond the training range.

18.9.4 Rotary Position Embedding (RoPE)

RoPE (Su et al., 2021) is a more recent approach that encodes position information by rotating the query and key vectors. For a query vector $\mathbf{q}$ at position $m$, RoPE applies a rotation matrix $\mathbf{R}_m$:

$$\mathbf{q}_m' = \mathbf{R}_m \mathbf{q}_m, \quad \mathbf{k}_n' = \mathbf{R}_n \mathbf{k}_n$$

The rotation matrix is block-diagonal, operating on pairs of dimensions:

$$\mathbf{R}_m = \begin{pmatrix} \cos m\theta_1 & -\sin m\theta_1 \\ \sin m\theta_1 & \cos m\theta_1 \\ & & \cos m\theta_2 & -\sin m\theta_2 \\ & & \sin m\theta_2 & \cos m\theta_2 \\ & & & & \ddots \end{pmatrix}$$

where $\theta_i = 10000^{-2i/d}$ follows the same frequency schedule as sinusoidal encodings.

The key property is that the dot product between rotated queries and keys depends only on the relative position $m - n$:

$$\mathbf{q}_m'^\top \mathbf{k}_n' = \mathbf{q}_m^\top \mathbf{R}_{m-n} \mathbf{k}_n$$

This makes RoPE a relative positional encoding that is naturally integrated into the attention computation. RoPE is used in LLaMA, GPT-NeoX, and many modern large language models because it supports length generalization better than learned absolute embeddings.

18.9.5 ALiBi: Attention with Linear Biases

ALiBi (Press et al., 2022) takes an even simpler approach: instead of modifying the input embeddings, it adds a position-dependent bias directly to the attention scores:

$$\text{scores}_{ij} = \mathbf{q}_i^\top \mathbf{k}_j - m \cdot |i - j|$$

where $m$ is a head-specific slope (different heads use different slopes, like $m = 2^{-1}, 2^{-2}, 2^{-4}, \ldots$). The bias linearly penalizes attending to distant positions, with different heads having different "attention spans."

ALiBi requires no additional parameters and generalizes remarkably well to sequences much longer than those seen during training. The linear penalty is a simple inductive bias: nearby tokens are more relevant than distant ones, all else being equal.

18.9.6 Comparison of Positional Encoding Methods

Method Type Extrapolation Parameters Used By
Sinusoidal Absolute, fixed Moderate 0 Original Transformer
Learned Absolute, learned Poor $L \times d$ BERT, GPT-2
RoPE Relative, applied to Q/K Good 0 LLaMA, GPT-NeoX
ALiBi Relative, attention bias Excellent 0 BLOOM, MPT

The trend in modern architectures is toward relative positional encodings (RoPE, ALiBi) that can generalize to longer sequences than seen during training.


18.10 Attention Visualization and Interpretation

18.10.1 What Do Attention Weights Tell Us?

Attention weights $\alpha_{ij}$ form a probability distribution that tells us how much output position $i$ "attends to" input position $j$. Visualizing these weights as heatmaps can reveal interpretable patterns:

  • Diagonal patterns in translation indicate monotonic alignment (word 1 maps to word 1, etc.)
  • Off-diagonal blocks indicate word reordering between languages
  • Diffuse attention indicates positions that aggregate global context
  • Sharp attention indicates strong pairwise dependencies

18.10.2 Attention Is Not Explanation

An important caveat: attention weights should not be uncritically interpreted as "explanations" for model behavior. Jain and Wallace (2019) showed that:

  1. Attention weights often do not correlate with gradient-based feature importance
  2. Alternative attention distributions can yield the same predictions
  3. Attention over intermediate representations is hard to map back to input features

Wiegreffe and Pinter (2019) offered a more nuanced view, arguing that attention can provide useful information when analyzed carefully. The key lesson is that attention weights are descriptive (they show what the model looks at) but not necessarily prescriptive (they may not explain why the model makes a particular decision).

18.10.3 Multi-Head Attention Patterns

With multi-head attention, different heads often specialize in different roles. Clark et al. (2019) analyzed BERT's attention heads and found:

  • Some heads attend to the next or previous token (positional)
  • Some heads attend to separator tokens (structural)
  • Some heads approximate dependency parse relations (syntactic)
  • Some heads are highly diffuse, acting as "bag of words" aggregators

This specialization emerges naturally from training --- it is not explicitly engineered.

18.10.4 Attention Visualization in Practice

Visualizing attention can be done at multiple levels of granularity:

Single-head heatmaps. For a single attention head, the attention weights $\alpha_{ij}$ form a matrix that can be displayed as a heatmap. Rows correspond to query positions (output tokens), and columns correspond to key positions (input tokens). Bright cells indicate high attention.

Head-level summary. With multiple heads, you can display a grid of heatmaps (one per head) or compute summary statistics like the entropy of each head's attention distribution (low entropy = focused, high entropy = diffuse).

Layer-level analysis. Attention patterns change across layers. Lower layers tend to show more local, positional patterns (attending to nearby tokens), while upper layers show more semantic patterns (attending to related content regardless of position). Tracing how attention patterns evolve across layers reveals how the model progressively builds up its representation.

Attention rollout (Abnar and Zuidema, 2020) addresses the fact that attention in one layer flows through residual connections from all previous layers. Simple per-layer visualization can be misleading because it ignores this accumulation. Attention rollout recursively combines attention weights across layers to estimate how much each input token contributes to each output position, accounting for the full computation graph.

Tools. The BertViz library provides interactive attention visualizations for Transformer models. For custom models, you can extract attention weights by storing them during the forward pass (as we did in our MultiHeadAttention implementation with self.attention_weights) and plotting them with matplotlib or seaborn.


18.11 Attention as a Differentiable Dictionary Lookup

18.11.1 The Dictionary Analogy

Let us formalize the dictionary analogy we introduced earlier. A traditional dictionary maps keys to values:

lookup(query, dictionary) -> value
    if query == key_i:
        return value_i

This is a hard lookup: you either get an exact match or nothing.

Attention generalizes this to a soft, differentiable lookup:

attention(query, keys, values) -> weighted_sum(values)
    weights = softmax(similarity(query, keys))
    return sum(weight_i * value_i)

Key differences:

Property Hard Lookup Soft Attention
Match type Exact Approximate (similarity)
Output Single value Weighted sum of all values
Differentiable No Yes
Trainable No Yes (via learned projections)

18.11.2 Content-Based vs. Location-Based Addressing

The QKV attention mechanism is a form of content-based addressing: the query's content determines which keys it matches. This is in contrast to location-based addressing, where attention is determined by position alone (e.g., "always attend to the previous 3 positions").

Neural Turing Machines (Graves et al., 2014) used both content-based and location-based addressing. Modern Transformers primarily use content-based addressing through QKV attention, but positional encodings (which we will explore in Chapter 19) inject location information into the content, indirectly enabling location-sensitive attention.

18.11.3 Memory Networks and Key-Value Stores

The attention-as-dictionary perspective connects to a broader family of models:

  • Memory Networks (Weston et al., 2015): Store facts as key-value pairs and use attention to retrieve relevant facts for question answering
  • Key-Value Memory Networks (Miller et al., 2016): Explicitly separate keys (for addressing) from values (for retrieval)
  • Product-Key Memory (Lample et al., 2019): Efficient large-scale memory using product quantization of keys

These models can be seen as specialized attention mechanisms with very large key-value stores, blurring the line between attention and retrieval-augmented generation.


18.12 Putting It All Together: Attention in Context

18.12.1 Summary of Attention Variants

Variant Scoring Complexity Use Case
Bahdanau (additive) $\mathbf{v}^\top \tanh(\mathbf{W}\mathbf{s} + \mathbf{U}\mathbf{h})$ $O(d_a(d_s + d_h))$ Seq2seq with RNNs
Luong dot $\mathbf{s}^\top \mathbf{h}$ $O(d)$ Fast cross-attention
Luong general $\mathbf{s}^\top \mathbf{W} \mathbf{h}$ $O(d^2)$ Flexible cross-attention
Scaled dot-product $\frac{\mathbf{Q}\mathbf{K}^\top}{\sqrt{d_k}}$ $O(n^2 d)$ Transformers
Multi-head $h$ parallel scaled dot-product heads $O(n^2 d)$ Transformers (standard)

18.12.2 Attention in Other Domains

While we have focused on sequence modeling, attention has become a universal mechanism across deep learning:

Computer vision. Vision Transformers (ViT) split images into patches and apply self-attention to patch embeddings. Each patch can attend to every other patch, enabling the model to capture global context from the first layer --- unlike CNNs, which build up receptive fields gradually. Cross-attention is used in text-to-image models (like Stable Diffusion) where image features attend to text embeddings.

Graph neural networks. Graph Attention Networks (GATs) use attention to weight the contributions of a node's neighbors. Each node computes attention weights over its neighbors, then aggregates their features using the attention-weighted sum. This allows the model to learn which neighbors are most informative for each node.

Multimodal models. Cross-attention enables information flow between different modalities. In vision-language models, image tokens serve as keys and values while text tokens serve as queries, allowing the text to "look at" relevant parts of the image.

Protein structure prediction. AlphaFold2 uses a sophisticated attention mechanism (the "Evoformer") where residue pairs attend to each other, enabling the model to reason about 3D structure from amino acid sequences.

The universality of attention is one of its most remarkable properties: the same mathematical framework --- queries seeking keys to retrieve values --- applies to tokens, image patches, graph nodes, amino acids, and any other structured data.

18.12.3 The Road Ahead

The attention mechanism is the foundational building block of the Transformer architecture, which we will explore in depth in Chapter 19. There, you will learn how self-attention layers are combined with feed-forward networks, layer normalization, and residual connections to create the full Transformer architecture.

In Chapter 20, we will see how pre-training objectives like masked language modeling and span corruption use the attention mechanism to learn powerful representations from unlabeled text. The attention mechanism you have learned in this chapter is the engine that powers GPT, BERT, and virtually every modern language model.


18.13 Practical Tips for Working with Attention

18.13.1 Debugging Attention Models

When attention-based models fail to train or produce poor results, consider these debugging strategies:

Check attention weight distributions. If attention weights are nearly uniform across all positions (high entropy), the model is not learning to focus on relevant positions. This can happen when: - The learning rate is too low for the attention parameters. - The input embeddings are not informative enough. - The positional encodings are missing or misconfigured.

Check for numerical issues. Very large dot products before softmax can cause overflow (NaN values). This is why scaling by $\sqrt{d_k}$ is essential. If you still see numerical issues, ensure that: - Input embeddings are properly initialized (not too large). - Layer normalization is applied before attention (pre-norm style). - Gradient clipping is used during training.

Verify masking correctness. A common bug is applying the mask with the wrong sign or shape. After applying the mask, check that: - Padding positions receive attention weight exactly 0.0. - In causal attention, no future positions receive non-zero weight. - The mask is broadcastable to the attention score shape $(B, h, n, m)$.

18.13.2 Attention and Memory: Key-Value Caching

In autoregressive generation (as we will see in Chapter 21), the model generates one token at a time. At each step, it must attend to all previously generated tokens. Without optimization, this requires recomputing the attention for all previous tokens at every step, giving $O(n^2)$ total computation for generating $n$ tokens.

Key-value (KV) caching avoids this redundancy. Once a token's key and value vectors are computed, they are stored in a cache. At each new generation step, only the new token's query, key, and value are computed. The new query attends to all cached keys, and the new key-value pair is appended to the cache.

This reduces the per-step cost from $O(n \cdot d)$ to $O(d)$ (excluding the attention computation itself), which is critical for practical inference speed. The tradeoff is memory: the KV cache stores $2 \times L \times h \times d_k$ values per token (keys and values for each layer and head), which can become the bottleneck for long sequences and large batch sizes.


Summary

In this chapter, we have built a thorough understanding of the attention mechanism:

  1. The information bottleneck in standard seq2seq models motivates the need for attention
  2. Bahdanau (additive) attention uses a learned alignment model to compute attention weights between encoder and decoder states
  3. Luong (multiplicative) attention simplifies scoring with dot products or bilinear forms
  4. Self-attention allows a sequence to attend to itself, capturing long-range dependencies
  5. The QKV framework unifies all attention variants: queries seek, keys advertise, values communicate
  6. Scaled dot-product attention divides by $\sqrt{d_k}$ to prevent softmax saturation
  7. Multi-head attention runs multiple attention heads in parallel, each specializing in different relationship types
  8. Masking handles padding and enforces causal (autoregressive) constraints
  9. Quadratic complexity $O(n^2)$ is attention's primary limitation, motivating efficient variants
  10. Attention as a differentiable dictionary provides a powerful mental model for understanding the mechanism

The attention mechanism transformed sequence modeling from a sequential, bottlenecked process into a parallel, dynamic one. In the next chapter, we will see how this building block assembles into the Transformer --- the architecture that changed everything.