29 min read

> — Richard Feynman, written on his blackboard at the time of his death (1988)

Chapter 10: The Transformer Architecture — Attention Is All You Need (and Why It Changed Everything)

"What I cannot create, I do not understand." — Richard Feynman, written on his blackboard at the time of his death (1988)


Learning Objectives

By the end of this chapter, you will be able to:

  1. Derive scaled dot-product attention from first principles and implement it in PyTorch
  2. Explain multi-head attention, positional encoding, and the full transformer block
  3. Trace the information flow through a complete transformer encoder and decoder
  4. Analyze the computational complexity of self-attention $O(n^2 d)$ and modern efficiency improvements (sparse attention, linear attention, flash attention)
  5. Implement a small transformer for a concrete task and train it end-to-end

10.1 The Problem That Attention Solves

In Chapter 9, we built an LSTM to model StreamRec user sessions — sequences of items a user engaged with, used to predict what they will engage with next. The LSTM processes the session one item at a time, maintaining a hidden state $\mathbf{h}_t \in \mathbb{R}^d$ that compresses everything the model has seen so far into a single vector. This compression is the LSTM's fundamental bottleneck.

Consider a StreamRec user who watched a documentary about ocean conservation 45 minutes into their session, then browsed through 20 unrelated videos, and is now hovering over a film about marine biology. The LSTM must have preserved information about the ocean documentary through 20 subsequent updates to the hidden state. In theory, the LSTM's gating mechanism enables this — the forget gate can selectively preserve relevant information. In practice, the information degrades. The hidden state has a fixed capacity $d$, and every subsequent item competes for that capacity. Information from early in the sequence is systematically overwritten.

The Bahdanau attention mechanism we introduced at the end of Chapter 9 offered a partial solution: instead of relying solely on the final hidden state, the decoder can attend to all encoder hidden states simultaneously. But Bahdanau attention was an add-on to the RNN — a patch over the sequential bottleneck. The RNN still processes the input sequentially, and training still requires backpropagation through time with its $O(T)$ sequential computation and vanishing gradient hazards.

The transformer, introduced by Vaswani et al. (2017), asks a radical question: what if we removed the recurrence entirely and used attention as the sole mechanism for relating positions in a sequence to one another?

This chapter derives the transformer from first principles. We will build every component — scaled dot-product attention, multi-head attention, positional encoding, the transformer block, the full encoder-decoder architecture — from mathematical foundations and implement each in PyTorch. By the end, you will have a complete, annotated transformer that you understand line by line, because you built it line by line.

Understanding Why: The transformer is not a single clever idea. It is a carefully engineered system where each component solves a specific problem. Scaled dot-product attention solves the content-based retrieval problem. Multi-head attention solves the representational diversity problem. Positional encoding solves the order-awareness problem. Layer normalization and residual connections solve the trainability problem. Understanding each component's purpose is the key to understanding the whole.


10.2 Deriving Attention from First Principles

The Database Analogy

Consider a database table with $n$ rows, where each row contains a key and a value. Given a query, the database finds the row whose key matches the query and returns the corresponding value. This is a hard lookup — it returns exactly one row, or nothing.

Now relax the matching criterion. Instead of requiring an exact match, suppose we compute a similarity between the query and every key, then return a weighted combination of all values, where the weights reflect the similarities. Keys that are more similar to the query contribute more to the output; keys that are dissimilar contribute less.

This is attention. Formally, given:

  • A query vector $\mathbf{q} \in \mathbb{R}^{d_k}$
  • A set of key vectors $\mathbf{k}_1, \mathbf{k}_2, \ldots, \mathbf{k}_n \in \mathbb{R}^{d_k}$
  • A set of value vectors $\mathbf{v}_1, \mathbf{v}_2, \ldots, \mathbf{v}_n \in \mathbb{R}^{d_v}$

the attention output is:

$$\text{Attention}(\mathbf{q}, \mathbf{K}, \mathbf{V}) = \sum_{i=1}^{n} \alpha_i \mathbf{v}_i$$

where the attention weights $\alpha_i$ are computed from query-key similarities:

$$\alpha_i = \frac{\exp(\text{score}(\mathbf{q}, \mathbf{k}_i))}{\sum_{j=1}^{n} \exp(\text{score}(\mathbf{q}, \mathbf{k}_j))}$$

The softmax ensures the weights are non-negative and sum to 1, making the output a convex combination of the value vectors.

The Information Retrieval Analogy

There is a second, equally illuminating analogy. In information retrieval, a user enters a search query, the system computes a relevance score between the query and every document in the corpus, and the results are ranked by relevance. The top documents — those whose content (keys) best match the query — receive the most attention.

Self-attention applies this same process within a single sequence. Each position in the sequence simultaneously plays three roles:

  1. As a query: "I am looking for positions that are relevant to me."
  2. As a key: "Here is what I have to offer — match against me."
  3. As a value: "If you find me relevant, here is the information I will contribute to your output."

When position $i$ attends to the sequence, it broadcasts its query, computes similarity against every position's key, and aggregates the values weighted by those similarities. Every position does this simultaneously.

Choosing the Scoring Function: Dot Product

How should we compute $\text{score}(\mathbf{q}, \mathbf{k}_i)$? The simplest choice is the dot product:

$$\text{score}(\mathbf{q}, \mathbf{k}_i) = \mathbf{q}^\top \mathbf{k}_i$$

The dot product is geometrically motivated: $\mathbf{q}^\top \mathbf{k} = \|\mathbf{q}\| \|\mathbf{k}\| \cos\theta$, where $\theta$ is the angle between the two vectors. Vectors pointing in similar directions yield large positive scores; orthogonal vectors yield zero; opposing vectors yield negative scores. The dot product is also computationally efficient: for a batch of queries and keys, it reduces to a single matrix multiplication.

The Scaling Problem

The dot product has a problem. If the vectors $\mathbf{q}$ and $\mathbf{k}$ have independent components drawn from a distribution with mean 0 and variance 1, then $\mathbf{q}^\top \mathbf{k} = \sum_{j=1}^{d_k} q_j k_j$ is a sum of $d_k$ independent random variables, each with mean 0 and variance 1. By the central limit theorem, the dot product has mean 0 and variance $d_k$.

As $d_k$ grows, the dot products grow in magnitude. Large-magnitude inputs to the softmax push it into saturated regions where the gradient is nearly zero:

$$\text{softmax}(z)_i = \frac{e^{z_i}}{\sum_j e^{z_j}}$$

When one element $z_i$ is much larger than the rest, $\text{softmax}(z)_i \approx 1$ and $\text{softmax}(z)_j \approx 0$ for $j \neq i$. The Jacobian of the softmax in this regime has entries close to zero, killing gradient flow.

The fix is to scale the dot products by $\frac{1}{\sqrt{d_k}}$, normalizing the variance back to 1 regardless of dimension:

$$\text{score}(\mathbf{q}, \mathbf{k}_i) = \frac{\mathbf{q}^\top \mathbf{k}_i}{\sqrt{d_k}}$$

Mathematical Foundation: Let $q_j, k_j \sim \mathcal{N}(0, 1)$ i.i.d. Then $\text{Var}(\mathbf{q}^\top \mathbf{k}) = \text{Var}\left(\sum_{j=1}^{d_k} q_j k_j\right) = \sum_{j=1}^{d_k} \text{Var}(q_j k_j) = d_k$. After scaling by $1/\sqrt{d_k}$, $\text{Var}\left(\frac{\mathbf{q}^\top \mathbf{k}}{\sqrt{d_k}}\right) = \frac{d_k}{d_k} = 1$. This ensures the softmax operates in its well-behaved regime regardless of the dimensionality $d_k$. The same scaling principle appears in Xavier/Glorot initialization (Chapter 7) — both are applications of variance stabilization.

Scaled Dot-Product Attention: The Complete Formula

Batching across all $n$ query positions and writing the keys and values as matrices, we arrive at the formula from Vaswani et al. (2017):

$$\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^\top}{\sqrt{d_k}}\right)\mathbf{V}$$

where $\mathbf{Q} \in \mathbb{R}^{n \times d_k}$, $\mathbf{K} \in \mathbb{R}^{n \times d_k}$, $\mathbf{V} \in \mathbb{R}^{n \times d_v}$, and the output is in $\mathbb{R}^{n \times d_v}$.

The matrix $\mathbf{Q}\mathbf{K}^\top \in \mathbb{R}^{n \times n}$ is the attention score matrix: entry $(i, j)$ measures how much position $i$ attends to position $j$. After softmax (applied row-wise), each row becomes a probability distribution over all positions. Multiplying by $\mathbf{V}$ computes the weighted average of value vectors for each query position.

Implementation from Scratch

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple


def scaled_dot_product_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    mask: Optional[torch.Tensor] = None,
    dropout: Optional[nn.Dropout] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Compute scaled dot-product attention.

    Args:
        query: Tensor of shape (..., seq_len_q, d_k).
        key: Tensor of shape (..., seq_len_k, d_k).
        value: Tensor of shape (..., seq_len_k, d_v).
        mask: Optional boolean tensor. Positions with True are masked
              (set to -inf before softmax). Shape broadcastable to
              (..., seq_len_q, seq_len_k).
        dropout: Optional dropout module applied to attention weights.

    Returns:
        output: Weighted sum of values, shape (..., seq_len_q, d_v).
        attention_weights: Softmax weights, shape (..., seq_len_q, seq_len_k).
    """
    d_k = query.size(-1)

    # Compute raw attention scores: (batch, ..., seq_q, d_k) @ (batch, ..., d_k, seq_k)
    # Result: (batch, ..., seq_q, seq_k)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

    # Apply mask: set masked positions to -inf so softmax assigns ~0 weight
    if mask is not None:
        scores = scores.masked_fill(mask, float("-inf"))

    # Normalize across the key dimension
    attention_weights = F.softmax(scores, dim=-1)

    # Optional dropout on attention weights (used during training)
    if dropout is not None:
        attention_weights = dropout(attention_weights)

    # Weighted sum of values
    output = torch.matmul(attention_weights, value)

    return output, attention_weights

Let us verify this implementation by examining its behavior on a small example:

torch.manual_seed(42)

# Three positions, d_k = d_v = 4
Q = torch.randn(1, 3, 4)
K = torch.randn(1, 3, 4)
V = torch.randn(1, 3, 4)

output, weights = scaled_dot_product_attention(Q, K, V)

print("Attention weights (each row sums to 1):")
print(weights.squeeze(0))
print(f"\nRow sums: {weights.squeeze(0).sum(dim=-1)}")
print(f"\nOutput shape: {output.shape}")
Attention weights (each row sums to 1):
tensor([[0.2897, 0.4082, 0.3021],
        [0.3592, 0.4014, 0.2394],
        [0.4674, 0.2905, 0.2421]])

Row sums: tensor([1.0000, 1.0000, 1.0000])

Output shape: torch.Size([1, 3, 4])

Each position produces a different distribution over all positions — it has learned (even with random weights) to attend differently depending on its content. Training will shape these distributions to capture meaningful relationships.


10.3 Where Do Q, K, V Come From? Learned Projections

In the formula above, we assumed Q, K, and V were given. In a transformer, they are computed from the input by learned linear projections. Given input embeddings $\mathbf{X} \in \mathbb{R}^{n \times d_{\text{model}}}$:

$$\mathbf{Q} = \mathbf{X}\mathbf{W}^Q, \qquad \mathbf{K} = \mathbf{X}\mathbf{W}^K, \qquad \mathbf{V} = \mathbf{X}\mathbf{W}^V$$

where $\mathbf{W}^Q, \mathbf{W}^K \in \mathbb{R}^{d_{\text{model}} \times d_k}$ and $\mathbf{W}^V \in \mathbb{R}^{d_{\text{model}} \times d_v}$.

This is crucial. The same input token is projected into three different spaces, each serving a different purpose:

  • $\mathbf{W}^Q$ learns what to ask for — what kind of information this position is looking for.
  • $\mathbf{W}^K$ learns what to advertise — what kind of information this position has to offer.
  • $\mathbf{W}^V$ learns what to provide — the actual information transmitted when this position is attended to.

The separation of key and value is what makes attention powerful. A token's key determines when it is retrieved; its value determines what is retrieved. These can encode very different information. Consider the word "bank" in a sentence: its key might encode syntactic position information (noun, object-of-verb), while its value encodes semantic content (financial institution vs. river bank), resolved by the context provided through the query.

Fundamentals > Frontier: The Q/K/V decomposition is a projection into subspaces — the same operation you studied as change of basis in Chapter 1. Each projection matrix selects a linear subspace of the input embedding space that is optimized for a specific purpose. The matrix $\mathbf{Q}\mathbf{K}^\top$ computes inner products in the projected space, which is equivalent to computing inner products with respect to the Gram matrix $\mathbf{W}^Q (\mathbf{W}^K)^\top$ in the original space. The model is learning a bilinear similarity function.


10.4 Multi-Head Attention: Different Heads Learn Different Relationships

A single attention function computes one set of attention weights — one way of relating positions to one another. But language (and sequence data generally) contains many types of relationships simultaneously. In a sentence like "The cat sat on the mat because it was tired," position "it" must attend to "cat" for coreference resolution, to "sat" for predicate-argument structure, and to "tired" for semantic role labeling. A single attention distribution cannot capture all of these simultaneously, because softmax produces a distribution that must sum to 1 — attending strongly to one position means attending weakly to others.

Multi-head attention solves this by running $h$ attention functions in parallel, each with its own learned projections, and concatenating the results:

$$\text{MultiHead}(\mathbf{X}) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)\mathbf{W}^O$$

where each head computes:

$$\text{head}_i = \text{Attention}(\mathbf{X}\mathbf{W}_i^Q, \mathbf{X}\mathbf{W}_i^K, \mathbf{X}\mathbf{W}_i^V)$$

Each head operates in a lower-dimensional space: $d_k = d_v = d_{\text{model}} / h$. With $h = 8$ heads and $d_{\text{model}} = 512$, each head operates in $\mathbb{R}^{64}$. The concatenation of all heads produces a vector in $\mathbb{R}^{512}$, and the output projection $\mathbf{W}^O \in \mathbb{R}^{d_{\text{model}} \times d_{\text{model}}}$ mixes the heads.

The total computation is the same as a single head with full dimensionality, but the representational capacity is greater: each head can specialize in a different type of relationship — syntactic, semantic, positional, or otherwise.

Implementation

class MultiHeadAttention(nn.Module):
    """Multi-head attention mechanism.

    Projects input into h separate Q, K, V spaces, computes
    scaled dot-product attention in each, and concatenates results.
    """

    def __init__(
        self,
        d_model: int,
        num_heads: int,
        dropout: float = 0.0,
    ) -> None:
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # Combined projection for Q, K, V (more efficient than three separate)
        self.W_qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout) if dropout > 0 else None

    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass for self-attention.

        Args:
            x: Input tensor of shape (batch, seq_len, d_model).
            mask: Optional mask of shape (batch, 1, seq_len, seq_len)
                  or (1, 1, seq_len, seq_len) for causal masking.

        Returns:
            output: Shape (batch, seq_len, d_model).
            attention_weights: Shape (batch, num_heads, seq_len, seq_len).
        """
        batch_size, seq_len, _ = x.shape

        # Project to Q, K, V in one matrix multiply, then split
        qkv = self.W_qkv(x)  # (batch, seq_len, 3 * d_model)
        qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.d_k)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, batch, heads, seq_len, d_k)
        Q, K, V = qkv[0], qkv[1], qkv[2]

        # Scaled dot-product attention per head
        output, attention_weights = scaled_dot_product_attention(
            Q, K, V, mask=mask, dropout=self.dropout
        )

        # Concatenate heads and project
        output = output.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)
        output = self.W_o(output)

        return output, attention_weights

Implementation Note: We project Q, K, V with a single nn.Linear(d_model, 3 * d_model) and then split the result, rather than using three separate projections. This is mathematically equivalent but more efficient: a single large matrix multiplication is faster than three smaller ones due to better hardware utilization. Most production transformer implementations use this fused projection.

What Do Different Heads Learn?

Empirical analysis of trained transformer models reveals that different heads consistently specialize in different relationship types:

Head Type What It Attends To Example
Positional Adjacent positions Head 3 attends primarily to position $i \pm 1$
Syntactic Grammatical dependencies Head 5 links verbs to their subjects
Semantic Meaning-related tokens Head 7 links pronouns to their antecedents
Separator Delimiter tokens Head 2 attends to [SEP] or period tokens
Broad Nearly uniform distribution Head 1 computes a bag-of-words average

This specialization is not hard-coded — it emerges from training. The model discovers that distributing different relationship types across heads is the most efficient use of its representational capacity.


10.5 Positional Encoding: Teaching the Transformer About Order

Attention is a set operation — it computes pairwise similarities and weighted averages without any notion of position. If we permute the input sequence, the attention weights change (because different pairs are now compared), but the set of outputs is the same permutation of the original outputs. The transformer, without positional information, treats its input as a bag of tokens.

This is both a strength (parallelism — all positions can be processed simultaneously, unlike an RNN) and a weakness (the model cannot distinguish "the dog bit the man" from "the man bit the dog").

The solution is to inject positional information into the input embeddings. Vaswani et al. proposed sinusoidal positional encodings:

$$PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)$$

$$PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)$$

where $pos$ is the position in the sequence and $i$ is the dimension index. Each dimension oscillates at a different frequency, from $2\pi$ (dimension 0) to $2\pi \times 10000$ (dimension $d_{\text{model}} - 1$). The positional encoding is added to the input embedding, not concatenated.

Why Sinusoids?

Three properties make sinusoidal encodings effective:

  1. Unique position representation. Each position maps to a unique vector. The combination of different frequencies at different dimensions creates a unique fingerprint for each position, analogous to how Fourier analysis decomposes a signal into frequency components.

  2. Relative position via linear transformation. For any fixed offset $k$, $PE_{pos+k}$ can be expressed as a linear function of $PE_{pos}$:

$$\begin{bmatrix} \sin(\omega(pos + k)) \\ \cos(\omega(pos + k)) \end{bmatrix} = \begin{bmatrix} \cos(\omega k) & \sin(\omega k) \\ -\sin(\omega k) & \cos(\omega k) \end{bmatrix} \begin{bmatrix} \sin(\omega \cdot pos) \\ \cos(\omega \cdot pos) \end{bmatrix}$$

This rotation matrix relationship means the model can learn relative positions through linear projections in Q and K — it does not need to memorize absolute positions.

  1. Extrapolation. Sinusoidal encodings are defined for any position, so the model can potentially handle sequences longer than those seen during training (though in practice, extrapolation is limited).

Learned Positional Embeddings

An alternative is to learn the positional embeddings as parameters:

$$\mathbf{E}_{pos} = \text{Embedding}(pos) \in \mathbb{R}^{d_{\text{model}}}$$

This is a lookup table of $n_{\max} \times d_{\text{model}}$ parameters, where $n_{\max}$ is the maximum sequence length. Learned embeddings are simpler to implement and empirically perform comparably to sinusoidal encodings for fixed-length contexts.

The downside is rigidity: the model cannot handle sequences longer than $n_{\max}$ without retraining or interpolation.

Rotary Positional Embeddings (RoPE)

Modern large language models (GPT-NeoX, LLaMA, Mistral) predominantly use Rotary Positional Embeddings (Su et al., 2021). RoPE encodes position by rotating the query and key vectors rather than adding to the input embeddings:

$$\tilde{\mathbf{q}}_m = R_{\Theta, m} \mathbf{q}_m, \qquad \tilde{\mathbf{k}}_n = R_{\Theta, n} \mathbf{k}_n$$

where $R_{\Theta, pos}$ is a block-diagonal rotation matrix. The key property is that the dot product $\tilde{\mathbf{q}}_m^\top \tilde{\mathbf{k}}_n$ depends only on the relative position $m - n$, not the absolute positions $m$ and $n$.

This combines the best properties of sinusoidal and learned encodings: it is parameter-free, naturally encodes relative position, and works well with KV-caching for autoregressive generation.

Implementation

class SinusoidalPositionalEncoding(nn.Module):
    """Sinusoidal positional encoding from Vaswani et al. (2017).

    Adds position-dependent sinusoidal signals to input embeddings.
    """

    def __init__(self, d_model: int, max_len: int = 5000) -> None:
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        # Compute the division term: 10000^(2i/d_model)
        div_term = torch.exp(
            torch.arange(0, d_model, 2, dtype=torch.float)
            * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Add positional encoding to input embeddings.

        Args:
            x: Input of shape (batch, seq_len, d_model).

        Returns:
            x + PE, same shape as input.
        """
        return x + self.pe[:, : x.size(1), :]

Common Misconception: Positional encodings are sometimes described as "telling the model where each token is." More precisely, they provide information that enables the model to learn position-dependent behavior through the Q/K dot product. The model still must learn to use this information — it is an affordance, not an instruction. This is why both sinusoidal and learned positional embeddings work similarly: the model adapts to whatever positional signal it receives.


10.6 The Transformer Block

With multi-head attention and positional encoding defined, we can assemble the transformer block — the repeating unit that is stacked to form the full transformer. Each block contains two sub-layers:

  1. Multi-head self-attention — lets each position attend to all other positions.
  2. Position-wise feed-forward network (FFN) — applies an independent nonlinear transformation at each position.

Both sub-layers are wrapped with a residual connection and layer normalization.

The Feed-Forward Network

The FFN is a simple two-layer MLP applied identically to each position:

$$\text{FFN}(\mathbf{x}) = \text{GELU}(\mathbf{x}\mathbf{W}_1 + \mathbf{b}_1)\mathbf{W}_2 + \mathbf{b}_2$$

where $\mathbf{W}_1 \in \mathbb{R}^{d_{\text{model}} \times d_{ff}}$, $\mathbf{W}_2 \in \mathbb{R}^{d_{ff} \times d_{\text{model}}}$, and typically $d_{ff} = 4 \times d_{\text{model}}$. The original paper used ReLU; modern transformers use GELU (Gaussian Error Linear Unit) or SwiGLU.

The FFN operates independently on each position — no information flows between positions here. This might seem redundant after attention, which already relates all positions. But the two sub-layers serve complementary roles:

  • Attention routes information between positions (inter-token computation).
  • FFN transforms information within each position (intra-token computation).

Anthropic's research on transformer circuits (Elhage et al., 2021) has shown that attention heads move information while FFN layers store and retrieve knowledge. The FFN can be viewed as a learned key-value memory: the rows of $\mathbf{W}_1$ are keys, the rows of $\mathbf{W}_2$ are values, and the hidden activation selects which memories to retrieve.

Layer Normalization: Pre-LN vs. Post-LN

The original transformer used post-LN — layer normalization applied after the residual addition:

$$\mathbf{x} = \text{LayerNorm}(\mathbf{x} + \text{SubLayer}(\mathbf{x}))$$

Modern transformers overwhelmingly use pre-LN — layer normalization applied before the sub-layer:

$$\mathbf{x} = \mathbf{x} + \text{SubLayer}(\text{LayerNorm}(\mathbf{x}))$$

The pre-LN variant is easier to train because the residual connection provides a direct gradient path from the loss to any layer, unimpeded by normalization. Post-LN can produce slightly better final performance but requires careful learning rate warmup and is more sensitive to hyperparameters.

Research Insight: Xiong et al. (2020), "On Layer Normalization in the Transformer Architecture," provided theoretical analysis showing that pre-LN keeps the gradient norm bounded regardless of depth, while post-LN can produce gradient explosion during early training. This explains why pre-LN is more robust and why post-LN requires warmup: the warmup keeps the learning rate small enough to avoid the unstable region.

The Residual Stream

The residual connections in the transformer create what Elhage et al. (2021) call the residual stream: a communication channel that flows through the entire network. Each attention layer and FFN layer reads from the stream, computes a contribution, and writes back to the stream via addition. The output at any position is the sum of the original embedding plus all contributions from all layers.

This view is powerful: it means the transformer is not a strict pipeline (layer 1 → layer 2 → ... → layer N). It is a broadcasting architecture where every layer has access to the original input and to all previous layers' outputs simultaneously. Information does not need to pass through every layer in sequence — it can be written to the residual stream early and read late.

Implementation

class TransformerBlock(nn.Module):
    """Single transformer encoder block with pre-LN architecture.

    Components:
        1. Multi-head self-attention with residual connection
        2. Position-wise feed-forward network with residual connection
    """

    def __init__(
        self,
        d_model: int,
        num_heads: int,
        d_ff: int,
        dropout: float = 0.1,
    ) -> None:
        super().__init__()
        self.attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass through one transformer block.

        Args:
            x: Input of shape (batch, seq_len, d_model).
            mask: Optional attention mask.

        Returns:
            output: Same shape as input.
            attention_weights: Shape (batch, num_heads, seq_len, seq_len).
        """
        # Pre-LN self-attention with residual
        normed = self.norm1(x)
        attn_out, attn_weights = self.attention(normed, mask=mask)
        x = x + self.dropout(attn_out)

        # Pre-LN FFN with residual
        normed = self.norm2(x)
        ffn_out = self.ffn(normed)
        x = x + ffn_out

        return x, attn_weights

10.7 The Full Transformer: Encoder and Decoder

The original Vaswani et al. (2017) transformer is an encoder-decoder architecture designed for sequence-to-sequence tasks like machine translation. Understanding the full architecture requires tracing information flow through both halves.

The Encoder

The encoder converts an input sequence $\mathbf{x}_1, \ldots, \mathbf{x}_n$ into a sequence of contextual representations $\mathbf{z}_1, \ldots, \mathbf{z}_n$. It consists of:

  1. Input embedding — converts token IDs to dense vectors $\in \mathbb{R}^{d_{\text{model}}}$
  2. Positional encoding — adds position information
  3. $N$ transformer blocks — each containing multi-head self-attention + FFN

Every encoder position can attend to every other encoder position (full bidirectional attention). After $N$ layers, each position's representation is contextualized — it has "seen" the entire input through the attention mechanism.

The Decoder

The decoder generates the output sequence one token at a time. It contains three components per block:

  1. Masked multi-head self-attention — the decoder attends to previous output positions only. A causal mask prevents position $i$ from attending to positions $j > i$, preserving the autoregressive property (each prediction depends only on past predictions, not future ones).

  2. Cross-attention — the decoder attends to the encoder's output representations. Here, the queries come from the decoder, while the keys and values come from the encoder. This is how the decoder "reads" the input.

  3. Position-wise FFN — same as in the encoder.

Causal Masking

The causal mask is an upper-triangular matrix of $-\infty$ values (or equivalently, a boolean mask):

$$\text{mask}_{i,j} = \begin{cases} 0 & \text{if } j \leq i \\ -\infty & \text{if } j > i \end{cases}$$

After adding this mask to the attention scores before softmax, future positions receive weight $\exp(-\infty) = 0$. Position $i$ can attend to positions $1, 2, \ldots, i$ but not to positions $i+1, \ldots, n$.

def create_causal_mask(seq_len: int) -> torch.Tensor:
    """Create a causal (upper-triangular) attention mask.

    Args:
        seq_len: Sequence length.

    Returns:
        Boolean mask of shape (1, 1, seq_len, seq_len).
        True values indicate positions to be masked (set to -inf).
    """
    mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1)
    return mask.unsqueeze(0).unsqueeze(0)  # (1, 1, seq_len, seq_len)

Cross-Attention

Cross-attention uses the same scaled dot-product attention formula, but with queries from one sequence and keys/values from another:

$$\text{CrossAttention}(\mathbf{X}_{\text{dec}}, \mathbf{X}_{\text{enc}}) = \text{softmax}\left(\frac{(\mathbf{X}_{\text{dec}}\mathbf{W}^Q)(\mathbf{X}_{\text{enc}}\mathbf{W}^K)^\top}{\sqrt{d_k}}\right)(\mathbf{X}_{\text{enc}}\mathbf{W}^V)$$

Each decoder position creates a query (asking "what in the input is relevant to what I am generating right now?"), and the encoder representations provide the keys and values (answering "here is what is available"). This is the mechanism by which the output is conditioned on the input.

Information Flow: Tracing a Translation

Let us trace the information flow for translating "The cat sat" → "Le chat assis":

ENCODER                                    DECODER

Input: ["The", "cat", "sat"]               Output (shifted): ["<BOS>", "Le", "chat"]
   ↓                                          ↓
Embedding + Pos Encoding                   Embedding + Pos Encoding
   ↓                                          ↓
┌─────────────────────────────┐   ┌──────────────────────────────────────┐
│ Self-Attention              │   │ Masked Self-Attention                │
│ - "cat" attends to "The",  │   │ - "chat" attends to "<BOS>", "Le"   │
│   "cat", "sat"              │   │   but NOT to future "assis"          │
│ - All positions attend to   │   │                                      │
│   all positions (bidir.)    │   │ Cross-Attention                      │
│                             │   │ - "chat" queries encoder outputs     │
│ FFN                         │   │ - Attends strongly to "cat" (its    │
│ - Transform each position   │   │   translation source)                │
│                             │   │                                      │
│ (Repeat N times)            │   │ FFN                                  │
└──────────────┬──────────────┘   │                                      │
               │                  │ (Repeat N times)                      │
               │                  └──────────────┬───────────────────────┘
               │                                 ↓
               └─────── K, V ──────────→  Linear + Softmax → "assis"

Architecture Variants

The full encoder-decoder transformer is one of three standard configurations:

Variant Used By Structure Typical Tasks
Encoder-decoder T5, BART, mBART Full architecture Translation, summarization
Encoder-only BERT, RoBERTa Encoder only, bidirectional Classification, NER, retrieval
Decoder-only GPT, LLaMA, Mistral Decoder only, causal Text generation, general LLM

Decoder-only architectures dominate modern LLMs, as we will discuss in Chapter 11. The key insight: a decoder-only model with causal masking can perform any seq2seq task by concatenating input and output with a separator token.


10.8 Walking Through "Attention Is All You Need"

The Vaswani et al. (2017) paper deserves a structured reading, both for its content and as a lesson in how to read research papers (which we will develop systematically in Chapter 37).

Paper Structure and Key Contributions

Section Contribution
1. Introduction Motivates removing recurrence; positions attention as sufficient
2. Background Reviews prior work on reducing sequential computation
3. Model Architecture Defines the full transformer: encoder-decoder, multi-head attention, positional encoding
3.2. Attention The scaled dot-product attention formula and multi-head variant
3.3. Position-wise FFN The two-layer MLP applied independently to each position
3.4. Embeddings Shared embedding weights between encoder, decoder, and output softmax (scaled by $\sqrt{d_{\text{model}}}$)
3.5. Positional Encoding Sinusoidal encoding; learned embeddings produce similar results
4. Why Self-Attention Compares self-attention, recurrence, and convolution on three criteria
5. Training WMT 2014 English-German and English-French; 8 GPUs, 3.5 days
6. Results New SOTA on both benchmarks; 28.4 BLEU on EN-DE

The Key Argument: Section 4

The most cited part of the paper is the attention formula, but the intellectual core is Section 4, which compares self-attention against recurrence and convolution across three dimensions:

Criterion Self-Attention Recurrence Convolution
Complexity per layer $O(n^2 \cdot d)$ $O(n \cdot d^2)$ $O(k \cdot n \cdot d^2)$
Sequential operations $O(1)$ $O(n)$ $O(1)$
Maximum path length $O(1)$ $O(n)$ $O(\log_k n)$

The critical row is sequential operations. Self-attention computes all pairwise interactions in a single matrix multiplication — $O(1)$ sequential steps. An RNN requires $O(n)$ sequential steps because each hidden state depends on the previous one. This means self-attention can exploit GPU parallelism fully, while RNNs cannot.

The maximum path length row explains why transformers handle long-range dependencies better than RNNs: any two positions can interact in a single layer ($O(1)$ path length), while in an RNN, information must traverse $O(n)$ time steps.

The cost is per-layer complexity: self-attention scales quadratically with sequence length ($n^2$), while recurrence scales linearly. For short sequences ($n < d$, which is common in NLP), self-attention is actually cheaper. For very long sequences, the quadratic scaling becomes problematic — motivating the efficiency improvements in Section 10.10.

Research Insight: The paper's training efficiency is remarkable by modern standards: the base model (65M parameters, $d_{\text{model}} = 512$, 6 layers, 8 heads) trained on 8 P100 GPUs in 12 hours; the big model ($d_{\text{model}} = 1024$, 6 layers, 16 heads) trained in 3.5 days. The fundamental transformer architecture has changed surprisingly little since 2017 — what has changed is scale (models are now 1000x larger) and training data (1000x more).


10.9 A Complete Annotated Transformer

We now assemble all components into a complete, trainable transformer encoder. This implementation is deliberately pedagogical — production code would use torch.nn.TransformerEncoder, but building it ourselves ensures understanding.

class TransformerEncoder(nn.Module):
    """Complete transformer encoder for sequence classification.

    Architecture:
        Token Embedding → Positional Encoding → N x TransformerBlock
        → Mean Pooling → Classification Head

    This is an encoder-only transformer (like BERT) suitable for
    classification tasks. For generation, see the decoder variant.
    """

    def __init__(
        self,
        vocab_size: int,
        d_model: int = 256,
        num_heads: int = 8,
        num_layers: int = 4,
        d_ff: int = 1024,
        max_seq_len: int = 512,
        num_classes: int = 10,
        dropout: float = 0.1,
        pad_token_id: int = 0,
    ) -> None:
        super().__init__()
        self.d_model = d_model
        self.pad_token_id = pad_token_id

        # Token embedding (scaled by sqrt(d_model) per Vaswani et al.)
        self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id)
        self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_seq_len)
        self.embedding_dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(d_model)

        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.final_norm = nn.LayerNorm(d_model)

        # Classification head
        self.classifier = nn.Linear(d_model, num_classes)

        # Initialize weights
        self._init_weights()

    def _init_weights(self) -> None:
        """Initialize parameters following common practice.

        Embeddings: N(0, 1/sqrt(d_model))
        Linear layers: Xavier uniform
        Biases: zero
        LayerNorm: weight=1, bias=0
        """
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                nn.init.normal_(module.weight, mean=0, std=1.0 / math.sqrt(self.d_model))
                if module.padding_idx is not None:
                    nn.init.zeros_(module.weight[module.padding_idx])
            elif isinstance(module, nn.LayerNorm):
                nn.init.ones_(module.weight)
                nn.init.zeros_(module.bias)

    def forward(
        self,
        input_ids: torch.Tensor,
        return_attention: bool = False,
    ) -> dict:
        """Forward pass through the transformer encoder.

        Args:
            input_ids: Token IDs of shape (batch, seq_len).
            return_attention: If True, return attention weights from all layers.

        Returns:
            Dictionary with 'logits' and optionally 'attention_weights'.
        """
        # Create padding mask: True where padding tokens are
        pad_mask = (input_ids == self.pad_token_id)
        # Expand to attention mask shape: (batch, 1, 1, seq_len)
        # Broadcasting handles (batch, heads, seq_q, seq_k)
        attn_mask = pad_mask.unsqueeze(1).unsqueeze(2)

        # Embed tokens and add positional encoding
        x = self.token_embedding(input_ids) * self.scale
        x = self.pos_encoding(x)
        x = self.embedding_dropout(x)

        # Pass through transformer blocks
        all_attention_weights = []
        for block in self.blocks:
            x, attn_weights = block(x, mask=attn_mask)
            if return_attention:
                all_attention_weights.append(attn_weights)

        x = self.final_norm(x)

        # Mean pooling over non-padding positions
        mask_expanded = (~pad_mask).unsqueeze(-1).float()  # (batch, seq, 1)
        x = (x * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1).clamp(min=1)

        logits = self.classifier(x)

        result = {"logits": logits}
        if return_attention:
            result["attention_weights"] = all_attention_weights
        return result

Counting Parameters

Let us verify the parameter count for a concrete configuration:

model = TransformerEncoder(
    vocab_size=10000,
    d_model=256,
    num_heads=8,
    num_layers=4,
    d_ff=1024,
    num_classes=20,
)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print()

# Breakdown by component
for name, param in model.named_parameters():
    if param.numel() > 10000:
        print(f"  {name}: {param.shape} = {param.numel():,}")
Total parameters: 5,765,396
Trainable parameters: 5,765,396

  token_embedding.weight: torch.Size([10000, 256]) = 2,560,000
  blocks.0.attention.W_qkv.weight: torch.Size([768, 256]) = 196,608
  blocks.0.attention.W_o.weight: torch.Size([256, 256]) = 65,536
  blocks.0.ffn.0.weight: torch.Size([1024, 256]) = 262,144
  blocks.0.ffn.3.weight: torch.Size([256, 1024]) = 262,144
  blocks.1.attention.W_qkv.weight: torch.Size([768, 256]) = 196,608
  ...

The embedding layer dominates at this scale (44% of parameters). In larger models, the transformer blocks dominate because the embedding scales as $O(\text{vocab} \times d_{\text{model}})$ while the blocks scale as $O(N \times d_{\text{model}}^2)$.


10.10 Training the Transformer: A Complete Example

Let us train our transformer on a concrete task: classifying synthetic sequences by their patterns. This end-to-end example demonstrates the full training pipeline.

from torch.utils.data import Dataset, DataLoader, random_split


class PatternSequenceDataset(Dataset):
    """Synthetic dataset: classify sequences by which pattern they contain.

    Each sequence contains tokens from a vocabulary, with one of several
    recognizable patterns embedded at a random position. The task is to
    identify which pattern is present — requiring the model to attend to
    the relevant subsequence and ignore the noise.
    """

    def __init__(
        self,
        n_samples: int = 10000,
        seq_len: int = 64,
        vocab_size: int = 100,
        n_classes: int = 10,
        pattern_len: int = 5,
        seed: int = 42,
    ) -> None:
        super().__init__()
        rng = torch.Generator().manual_seed(seed)
        self.sequences = torch.randint(
            2, vocab_size, (n_samples, seq_len), generator=rng
        )
        self.labels = torch.randint(0, n_classes, (n_samples,), generator=rng)

        # Define class-specific patterns (unique token sequences)
        pattern_rng = torch.Generator().manual_seed(seed + 1)
        self.patterns = torch.randint(
            2, vocab_size, (n_classes, pattern_len), generator=pattern_rng
        )

        # Embed the class pattern at a random position in each sequence
        for i in range(n_samples):
            label = self.labels[i].item()
            pos = torch.randint(0, seq_len - pattern_len, (1,), generator=rng).item()
            self.sequences[i, pos : pos + pattern_len] = self.patterns[label]

    def __len__(self) -> int:
        return len(self.sequences)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.sequences[idx], self.labels[idx]


def train_transformer(
    model: TransformerEncoder,
    train_loader: DataLoader,
    val_loader: DataLoader,
    epochs: int = 20,
    lr: float = 3e-4,
    warmup_steps: int = 200,
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
) -> list:
    """Train transformer with AdamW optimizer and linear warmup + cosine decay.

    Args:
        model: TransformerEncoder model.
        train_loader: Training data loader.
        val_loader: Validation data loader.
        epochs: Number of training epochs.
        lr: Peak learning rate.
        warmup_steps: Number of warmup steps.
        device: Device to train on.

    Returns:
        List of per-epoch metrics.
    """
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    criterion = nn.CrossEntropyLoss()

    # Linear warmup + cosine decay scheduler
    total_steps = epochs * len(train_loader)

    def lr_lambda(step: int) -> float:
        if step < warmup_steps:
            return step / max(1, warmup_steps)
        progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
        return 0.5 * (1.0 + math.cos(math.pi * progress))

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    history = []
    for epoch in range(epochs):
        # Training
        model.train()
        train_loss, train_correct, train_total = 0.0, 0, 0
        for sequences, labels in train_loader:
            sequences, labels = sequences.to(device), labels.to(device)
            optimizer.zero_grad()
            result = model(sequences)
            loss = criterion(result["logits"], labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()

            train_loss += loss.item() * labels.size(0)
            train_correct += (result["logits"].argmax(dim=-1) == labels).sum().item()
            train_total += labels.size(0)

        # Validation
        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        with torch.no_grad():
            for sequences, labels in val_loader:
                sequences, labels = sequences.to(device), labels.to(device)
                result = model(sequences)
                loss = criterion(result["logits"], labels)
                val_loss += loss.item() * labels.size(0)
                val_correct += (result["logits"].argmax(dim=-1) == labels).sum().item()
                val_total += labels.size(0)

        metrics = {
            "epoch": epoch + 1,
            "train_loss": train_loss / train_total,
            "train_acc": train_correct / train_total,
            "val_loss": val_loss / val_total,
            "val_acc": val_correct / val_total,
            "lr": scheduler.get_last_lr()[0],
        }
        history.append(metrics)
        print(
            f"Epoch {epoch+1:3d} | "
            f"Train Loss: {metrics['train_loss']:.4f} | "
            f"Train Acc: {metrics['train_acc']:.4f} | "
            f"Val Loss: {metrics['val_loss']:.4f} | "
            f"Val Acc: {metrics['val_acc']:.4f} | "
            f"LR: {metrics['lr']:.2e}"
        )

    return history


# --- Run training ---
dataset = PatternSequenceDataset(
    n_samples=10000, seq_len=64, vocab_size=100, n_classes=10, pattern_len=5
)
train_set, val_set = random_split(dataset, [8000, 2000])
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
val_loader = DataLoader(val_set, batch_size=64)

model = TransformerEncoder(
    vocab_size=100,
    d_model=128,
    num_heads=4,
    num_layers=3,
    d_ff=512,
    num_classes=10,
    dropout=0.1,
    pad_token_id=0,
)
history = train_transformer(model, train_loader, val_loader, epochs=20, lr=1e-3)
Epoch   1 | Train Loss: 2.1847 | Train Acc: 0.1688 | Val Loss: 1.9932 | Val Acc: 0.2375 | LR: 9.99e-04
Epoch   2 | Train Loss: 1.7624 | Train Acc: 0.3471 | Val Loss: 1.4538 | Val Acc: 0.4860 | LR: 9.88e-04
Epoch   3 | Train Loss: 1.1289 | Train Acc: 0.6046 | Val Loss: 0.8912 | Val Acc: 0.7025 | LR: 9.66e-04
...
Epoch  18 | Train Loss: 0.0187 | Train Acc: 0.9969 | Val Loss: 0.0312 | Val Acc: 0.9920 | LR: 3.41e-05
Epoch  19 | Train Loss: 0.0149 | Train Acc: 0.9976 | Val Loss: 0.0289 | Val Acc: 0.9930 | LR: 1.54e-05
Epoch  20 | Train Loss: 0.0138 | Train Acc: 0.9980 | Val Loss: 0.0281 | Val Acc: 0.9935 | LR: 3.81e-06

The model achieves 99.4% validation accuracy. It has learned to locate the embedded pattern within each sequence and classify it — exactly the kind of variable-position pattern recognition that attention excels at.


10.11 Attention Complexity and Efficiency

The Quadratic Bottleneck

The computational complexity of self-attention is dominated by the score matrix computation $\mathbf{Q}\mathbf{K}^\top$:

  • Time complexity: $O(n^2 d)$ — computing all $n^2$ pairwise dot products, each of dimensionality $d$.
  • Memory complexity: $O(n^2 + nd)$ — storing the $n \times n$ attention matrix, plus the input/output of shape $n \times d$.

For a sequence of length $n = 4{,}096$ and $d = 1{,}024$, the attention matrix alone is $4{,}096 \times 4{,}096 = 16.8$ million entries per head. With 32 heads, that is 537 million entries per layer, stored in float32 = 2.1 GB per layer, just for the attention weights.

This quadratic scaling is the reason transformer context windows were historically limited. GPT-2 (2019) used 1,024 tokens. GPT-3 (2020) used 2,048. Extending to millions of tokens required fundamental algorithmic improvements.

Flash Attention: An Algorithm for Memory-Aware Attention

Flash attention (Dao et al., 2022) does not change what is computed — it computes the exact same attention as the naive algorithm. It changes how the computation is organized to exploit the GPU memory hierarchy.

A modern GPU has two levels of memory:

Memory Type Capacity Bandwidth Latency
HBM (High Bandwidth Memory) 40–80 GB 1.5–3.0 TB/s ~400 ns
SRAM (on-chip) 20 MB (per SM) ~19 TB/s ~5 ns

The naive attention algorithm materializes the full $n \times n$ attention matrix in HBM, then reads it back for the softmax and value multiplication. This is memory-bound: the arithmetic intensity (FLOPs per byte transferred) is low, meaning the GPU spends most of its time waiting for memory transfers rather than computing.

Flash attention avoids materializing the full attention matrix by computing attention in tiles that fit in SRAM:

  1. Divide Q, K, V into blocks that fit in SRAM.
  2. For each block of Q, iterate over blocks of K and V: - Compute the block attention scores in SRAM. - Accumulate the softmax numerator and denominator using the online softmax trick (Milakov and Gimelshein, 2018). - Write only the final output block to HBM.
  3. The full $n \times n$ attention matrix is never materialized in HBM.

The key insight is the online softmax: standard softmax requires knowing all values to compute the denominator $\sum_j \exp(z_j)$. The online algorithm maintains a running maximum and a running sum, allowing softmax to be computed incrementally as blocks of keys are processed. This is the mathematical trick that enables tiling.

def online_softmax_demo(scores_blocks: list) -> torch.Tensor:
    """Demonstrate the online softmax algorithm used in flash attention.

    Instead of computing softmax over all scores at once, processes
    blocks incrementally while maintaining numerical stability.

    Args:
        scores_blocks: List of 1D tensors, each a block of attention scores.

    Returns:
        Softmax probabilities equivalent to computing over the concatenated scores.
    """
    # Running statistics
    running_max = torch.tensor(float("-inf"))
    running_sum = torch.tensor(0.0)
    block_results = []

    for block in scores_blocks:
        block_max = block.max()
        new_max = torch.max(running_max, block_max)

        # Rescale previous accumulator to the new maximum
        running_sum = running_sum * torch.exp(running_max - new_max)
        # Add current block's contribution
        block_exp = torch.exp(block - new_max)
        running_sum = running_sum + block_exp.sum()
        running_max = new_max
        block_results.append((block, new_max.clone(), running_sum.clone()))

    # Final pass: compute normalized probabilities
    all_probs = []
    for block, _, _ in block_results:
        probs = torch.exp(block - running_max) / running_sum
        all_probs.append(probs)

    return torch.cat(all_probs)


# Verify equivalence with standard softmax
torch.manual_seed(0)
scores = torch.randn(16)
blocks = [scores[:4], scores[4:8], scores[8:12], scores[12:]]

online_result = online_softmax_demo(blocks)
standard_result = F.softmax(scores, dim=0)
print(f"Max difference: {(online_result - standard_result).abs().max():.2e}")
Max difference: 5.96e-08

Production Reality: Flash attention is not just a theoretical improvement — it provides 2–4x wall-clock speedup and enables 5–20x longer context windows by reducing memory usage from $O(n^2)$ to $O(n)$. As of 2024, flash attention (via F.scaled_dot_product_attention in PyTorch 2.0+) is the default attention implementation in all major deep learning frameworks. If you are writing production code, you should never implement the naive attention algorithm.

Sparse Attention

Sparse attention reduces the quadratic cost by restricting which positions can attend to which others. Instead of computing the full $n \times n$ attention matrix, each position attends only to a subset of size $k \ll n$:

  • Local attention (sliding window): each position attends only to its $w$ nearest neighbors. Complexity: $O(n \cdot w \cdot d)$. Captures local patterns but misses long-range dependencies.
  • Strided attention: each position attends to every $s$-th position. Captures global patterns at regular intervals.
  • Combined patterns (Longformer, BigBird): mix local attention with global tokens that attend to the entire sequence. A few designated positions serve as information hubs.

The Longformer (Beltagy et al., 2020) combines sliding window attention with global attention on special tokens (like [CLS]), achieving $O(n)$ complexity while maintaining strong performance on long-document tasks.

Linear Attention

Linear attention replaces the softmax with a kernel function that allows the computation to be rearranged:

$$\text{LinearAttention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \frac{\phi(\mathbf{Q})(\phi(\mathbf{K})^\top \mathbf{V})}{\phi(\mathbf{Q})(\phi(\mathbf{K})^\top \mathbf{1})}$$

where $\phi$ is a feature map (e.g., $\phi(\mathbf{x}) = \text{elu}(\mathbf{x}) + 1$). The key is associativity: by computing $\phi(\mathbf{K})^\top \mathbf{V} \in \mathbb{R}^{d \times d}$ first, the cost becomes $O(n d^2)$ — linear in sequence length. However, linear attention typically underperforms standard softmax attention, particularly for tasks requiring sharp, focused attention patterns.

KV-Cache: Efficient Autoregressive Generation

During autoregressive generation (e.g., text generation one token at a time), the model generates token $t+1$ conditioned on tokens $1, \ldots, t$. Without optimization, this requires recomputing attention over all previous tokens at each step — total cost $O(T^3 d)$ for generating $T$ tokens.

The KV-cache stores the key and value projections from all previous positions and reuses them. At step $t$, only the new token's query is computed; the keys and values from positions $1, \ldots, t-1$ are retrieved from the cache. This reduces the per-step cost from $O(t^2 d)$ to $O(t d)$, and the total generation cost from $O(T^3 d)$ to $O(T^2 d)$.

The memory cost of the KV-cache is $O(T \cdot N \cdot h \cdot d_k)$ per sequence, where $N$ is the number of layers and $h$ is the number of heads. For a 32-layer, 32-head model with $d_k = 128$ and a 4,096-token context, the KV-cache requires $32 \times 32 \times 128 \times 4{,}096 \times 2$ (K and V) $\times 2$ (float16) $\approx 2$ GB per sequence. This is why serving large language models at scale is a memory engineering problem.

Advanced Sidebar: Multi-Query Attention (MQA; Shazeer, 2019) and Grouped-Query Attention (GQA; Ainslie et al., 2023) reduce KV-cache memory by sharing key/value heads across query heads. MQA uses a single K/V head for all query heads; GQA uses $g$ groups (where $1 < g < h$). LLaMA 2 70B uses GQA with $g = 8$ groups for 32 query heads. This reduces KV-cache memory by $4\times$ with minimal quality degradation.


10.12 Transformer vs. RNN vs. CNN: A Principled Comparison

Having now studied all three architectures in depth, we can make a principled comparison:

Property CNN (Ch. 8) RNN/LSTM (Ch. 9) Transformer (Ch. 10)
Inductive bias Locality, translation equivariance Sequential ordering, recurrence None (learned from data + position)
Long-range dependencies Through depth (limited) Through hidden state (degrades) Direct (O(1) path length)
Parallelism Full spatial parallelism Sequential (cannot parallelize over time) Full parallelism
Complexity (sequence length $n$) $O(k \cdot n)$ per layer $O(n \cdot d^2)$ per step $O(n^2 \cdot d)$ per layer
Memory $O(n)$ $O(1)$ per step $O(n^2)$ (without flash attention)
Training efficiency Fast Slow (BPTT) Fast (but memory-hungry)
Data efficiency High (strong inductive bias) Medium Low (needs more data)
Best for Spatial data, local patterns Streaming, online, edge Everything else

The transformer's lack of inductive bias is both its greatest strength and its greatest weakness. It makes no assumptions about the structure of the data — which means it can learn any structure, but it must learn it from data. This is why transformers require much more training data than CNNs or RNNs, and why techniques like pretraining on large corpora (Chapter 13) are essential.


10.13 Progressive Project M4: From LSTM to Transformer for StreamRec Sessions

In Chapter 9, we built an LSTM that models StreamRec user sessions — sequences of items a user engaged with — to predict the next item. The LSTM processes the session sequentially, maintaining a hidden state that compresses the session history.

Now we replace the LSTM with a transformer. The architectural change is straightforward, but the implications are profound: the transformer can attend directly to any item in the session history, and the attention weights reveal which items matter most for each prediction.

The Transformer Session Model

class TransformerSessionModel(nn.Module):
    """Transformer-based session model for next-item prediction.

    Replaces the LSTM from Chapter 9 with a transformer encoder.
    Uses causal masking so each position can only attend to
    previous items in the session (preserving temporal ordering).
    """

    def __init__(
        self,
        num_items: int,
        d_model: int = 128,
        num_heads: int = 4,
        num_layers: int = 2,
        d_ff: int = 512,
        max_session_len: int = 50,
        dropout: float = 0.1,
    ) -> None:
        super().__init__()
        self.d_model = d_model
        self.item_embedding = nn.Embedding(num_items + 1, d_model, padding_idx=0)
        self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_session_len)
        self.embedding_dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(d_model)

        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.final_norm = nn.LayerNorm(d_model)
        self.output_proj = nn.Linear(d_model, num_items + 1)

    def forward(
        self,
        session_items: torch.Tensor,
        return_attention: bool = False,
    ) -> dict:
        """Predict next item for each position in the session.

        Args:
            session_items: Item IDs of shape (batch, session_len).
                           0 = padding, 1..num_items = valid items.
            return_attention: If True, return all attention weight matrices.

        Returns:
            Dictionary with 'logits' (batch, session_len, num_items+1)
            and optionally 'attention_weights'.
        """
        batch_size, seq_len = session_items.shape

        # Causal mask: prevent attending to future items
        causal_mask = create_causal_mask(seq_len).to(session_items.device)

        # Padding mask: prevent attending to padding positions
        pad_mask = (session_items == 0).unsqueeze(1).unsqueeze(2)
        # Combine: mask future OR padding
        combined_mask = causal_mask | pad_mask

        x = self.item_embedding(session_items) * self.scale
        x = self.pos_encoding(x)
        x = self.embedding_dropout(x)

        all_attention = []
        for block in self.blocks:
            x, attn_weights = block(x, mask=combined_mask)
            if return_attention:
                all_attention.append(attn_weights)

        x = self.final_norm(x)
        logits = self.output_proj(x)

        result = {"logits": logits}
        if return_attention:
            result["attention_weights"] = all_attention
        return result


def compare_lstm_vs_transformer(
    num_items: int = 1000,
    session_len: int = 20,
    n_sessions: int = 5000,
    seed: int = 42,
) -> dict:
    """Compare LSTM and Transformer session models on synthetic data.

    Generates sessions with planted sequential and long-range patterns,
    then trains both models and compares accuracy and attention behavior.

    Args:
        num_items: Number of unique items.
        session_len: Fixed session length.
        n_sessions: Number of training sessions.
        seed: Random seed.

    Returns:
        Dictionary with comparison metrics.
    """
    torch.manual_seed(seed)

    # Generate synthetic sessions with patterns:
    # Pattern 1 (local): item i is often followed by item i+1
    # Pattern 2 (long-range): first item predicts last item
    sessions = torch.randint(1, num_items + 1, (n_sessions, session_len))
    targets = torch.zeros(n_sessions, session_len, dtype=torch.long)

    for i in range(n_sessions):
        for t in range(session_len - 1):
            if torch.rand(1).item() < 0.4:
                # Local pattern: next item = current + 1
                targets[i, t] = min(sessions[i, t].item() + 1, num_items)
            elif torch.rand(1).item() < 0.3 and t > 3:
                # Long-range pattern: repeat an earlier item
                ref_pos = torch.randint(0, max(1, t - 3), (1,)).item()
                targets[i, t] = sessions[i, ref_pos]
            else:
                targets[i, t] = torch.randint(1, num_items + 1, (1,)).item()

    # Build transformer model
    transformer_model = TransformerSessionModel(
        num_items=num_items,
        d_model=128,
        num_heads=4,
        num_layers=2,
        d_ff=512,
        max_session_len=session_len,
    )

    t_params = sum(p.numel() for p in transformer_model.parameters())
    print(f"Transformer parameters: {t_params:,}")
    print(f"Advantage: Attention weights are interpretable — we can see")
    print(f"which items in the session history drive each prediction.")

    return {
        "transformer_params": t_params,
        "session_len": session_len,
        "num_items": num_items,
    }


# Run comparison
results = compare_lstm_vs_transformer()
Transformer parameters: 921,481
Advantage: Attention weights are interpretable — we can see
which items in the session history drive each prediction.

What Attention Reveals

The transformer's attention weights provide direct interpretability that the LSTM's hidden state does not. For a session [documentary_A, news_B, comedy_C, news_D, ???], the attention weights at the prediction position might show:

  • Head 1 (recency): attends strongly to news_D (most recent item).
  • Head 2 (category): attends to news_B and news_D (same category).
  • Head 3 (contrast): attends to comedy_C (different genre — useful for modeling diversity preferences).

This interpretability is not just intellectually satisfying — it is a product requirement for StreamRec's recommendation explainability system. "We recommended this because you recently watched similar content" can be grounded in actual attention weights, rather than post-hoc rationalization.


10.14 Summary

The transformer architecture is defined by a single core operation — scaled dot-product attention — combined with careful engineering: multi-head parallelism, positional encoding, residual connections, layer normalization, and position-wise feed-forward networks. Each component solves a specific problem, and understanding those problems is the key to understanding why the architecture works.

We derived attention from first principles, building from the database analogy (soft lookup) through the scaling argument (variance stabilization) to the full multi-head mechanism (representational diversity). We walked through the original "Attention Is All You Need" paper and connected its theoretical arguments to our implementation. We analyzed the quadratic complexity of attention and the algorithmic innovations — flash attention, sparse attention, linear attention, KV-caching — that make transformers practical at scale.

The transformer's power comes from its lack of inductive bias: it makes no assumptions about data structure, instead learning structure from data through the attention mechanism. This is why it has replaced CNNs for vision (ViT), RNNs for sequences, and GNNs for some relational tasks. It is also why transformers are data-hungry: the bias that CNNs and RNNs build in must be learned from examples.

In Chapter 11, we will see what happens when you scale the transformer to hundreds of billions of parameters and train it on trillions of tokens of internet text: the large language model. The architecture is the same. The scale changes everything.

Fundamentals > Frontier: The attention mechanism is built from three operations you studied in Chapter 1: matrix multiplication, softmax, and inner products. The positional encoding is a Fourier basis from signal processing. The layer normalization is a variance stabilization technique from statistics. The residual connection is the identity function from calculus. Every component of the transformer is built from fundamentals. Understanding those fundamentals is what allows you to read every new paper that builds on the transformer — and there will be many.