In Chapter 18, we explored the mechanics of attention: how a model can learn to selectively focus on different parts of its input when producing each element of its output. We saw scaled dot-product attention, multi-head attention, and the...
In This Chapter
- Introduction
- 19.1 The "Attention Is All You Need" Paper
- 19.2 Positional Encoding
- 19.3 Layer Normalization
- 19.4 Feed-Forward Networks
- 19.5 Residual Connections
- 19.6 The Encoder Block
- 19.7 The Decoder Block
- 19.8 Building a Complete Transformer
- 19.9 The Training Process
- 19.10 Training on a Toy Translation Task
- 19.11 Design Decisions and Variations
- 19.12 Understanding Information Flow
- 19.13 Practical Considerations
- 19.14 Production Implementation Details
- 19.15 From the Original Transformer to Modern Architectures
- 19.16 Summary
- References
Chapter 19: The Transformer Architecture
"Attention is all you need." --- Ashish Vaswani et al., 2017
Introduction
In Chapter 18, we explored the mechanics of attention: how a model can learn to selectively focus on different parts of its input when producing each element of its output. We saw scaled dot-product attention, multi-head attention, and the mathematical machinery that makes it all work. Now we arrive at the architecture that brought these ideas together into one of the most consequential papers in the history of deep learning.
The Transformer, introduced by Vaswani et al. in their 2017 paper "Attention Is All You Need," did not merely add attention to an existing architecture. It made attention the entire architecture. By discarding the recurrent connections that had defined sequence modeling for decades, the Transformer achieved something remarkable: it was faster to train, easier to parallelize, and --- as the field would soon discover --- more capable than anything that came before it.
This chapter takes you inside the Transformer, component by component. We will build every piece from scratch in PyTorch, understand the mathematical reasoning behind each design choice, and assemble a complete, working Transformer that you can train on a toy translation task. By the end, you will not just understand the Transformer --- you will have built one.
What You Will Learn
- The historical context and motivation behind the Transformer
- How positional encoding injects sequence order into an attention-based model
- The role of layer normalization, residual connections, and feed-forward networks
- How encoder blocks and decoder blocks are assembled
- How cross-attention connects the encoder and decoder
- How to build and train a complete Transformer in PyTorch
- The key design decisions that make the architecture work
Prerequisites
This chapter builds directly on Chapter 18 (Attention Mechanisms). You should be comfortable with:
- Scaled dot-product attention and multi-head attention (Chapter 18)
- Matrix multiplication and softmax operations (Chapter 3)
- Neural network fundamentals including backpropagation (Chapters 7--9)
- Basic PyTorch including
nn.Moduleand autograd (Chapter 10)
19.1 The "Attention Is All You Need" Paper
19.1.1 Historical Context
By 2017, the dominant paradigm for sequence-to-sequence tasks like machine translation was the encoder-decoder architecture built from recurrent neural networks (RNNs), typically LSTMs or GRUs. These models processed input sequences one token at a time, maintaining a hidden state that carried information forward through the sequence. Attention mechanisms, as we discussed in Chapter 18, had been added on top of these recurrent architectures to allow the decoder to look back at all encoder states rather than compressing everything into a single vector.
The problem was speed. Because RNNs process tokens sequentially --- the hidden state at position $t$ depends on the hidden state at position $t-1$ --- training could not be effectively parallelized across sequence positions. As datasets grew larger and models grew deeper, this sequential bottleneck became increasingly painful. Researchers at Google Brain and Google Research asked a provocative question: what if we could build a sequence-to-sequence model using only attention, with no recurrence at all?
19.1.2 The Key Insight
The central insight of the Transformer is that self-attention --- where a sequence attends to itself --- can replace recurrence entirely. Instead of building up a representation of a sequence by processing it token by token, self-attention allows every token to interact with every other token in a single operation. This makes the computation inherently parallel: all positions can be processed simultaneously.
Of course, this introduces a challenge. Without recurrence, the model has no inherent notion of token order. The sentence "the cat sat on the mat" and "mat the on sat cat the" would produce identical representations. The Transformer solves this with positional encoding, which we will explore in detail shortly.
19.1.3 The Architecture at a Glance
The Transformer follows the encoder-decoder pattern that was already standard for machine translation:
- The encoder reads the entire input sequence and produces a rich representation of it.
- The decoder generates the output sequence one token at a time, attending both to its own previously generated tokens and to the encoder's representation.
What is different from previous encoder-decoder models is how each component is built. Both the encoder and decoder are stacks of identical layers (blocks), and each block is composed entirely of attention mechanisms, feed-forward networks, layer normalization, and residual connections. There are no recurrent or convolutional layers anywhere.
The original Transformer used $N = 6$ layers for both the encoder and decoder, a model dimension of $d_{\text{model}} = 512$, and $h = 8$ attention heads with $d_k = d_v = 64$.
19.2 Positional Encoding
19.2.1 The Problem of Position
As we noted above, self-attention is inherently permutation-invariant. Given a set of input vectors, multi-head attention produces the same output regardless of the order in which those vectors are presented. But order matters enormously in language --- "dog bites man" means something very different from "man bites dog."
To solve this, the Transformer adds positional encodings to the input embeddings before they enter the first encoder or decoder layer. These encodings carry information about each token's position in the sequence, allowing the model to reason about order.
19.2.2 Sinusoidal Positional Encoding
The original Transformer uses a deterministic, sinusoidal encoding scheme. For a token at position $pos$ and embedding dimension $i$, the positional encoding is:
$$PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)$$
$$PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)$$
where $d_{\text{model}}$ is the dimensionality of the model. Even dimensions use sine, odd dimensions use cosine.
Why sinusoids? There are several elegant properties:
- Unique encoding: Each position gets a unique pattern across the $d_{\text{model}}$ dimensions.
- Bounded values: Sine and cosine are bounded in $[-1, 1]$, so the positional signal does not overwhelm the embedding values.
- Relative position information: For any fixed offset $k$, $PE_{pos+k}$ can be expressed as a linear function of $PE_{pos}$. This means the model can easily learn to attend to relative positions.
- Generalization to unseen lengths: Because the encoding is a continuous function of position, the model can extrapolate to sequence lengths longer than those seen during training (though in practice, performance degrades).
The intuition is that each dimension of the positional encoding oscillates at a different frequency. The low-order dimensions change rapidly (high frequency), encoding fine-grained position information. The high-order dimensions change slowly (low frequency), encoding coarse-grained position information. Together, they create a unique "fingerprint" for each position.
19.2.3 Learned Positional Encoding
An alternative to sinusoidal encoding is to use learned positional embeddings. In this approach, we create a learnable embedding matrix of shape $(L_{\max}, d_{\text{model}})$, where $L_{\max}$ is the maximum sequence length. Each position gets its own learnable vector, and these are optimized along with the rest of the model parameters during training.
The original Transformer paper found that sinusoidal and learned encodings performed comparably. In practice:
- Sinusoidal: No additional parameters, can potentially generalize to longer sequences, deterministic.
- Learned: More flexible, can capture task-specific positional patterns, but limited to the maximum length seen during training.
Many modern Transformer variants (BERT, GPT-2) use learned positional embeddings. More recent models use rotary positional embeddings (RoPE) or ALiBi, but these are beyond the scope of this chapter.
19.2.4 Implementation
The positional encoding is added (not concatenated) to the token embeddings:
$$\mathbf{x}_{\text{input}} = \text{Embedding}(\text{token}) + PE_{\text{pos}}$$
This addition works because both the token embedding and the positional encoding live in the same $d_{\text{model}}$-dimensional space. The model learns to disentangle the two signals.
import torch
import torch.nn as nn
import math
torch.manual_seed(42)
class SinusoidalPositionalEncoding(nn.Module):
"""Sinusoidal positional encoding from 'Attention Is All You Need'.
Args:
d_model: Dimension of the model embeddings.
max_len: Maximum sequence length to precompute.
dropout: Dropout rate applied after adding positional encoding.
"""
def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1) -> None:
super().__init__()
self.dropout = nn.Dropout(p=dropout)
# Create positional encoding matrix
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # Shape: (1, max_len, d_model)
# Register as buffer (not a parameter, but saved with the model)
self.register_buffer("pe", pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Add positional encoding to input embeddings.
Args:
x: Input tensor of shape (batch_size, seq_len, d_model).
Returns:
Tensor with positional encoding added, same shape as input.
"""
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
We register the positional encoding matrix as a buffer rather than a parameter. This means it will be saved with the model and moved to the correct device, but it will not receive gradient updates during training.
See code/example-01-positional-encoding.py for a complete, runnable example with visualization.
19.3 Layer Normalization
19.3.1 Why Normalize?
Training deep networks is challenging because the distribution of inputs to each layer changes as the parameters of preceding layers are updated. This phenomenon, sometimes called internal covariate shift, can slow down training and make it harder to use large learning rates. Normalization techniques address this by standardizing the inputs to each layer.
While batch normalization (normalizing across the batch dimension) is standard in vision models, it is problematic for sequence tasks where sequences have variable lengths and batch statistics can be noisy with small batches. Layer normalization normalizes across the feature dimension for each individual example, making it independent of batch size and sequence length.
19.3.2 Layer Normalization: The Math
Given an input vector $\mathbf{x} \in \mathbb{R}^{d}$, layer normalization computes:
$$\text{LayerNorm}(\mathbf{x}) = \gamma \odot \frac{\mathbf{x} - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta$$
where: - $\mu = \frac{1}{d}\sum_{i=1}^{d} x_i$ is the mean across features - $\sigma^2 = \frac{1}{d}\sum_{i=1}^{d} (x_i - \mu)^2$ is the variance across features - $\gamma, \beta \in \mathbb{R}^{d}$ are learnable scale and shift parameters - $\epsilon$ is a small constant for numerical stability (typically $10^{-5}$ or $10^{-6}$) - $\odot$ denotes element-wise multiplication
19.3.3 Post-Norm vs. Pre-Norm
The original Transformer paper places layer normalization after the residual connection --- a configuration known as post-norm:
$$\text{output} = \text{LayerNorm}(\mathbf{x} + \text{Sublayer}(\mathbf{x}))$$
Later work found that placing layer normalization before the sublayer --- known as pre-norm --- leads to more stable training, especially for deeper models:
$$\text{output} = \mathbf{x} + \text{Sublayer}(\text{LayerNorm}(\mathbf{x}))$$
The pre-norm variant has become the more common choice in modern Transformer implementations. The key advantage is that the residual path remains an unimpeded identity mapping from input to output, which helps gradients flow through very deep networks. GPT-2, GPT-3, and many other large language models use pre-norm.
In this chapter, we will implement both variants but default to pre-norm in our complete Transformer, noting where the original paper differs.
19.3.4 Why Pre-Norm Improves Training Stability
The practical difference between pre-norm and post-norm becomes dramatic for deep Transformers. Xiong et al. (2020) provided a theoretical analysis showing that in post-norm Transformers, the expected gradient magnitude at initialization grows with the number of layers, requiring careful learning rate warmup to avoid divergence. In pre-norm Transformers, the gradient magnitude is well-behaved regardless of depth.
To understand this intuitively, consider the backward pass. In post-norm, the gradient flows through the layer normalization at each step, which can amplify or attenuate it depending on the current activation statistics. In pre-norm, the gradient through the residual connection bypasses normalization entirely -- it flows through a clean identity path from the output back to the input. This is why pre-norm models often train stably even without learning rate warmup, while post-norm models can diverge without it.
A practical consequence: if you are implementing a Transformer from scratch and your model is training unstably (loss spikes, NaN values), switching from post-norm to pre-norm is often the first remedy to try. As we discussed in Chapter 12 regarding training stability, the interaction between normalization and residual connections is one of the most critical factors in deep network training.
19.3.5 RMSNorm: A Simplified Alternative
RMSNorm (Root Mean Square Layer Normalization) simplifies layer normalization by removing the mean centering step:
$$\text{RMSNorm}(\mathbf{x}) = \gamma \odot \frac{\mathbf{x}}{\sqrt{\frac{1}{d}\sum_{i=1}^{d} x_i^2 + \epsilon}}$$
RMSNorm is computationally cheaper (no need to compute and subtract the mean) and has been adopted by several major language models including LLaMA and Gemma. Empirically, the mean centering in standard layer normalization appears to contribute little beyond the variance normalization, making RMSNorm a practical simplification with negligible performance loss.
19.4 Feed-Forward Networks
19.4.1 Position-wise Feed-Forward Networks
Each Transformer block contains a position-wise feed-forward network (FFN). This is a simple two-layer fully connected network applied independently to each position (token) in the sequence:
$$\text{FFN}(\mathbf{x}) = W_2 \cdot \text{ReLU}(W_1 \mathbf{x} + b_1) + b_2$$
where: - $W_1 \in \mathbb{R}^{d_{\text{ff}} \times d_{\text{model}}}$ projects from the model dimension to a larger intermediate dimension - $W_2 \in \mathbb{R}^{d_{\text{model}} \times d_{\text{ff}}}$ projects back to the model dimension - $d_{\text{ff}}$ is typically $4 \times d_{\text{model}}$ (the original paper uses $d_{\text{ff}} = 2048$ with $d_{\text{model}} = 512$)
The term "position-wise" means that the same linear transformation is applied to every position, but independently. Across positions, there is no interaction in the FFN --- that job belongs to the attention mechanism. You can think of the FFN as the part of the block that processes the information gathered by attention, performing nonlinear transformations on the representation at each position.
19.4.2 Why a Two-Layer Network?
The FFN introduces the model's primary source of nonlinearity (beyond what attention itself provides). The expansion to $d_{\text{ff}} = 4 \times d_{\text{model}}$ creates a bottleneck architecture: the information is projected into a higher-dimensional space where the nonlinear activation can operate, then projected back down. This gives the model a larger "working space" for computation while keeping the residual stream at a manageable dimension.
19.4.3 Activation Functions
The original Transformer uses ReLU. Many modern variants use GELU (Gaussian Error Linear Unit), which is smoother and empirically performs slightly better:
$$\text{GELU}(x) = x \cdot \Phi(x) \approx 0.5x\left(1 + \tanh\left[\sqrt{2/\pi}(x + 0.044715x^3)\right]\right)$$
where $\Phi(x)$ is the standard normal cumulative distribution function. Other variants use SwiGLU or GeGLU gated activations, but we will use GELU in our implementation for a good balance of simplicity and performance.
class PositionWiseFeedForward(nn.Module):
"""Position-wise feed-forward network used in each Transformer block.
Args:
d_model: Dimension of the model.
d_ff: Dimension of the inner feed-forward layer.
dropout: Dropout rate.
"""
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1) -> None:
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
self.activation = nn.GELU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply position-wise feed-forward transformation.
Args:
x: Input tensor of shape (batch_size, seq_len, d_model).
Returns:
Output tensor of same shape as input.
"""
return self.linear2(self.dropout(self.activation(self.linear1(x))))
19.5 Residual Connections
19.5.1 The Gradient Highway
Residual connections, introduced by He et al. (2016) for deep convolutional networks, are critical to the Transformer's ability to train deep stacks of layers. The idea is simple: instead of computing $\mathbf{y} = F(\mathbf{x})$, we compute:
$$\mathbf{y} = \mathbf{x} + F(\mathbf{x})$$
The function $F$ only needs to learn the residual --- the difference between the desired output and the input. This creates a "gradient highway" through the network: during backpropagation, the gradient can flow directly through the addition operation, bypassing $F$ entirely. This prevents the vanishing gradient problem that plagues very deep networks.
19.5.2 Residual Connections in the Transformer
In the Transformer, every sublayer (multi-head attention and feed-forward network) is wrapped in a residual connection. With pre-norm, each sublayer takes the form:
$$\mathbf{y} = \mathbf{x} + \text{Sublayer}(\text{LayerNorm}(\mathbf{x}))$$
This means the residual stream --- the path through the $+$ operations --- forms a direct connection from the input of the first layer to the output of the last layer. Each sublayer contributes additive updates to this stream. This perspective, sometimes called the "residual stream" view, is important for understanding how Transformers process information.
19.5.3 Dropout on Sublayer Outputs
The original Transformer applies dropout to the output of each sublayer before it is added to the residual. This acts as a form of regularization, randomly zeroing out some of the sublayer's contribution during training:
$$\mathbf{y} = \mathbf{x} + \text{Dropout}(\text{Sublayer}(\text{LayerNorm}(\mathbf{x})))$$
The original paper uses a dropout rate of 0.1 for the base model.
19.6 The Encoder Block
19.6.1 Structure
An encoder block (also called an encoder layer) consists of two sublayers:
-
Multi-head self-attention: Each position in the input attends to all positions, including itself. This is exactly the multi-head attention mechanism from Chapter 18, where the queries, keys, and values all come from the same sequence.
-
Position-wise feed-forward network: The two-layer FFN described in Section 19.4.
Each sublayer is wrapped with a residual connection and layer normalization. The complete encoder block (using pre-norm) computes:
$$\mathbf{z} = \mathbf{x} + \text{Dropout}(\text{MultiHeadAttn}(\text{LN}(\mathbf{x}), \text{LN}(\mathbf{x}), \text{LN}(\mathbf{x})))$$
$$\text{output} = \mathbf{z} + \text{Dropout}(\text{FFN}(\text{LN}(\mathbf{z})))$$
where $\text{LN}$ denotes layer normalization.
19.6.2 No Masking in the Encoder
A crucial distinction: the encoder uses unmasked self-attention. Every position can attend to every other position in the input sequence. This is appropriate because the encoder processes the entire input at once --- there is no notion of "future" tokens in the encoder.
The only masking in the encoder is padding masking: if input sequences are padded to the same length within a batch, the attention mechanism should ignore the padding positions. This is achieved by setting the attention scores for padding positions to $-\infty$ before the softmax, which drives their attention weights to zero.
19.6.3 Implementation
class TransformerEncoderBlock(nn.Module):
"""Single encoder block of the Transformer.
Uses pre-norm configuration with multi-head self-attention
followed by a position-wise feed-forward network.
Args:
d_model: Dimension of the model.
n_heads: Number of attention heads.
d_ff: Dimension of the feed-forward inner layer.
dropout: Dropout rate.
"""
def __init__(
self,
d_model: int,
n_heads: int,
d_ff: int,
dropout: float = 0.1,
) -> None:
super().__init__()
self.self_attn = nn.MultiheadAttention(
embed_dim=d_model,
num_heads=n_heads,
dropout=dropout,
batch_first=True,
)
self.ffn = PositionWiseFeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(
self,
src: torch.Tensor,
src_mask: torch.Tensor | None = None,
src_key_padding_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass through encoder block.
Args:
src: Source sequence of shape (batch_size, seq_len, d_model).
src_mask: Attention mask of shape (seq_len, seq_len).
src_key_padding_mask: Padding mask of shape (batch_size, seq_len).
Returns:
Output tensor of same shape as src.
"""
# Self-attention sublayer with pre-norm
normed = self.norm1(src)
attn_output, _ = self.self_attn(
normed, normed, normed,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)
src = src + self.dropout1(attn_output)
# Feed-forward sublayer with pre-norm
normed = self.norm2(src)
ff_output = self.ffn(normed)
src = src + self.dropout2(ff_output)
return src
19.6.4 The Encoder Stack
The complete encoder is a stack of $N$ identical encoder blocks, preceded by embedding and positional encoding, and followed by a final layer normalization (necessary with pre-norm to normalize the output of the last block):
class TransformerEncoder(nn.Module):
"""Complete Transformer encoder stack.
Args:
vocab_size: Size of the source vocabulary.
d_model: Dimension of the model.
n_heads: Number of attention heads.
d_ff: Dimension of the feed-forward inner layer.
n_layers: Number of encoder blocks.
max_len: Maximum sequence length.
dropout: Dropout rate.
"""
def __init__(
self,
vocab_size: int,
d_model: int,
n_heads: int,
d_ff: int,
n_layers: int,
max_len: int = 5000,
dropout: float = 0.1,
) -> None:
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_len, dropout)
self.layers = nn.ModuleList([
TransformerEncoderBlock(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
self.norm = nn.LayerNorm(d_model)
self.d_model = d_model
def forward(
self,
src: torch.Tensor,
src_mask: torch.Tensor | None = None,
src_key_padding_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Encode source sequence.
Args:
src: Source token IDs of shape (batch_size, seq_len).
src_mask: Attention mask.
src_key_padding_mask: Padding mask.
Returns:
Encoded representation of shape (batch_size, seq_len, d_model).
"""
# Scale embeddings by sqrt(d_model) as in the original paper
x = self.embedding(src) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
for layer in self.layers:
x = layer(x, src_mask, src_key_padding_mask)
return self.norm(x)
Note the scaling of embeddings by $\sqrt{d_{\text{model}}}$. The original paper does this to ensure that the embedding values are on a similar scale to the positional encodings, which have values roughly in $[-1, 1]$. Without this scaling, the positional signal could be drowned out by the embeddings.
19.7 The Decoder Block
19.7.1 Structure
The decoder block is more complex than the encoder block because it has three sublayers instead of two:
-
Masked multi-head self-attention: The decoder attends to its own previous outputs, but with a causal mask that prevents positions from attending to future positions. This is essential during training to maintain the autoregressive property --- the prediction for position $t$ should only depend on positions $1, \ldots, t-1$.
-
Multi-head cross-attention: The decoder attends to the encoder's output. Here, the queries come from the decoder, while the keys and values come from the encoder. This is the mechanism by which the decoder "reads" the source sequence. As we discussed in Chapter 18, this is the classic attention pattern from sequence-to-sequence models.
-
Position-wise feed-forward network: Same as in the encoder.
Each sublayer is wrapped with a residual connection and layer normalization.
19.7.2 Causal Masking
The causal mask (also called the "look-ahead mask" or "subsequent mask") is an upper-triangular matrix of $-\infty$ values:
$$M_{ij} = \begin{cases} 0 & \text{if } i \geq j \\ -\infty & \text{if } i < j \end{cases}$$
This mask is added to the attention scores before the softmax. Positions where the mask is $-\infty$ get attention weights of zero after softmax, effectively preventing the decoder from "peeking" at future tokens.
def generate_causal_mask(size: int) -> torch.Tensor:
"""Generate a causal (look-ahead) mask for the decoder.
Args:
size: Sequence length.
Returns:
Upper-triangular mask of shape (size, size) with -inf above
the diagonal and 0 on and below.
"""
mask = torch.triu(torch.ones(size, size), diagonal=1)
mask = mask.masked_fill(mask == 1, float("-inf"))
return mask
19.7.3 Cross-Attention in Detail
Cross-attention is the bridge between the encoder and decoder. In multi-head attention terminology:
- Queries ($Q$): Derived from the decoder's representation
- Keys ($K$) and Values ($V$): Derived from the encoder's output
This asymmetric structure is what makes cross-attention fundamentally different from self-attention. In self-attention, a token asks "what in my sequence is relevant to me?" In cross-attention, a decoder token asks "what in the source sequence is relevant for generating me?"
The mathematical formulation is identical to standard multi-head attention from Chapter 18:
$$\text{CrossAttn}(\mathbf{Q}_{\text{dec}}, \mathbf{K}_{\text{enc}}, \mathbf{V}_{\text{enc}}) = \text{softmax}\left(\frac{\mathbf{Q}_{\text{dec}} \mathbf{K}_{\text{enc}}^T}{\sqrt{d_k}}\right) \mathbf{V}_{\text{enc}}$$
where $\mathbf{Q}_{\text{dec}} = \mathbf{H}_{\text{dec}} \mathbf{W}^Q$ comes from the decoder hidden states and $\mathbf{K}_{\text{enc}} = \mathbf{H}_{\text{enc}} \mathbf{W}^K$, $\mathbf{V}_{\text{enc}} = \mathbf{H}_{\text{enc}} \mathbf{W}^V$ come from the encoder output. Note that the attention matrix has shape $(\text{tgt\_len}, \text{src\_len})$ -- each target position produces a distribution over source positions, determining how much to "read" from each part of the input.
An important efficiency note: since the encoder output does not change during decoding, the key and value projections $\mathbf{K}_{\text{enc}}$ and $\mathbf{V}_{\text{enc}}$ need to be computed only once and can be reused at every decoding step. This is another form of KV-caching that production implementations exploit.
This allows each decoder position to attend to all positions in the source sequence, deciding how much information to draw from each source token when generating the current target token. If you recall from Chapter 18, this is exactly the attention mechanism used in the original Bahdanau attention for neural machine translation -- but now it is multi-headed and integrated more cleanly into the architecture.
19.7.4 Implementation
class TransformerDecoderBlock(nn.Module):
"""Single decoder block of the Transformer.
Uses pre-norm configuration with masked self-attention,
cross-attention to encoder output, and a feed-forward network.
Args:
d_model: Dimension of the model.
n_heads: Number of attention heads.
d_ff: Dimension of the feed-forward inner layer.
dropout: Dropout rate.
"""
def __init__(
self,
d_model: int,
n_heads: int,
d_ff: int,
dropout: float = 0.1,
) -> None:
super().__init__()
# Masked self-attention
self.self_attn = nn.MultiheadAttention(
embed_dim=d_model,
num_heads=n_heads,
dropout=dropout,
batch_first=True,
)
# Cross-attention
self.cross_attn = nn.MultiheadAttention(
embed_dim=d_model,
num_heads=n_heads,
dropout=dropout,
batch_first=True,
)
# Feed-forward network
self.ffn = PositionWiseFeedForward(d_model, d_ff, dropout)
# Layer norms (one per sublayer)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
# Dropout for residual connections
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
def forward(
self,
tgt: torch.Tensor,
memory: torch.Tensor,
tgt_mask: torch.Tensor | None = None,
memory_mask: torch.Tensor | None = None,
tgt_key_padding_mask: torch.Tensor | None = None,
memory_key_padding_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass through decoder block.
Args:
tgt: Target sequence of shape (batch_size, tgt_len, d_model).
memory: Encoder output of shape (batch_size, src_len, d_model).
tgt_mask: Causal mask of shape (tgt_len, tgt_len).
memory_mask: Cross-attention mask.
tgt_key_padding_mask: Target padding mask.
memory_key_padding_mask: Source padding mask.
Returns:
Output tensor of same shape as tgt.
"""
# Sublayer 1: Masked self-attention
normed = self.norm1(tgt)
self_attn_out, _ = self.self_attn(
normed, normed, normed,
attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask,
)
tgt = tgt + self.dropout1(self_attn_out)
# Sublayer 2: Cross-attention
normed = self.norm2(tgt)
cross_attn_out, _ = self.cross_attn(
normed, memory, memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)
tgt = tgt + self.dropout2(cross_attn_out)
# Sublayer 3: Feed-forward network
normed = self.norm3(tgt)
ff_out = self.ffn(normed)
tgt = tgt + self.dropout3(ff_out)
return tgt
19.7.5 The Decoder Stack
class TransformerDecoder(nn.Module):
"""Complete Transformer decoder stack.
Args:
vocab_size: Size of the target vocabulary.
d_model: Dimension of the model.
n_heads: Number of attention heads.
d_ff: Dimension of the feed-forward inner layer.
n_layers: Number of decoder blocks.
max_len: Maximum sequence length.
dropout: Dropout rate.
"""
def __init__(
self,
vocab_size: int,
d_model: int,
n_heads: int,
d_ff: int,
n_layers: int,
max_len: int = 5000,
dropout: float = 0.1,
) -> None:
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_len, dropout)
self.layers = nn.ModuleList([
TransformerDecoderBlock(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
self.norm = nn.LayerNorm(d_model)
self.d_model = d_model
def forward(
self,
tgt: torch.Tensor,
memory: torch.Tensor,
tgt_mask: torch.Tensor | None = None,
memory_mask: torch.Tensor | None = None,
tgt_key_padding_mask: torch.Tensor | None = None,
memory_key_padding_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Decode target sequence using encoder memory.
Args:
tgt: Target token IDs of shape (batch_size, tgt_len).
memory: Encoder output of shape (batch_size, src_len, d_model).
tgt_mask: Causal mask.
memory_mask: Cross-attention mask.
tgt_key_padding_mask: Target padding mask.
memory_key_padding_mask: Source padding mask.
Returns:
Decoded representation of shape (batch_size, tgt_len, d_model).
"""
x = self.embedding(tgt) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
for layer in self.layers:
x = layer(
x, memory, tgt_mask, memory_mask,
tgt_key_padding_mask, memory_key_padding_mask,
)
return self.norm(x)
19.8 Building a Complete Transformer
19.8.1 Putting It All Together
Now we assemble the encoder, decoder, and output projection into a complete Transformer model:
class Transformer(nn.Module):
"""Complete Transformer model for sequence-to-sequence tasks.
Combines encoder, decoder, and output projection into
a single module following the architecture from
'Attention Is All You Need' (Vaswani et al., 2017).
Args:
src_vocab_size: Size of the source vocabulary.
tgt_vocab_size: Size of the target vocabulary.
d_model: Dimension of the model.
n_heads: Number of attention heads.
d_ff: Dimension of the feed-forward inner layer.
n_encoder_layers: Number of encoder blocks.
n_decoder_layers: Number of decoder blocks.
max_len: Maximum sequence length.
dropout: Dropout rate.
"""
def __init__(
self,
src_vocab_size: int,
tgt_vocab_size: int,
d_model: int = 512,
n_heads: int = 8,
d_ff: int = 2048,
n_encoder_layers: int = 6,
n_decoder_layers: int = 6,
max_len: int = 5000,
dropout: float = 0.1,
) -> None:
super().__init__()
self.encoder = TransformerEncoder(
src_vocab_size, d_model, n_heads, d_ff,
n_encoder_layers, max_len, dropout,
)
self.decoder = TransformerDecoder(
tgt_vocab_size, d_model, n_heads, d_ff,
n_decoder_layers, max_len, dropout,
)
self.output_projection = nn.Linear(d_model, tgt_vocab_size)
# Initialize parameters with Xavier uniform
self._init_parameters()
def _init_parameters(self) -> None:
"""Initialize parameters using Xavier uniform initialization."""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(
self,
src: torch.Tensor,
tgt: torch.Tensor,
src_mask: torch.Tensor | None = None,
tgt_mask: torch.Tensor | None = None,
memory_mask: torch.Tensor | None = None,
src_key_padding_mask: torch.Tensor | None = None,
tgt_key_padding_mask: torch.Tensor | None = None,
memory_key_padding_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass through the complete Transformer.
Args:
src: Source token IDs of shape (batch_size, src_len).
tgt: Target token IDs of shape (batch_size, tgt_len).
src_mask: Source attention mask.
tgt_mask: Target causal mask.
memory_mask: Cross-attention mask.
src_key_padding_mask: Source padding mask.
tgt_key_padding_mask: Target padding mask.
memory_key_padding_mask: Memory padding mask for cross-attention.
Returns:
Logits of shape (batch_size, tgt_len, tgt_vocab_size).
"""
# Encode source sequence
memory = self.encoder(src, src_mask, src_key_padding_mask)
# Decode target sequence using encoder output
decoder_output = self.decoder(
tgt, memory, tgt_mask, memory_mask,
tgt_key_padding_mask, memory_key_padding_mask,
)
# Project to vocabulary size
logits = self.output_projection(decoder_output)
return logits
19.8.2 Parameter Initialization
The original paper uses Xavier (Glorot) uniform initialization for most parameters. This initialization strategy sets weights such that the variance of activations remains roughly constant across layers, which is important for stable training. PyTorch's nn.Embedding and nn.LayerNorm have their own default initialization schemes that are reasonable, so we only apply Xavier to parameters with more than one dimension (i.e., weight matrices, not biases or layer norm parameters).
19.8.3 Weight Tying
A common optimization is to tie the decoder's embedding weights with the output projection layer. Since the embedding matrix maps tokens to vectors and the output projection maps vectors back to token logits, they are performing inverse operations. Sharing these weights reduces the parameter count and often improves performance:
# Weight tying: share embedding and output projection weights
model.output_projection.weight = model.decoder.embedding.weight
The original Transformer paper mentions this technique, and it is used in many subsequent models.
19.8.4 Model Size
Let us count the parameters of our Transformer with the original paper's settings ($d_{\text{model}} = 512$, $d_{\text{ff}} = 2048$, $N = 6$, $h = 8$):
| Component | Parameters per Layer | Total |
|---|---|---|
| Multi-head attention | $4 \times d_{\text{model}}^2 = 4 \times 512^2 \approx 1.05M$ | --- |
| Feed-forward network | $2 \times d_{\text{model}} \times d_{\text{ff}} \approx 2.1M$ | --- |
| Layer norms | $2 \times 2 \times d_{\text{model}} \approx 2K$ | --- |
| Encoder (6 layers) | --- | $\approx 19M$ |
| Decoder (6 layers, with cross-attn) | --- | $\approx 25M$ |
| Embeddings (shared, 37K vocab) | --- | $\approx 19M$ |
| Total | --- | $\approx 63M$ |
The original "base" model had about 65 million parameters. The "big" model ($d_{\text{model}} = 1024$, $d_{\text{ff}} = 4096$, $N = 6$, $h = 16$) had about 213 million.
19.9 The Training Process
19.9.1 Teacher Forcing
During training, the Transformer decoder uses teacher forcing: at each position, the correct previous tokens (from the training data) are provided as input, rather than the model's own predictions. This means the decoder input at training time is the target sequence shifted right by one position, with a special <bos> (beginning of sequence) token prepended.
For example, if the target is ["I", "like", "cats", "<eos>"], the decoder input would be ["<bos>", "I", "like", "cats"], and the expected output (labels) would be ["I", "like", "cats", "<eos>"].
The causal mask ensures that even though all target tokens are provided simultaneously, each position can only attend to earlier positions. This allows training to happen in parallel --- all positions are processed in a single forward pass --- while maintaining the autoregressive property.
19.9.2 Label-Smoothed Cross-Entropy Loss
The original Transformer uses label smoothing with $\epsilon_{ls} = 0.1$. Instead of using hard one-hot targets, label smoothing distributes a small amount of probability mass to all other tokens:
$$q_i = \begin{cases} 1 - \epsilon_{ls} & \text{if } i = y \\ \epsilon_{ls} / (V - 1) & \text{otherwise} \end{cases}$$
where $y$ is the correct token and $V$ is the vocabulary size. This prevents the model from becoming overconfident and improves generalization. It also slightly hurts perplexity (since the model learns not to assign 100% probability to the correct answer) but improves BLEU scores on translation tasks.
19.9.3 The Adam Optimizer with Warm-Up
The original paper uses the Adam optimizer with a custom learning rate schedule that includes a warm-up phase:
$$lr = d_{\text{model}}^{-0.5} \cdot \min(step^{-0.5}, step \cdot warmup\_steps^{-1.5})$$
This increases the learning rate linearly for the first $warmup\_steps$ steps (typically 4,000), then decreases it proportionally to the inverse square root of the step number. The warm-up prevents the model from making large, poorly-directed updates early in training when the parameters are still random.
class TransformerLRScheduler:
"""Learning rate scheduler from 'Attention Is All You Need'.
Implements the warm-up and decay schedule described in the paper.
Args:
optimizer: The optimizer to adjust.
d_model: Model dimension (used for scaling).
warmup_steps: Number of warm-up steps.
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
d_model: int,
warmup_steps: int = 4000,
) -> None:
self.optimizer = optimizer
self.d_model = d_model
self.warmup_steps = warmup_steps
self.step_num = 0
def step(self) -> None:
"""Update learning rate for the next step."""
self.step_num += 1
lr = self.d_model ** (-0.5) * min(
self.step_num ** (-0.5),
self.step_num * self.warmup_steps ** (-1.5),
)
for param_group in self.optimizer.param_groups:
param_group["lr"] = lr
19.10 Training on a Toy Translation Task
Let us bring everything together by training our Transformer on a simple toy translation task. We will create a synthetic dataset of number-word translations (e.g., mapping digit sequences to their English word equivalents) to demonstrate that the model learns to translate.
19.10.1 Setting Up the Data
torch.manual_seed(42)
# Special tokens
PAD_IDX = 0
BOS_IDX = 1
EOS_IDX = 2
# Simple vocabulary for a copy/reverse task
# Source: sequences of integers
# Target: same sequence reversed (a simple seq2seq task)
def create_toy_data(
n_samples: int = 1000,
min_len: int = 3,
max_len: int = 8,
vocab_size: int = 10,
) -> list[tuple[list[int], list[int]]]:
"""Create synthetic data for a sequence reversal task.
Args:
n_samples: Number of training examples to generate.
min_len: Minimum sequence length.
max_len: Maximum sequence length.
vocab_size: Number of distinct tokens (excluding special tokens).
Returns:
List of (source, target) pairs where target is the reversed source.
"""
data = []
for _ in range(n_samples):
length = torch.randint(min_len, max_len + 1, (1,)).item()
# Tokens start at 3 (after PAD, BOS, EOS)
src = torch.randint(3, 3 + vocab_size, (length,)).tolist()
tgt = list(reversed(src))
data.append((src, tgt))
return data
19.10.2 Collation and Batching
def collate_batch(
batch: list[tuple[list[int], list[int]]],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Collate a batch of (source, target) pairs with padding.
Args:
batch: List of (source_tokens, target_tokens) pairs.
Returns:
Tuple of (src_padded, tgt_input_padded, tgt_output_padded).
"""
src_list, tgt_list = [], []
for src, tgt in batch:
src_list.append(torch.tensor(src, dtype=torch.long))
tgt_list.append(torch.tensor([BOS_IDX] + tgt + [EOS_IDX], dtype=torch.long))
src_padded = nn.utils.rnn.pad_sequence(
src_list, batch_first=True, padding_value=PAD_IDX
)
tgt_padded = nn.utils.rnn.pad_sequence(
tgt_list, batch_first=True, padding_value=PAD_IDX
)
# tgt_input: everything except the last token (input to decoder)
tgt_input = tgt_padded[:, :-1]
# tgt_output: everything except the first token (labels)
tgt_output = tgt_padded[:, 1:]
return src_padded, tgt_input, tgt_output
19.10.3 The Training Loop
def train_transformer(
model: Transformer,
train_data: list[tuple[list[int], list[int]]],
n_epochs: int = 20,
batch_size: int = 32,
device: str = "cpu",
) -> list[float]:
"""Train the Transformer model on toy data.
Args:
model: The Transformer model.
train_data: Training data as (source, target) pairs.
n_epochs: Number of training epochs.
batch_size: Batch size.
device: Device to train on.
Returns:
List of average loss values per epoch.
"""
model = model.to(device)
model.train()
optimizer = torch.optim.Adam(
model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9
)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
losses = []
for epoch in range(n_epochs):
epoch_loss = 0.0
n_batches = 0
# Simple batching
for i in range(0, len(train_data), batch_size):
batch = train_data[i:i + batch_size]
src, tgt_input, tgt_output = collate_batch(batch)
src = src.to(device)
tgt_input = tgt_input.to(device)
tgt_output = tgt_output.to(device)
# Create masks
tgt_mask = generate_causal_mask(tgt_input.size(1)).to(device)
src_padding_mask = (src == PAD_IDX)
tgt_padding_mask = (tgt_input == PAD_IDX)
# Forward pass
logits = model(
src, tgt_input,
tgt_mask=tgt_mask,
src_key_padding_mask=src_padding_mask,
tgt_key_padding_mask=tgt_padding_mask,
memory_key_padding_mask=src_padding_mask,
)
# Compute loss
loss = criterion(
logits.reshape(-1, logits.size(-1)),
tgt_output.reshape(-1),
)
# Backward pass
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
epoch_loss += loss.item()
n_batches += 1
avg_loss = epoch_loss / n_batches
losses.append(avg_loss)
if (epoch + 1) % 5 == 0:
print(f"Epoch {epoch + 1}/{n_epochs}, Loss: {avg_loss:.4f}")
return losses
19.10.4 Greedy Decoding
At inference time, we generate the output sequence one token at a time using greedy decoding --- always choosing the most probable next token:
def greedy_decode(
model: Transformer,
src: torch.Tensor,
max_len: int = 20,
device: str = "cpu",
) -> list[int]:
"""Generate output sequence using greedy decoding.
Args:
model: Trained Transformer model.
src: Source token IDs of shape (1, src_len).
max_len: Maximum output length.
device: Device to run on.
Returns:
List of predicted token IDs.
"""
model.eval()
src = src.to(device)
# Encode source
with torch.no_grad():
src_padding_mask = (src == PAD_IDX)
memory = model.encoder(src, src_key_padding_mask=src_padding_mask)
# Start with BOS token
tgt_tokens = [BOS_IDX]
for _ in range(max_len):
tgt_tensor = torch.tensor([tgt_tokens], dtype=torch.long, device=device)
tgt_mask = generate_causal_mask(len(tgt_tokens)).to(device)
with torch.no_grad():
decoder_out = model.decoder(
tgt_tensor, memory, tgt_mask=tgt_mask,
memory_key_padding_mask=src_padding_mask,
)
logits = model.output_projection(decoder_out[:, -1, :])
next_token = logits.argmax(dim=-1).item()
tgt_tokens.append(next_token)
if next_token == EOS_IDX:
break
return tgt_tokens[1:] # Exclude BOS
19.10.5 Running the Experiment
torch.manual_seed(42)
# Create data
train_data = create_toy_data(n_samples=2000, vocab_size=10)
# Create model (small for toy task)
model = Transformer(
src_vocab_size=13, # 10 tokens + PAD + BOS + EOS
tgt_vocab_size=13,
d_model=64,
n_heads=4,
d_ff=256,
n_encoder_layers=2,
n_decoder_layers=2,
max_len=50,
dropout=0.1,
)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
# Train
losses = train_transformer(model, train_data, n_epochs=30, batch_size=64)
# Test
test_src = torch.tensor([[3, 5, 7, 9, 4]]) # Example source sequence
predicted = greedy_decode(model, test_src)
print(f"Source: {test_src[0].tolist()}")
print(f"Expected: {list(reversed(test_src[0].tolist()))}")
print(f"Predicted: {predicted}")
When you run this code, you should see the loss decrease steadily and the model learning to reverse sequences. This simple task demonstrates all the components of the Transformer working together: the encoder processes the source sequence, the decoder generates the reversed sequence autoregressively, and cross-attention allows the decoder to attend to the encoder's representation.
See code/example-03-full-transformer.py for the complete, runnable training script.
19.11 Design Decisions and Variations
19.11.1 Why No Recurrence?
The absence of recurrence in the Transformer is not just an aesthetic choice --- it has profound practical consequences:
-
Parallelism: All positions in a sequence can be processed simultaneously. An RNN must process tokens 1, 2, 3, ..., $n$ in order. Self-attention processes all $n$ tokens at once.
-
Path length: In an RNN, information from position 1 must travel through $n-1$ sequential steps to reach position $n$. In self-attention, every position can attend to every other position in a single step. This shorter path length makes it easier for the model to learn long-range dependencies.
-
Training speed: The parallelism translates directly to faster training on GPUs and TPUs. The original paper reports that the Transformer base model was trained in 12 hours on 8 GPUs --- far faster than comparable RNN models.
The trade-off is computational complexity. Self-attention has $O(n^2)$ complexity in sequence length (every position attends to every other position), compared to $O(n)$ for an RNN. For very long sequences, this quadratic cost can be prohibitive, motivating efficient attention variants like Linformer, Performer, and Flash Attention.
19.11.2 Encoder-Only vs. Decoder-Only vs. Encoder-Decoder
The original Transformer is an encoder-decoder model, but the architecture's components have been rearranged in influential ways:
-
Encoder-only (e.g., BERT): Uses only the encoder with unmasked self-attention. Best for tasks that require understanding the full input (classification, named entity recognition, question answering with a given context).
-
Decoder-only (e.g., GPT, LLaMA): Uses only the decoder with causal masking. Best for generative tasks (text generation, code generation). These models have dominated the landscape of large language models.
-
Encoder-decoder (e.g., T5, BART, the original Transformer): Uses both components. Best for sequence-to-sequence tasks (translation, summarization).
19.11.3 Architectural Variants in Detail
Let us examine the three major Transformer variants more carefully, as understanding their differences is essential for choosing the right architecture for a given task:
GPT-style (Decoder-only, Causal). The GPT family uses only the decoder component with causal masking. There is no encoder and no cross-attention. The model processes a sequence left-to-right, predicting each next token based on all previous tokens. The training objective is causal language modeling: maximize the probability of each token given its predecessors:
$$\mathcal{L} = -\sum_{t=1}^{T} \log P(x_t | x_1, \ldots, x_{t-1})$$
This approach is conceptually simple and scales remarkably well. GPT-3 (175B parameters), GPT-4, and Claude are all based on decoder-only architectures. The key advantage is that the same architecture serves both pre-training and generation -- the model always operates in the same left-to-right mode.
BERT-style (Encoder-only, Bidirectional). BERT uses only the encoder with unmasked self-attention. Every token can attend to every other token, including tokens that appear later in the sequence. This bidirectional context is a significant advantage for understanding tasks: when classifying a sentence, the meaning of a word in the middle depends on words both before and after it.
BERT is pre-trained with masked language modeling (MLM): randomly mask 15% of tokens and train the model to predict the masked tokens from context. This forces the model to build rich bidirectional representations. However, BERT cannot generate text autoregressively because it has no causal masking.
T5-style (Encoder-Decoder, Full). T5 (Text-to-Text Transfer Transformer) uses the complete encoder-decoder architecture. It frames every NLP task as a text-to-text problem: the encoder processes the input text, and the decoder generates the output text. For classification, the "output text" might simply be the class label as a word. This unified framing simplifies multi-task learning and transfer.
| Variant | Architecture | Pre-training | Best For | Examples |
|---|---|---|---|---|
| GPT-style | Decoder-only | Causal LM | Generation, reasoning | GPT-3/4, LLaMA, Claude |
| BERT-style | Encoder-only | Masked LM | Classification, NER, QA | BERT, RoBERTa, DeBERTa |
| T5-style | Encoder-decoder | Span corruption | Translation, summarization | T5, BART, mBART |
19.11.4 Scaling Laws
One of the most important empirical discoveries about Transformers is that their performance improves predictably as a power law with increases in model size, dataset size, and compute budget. This observation, formalized by Kaplan et al. (2020), can be expressed as:
$$L(N) \propto N^{-\alpha_N}, \qquad L(D) \propto D^{-\alpha_D}, \qquad L(C) \propto C^{-\alpha_C}$$
where $L$ is the loss, $N$ is the number of parameters, $D$ is the dataset size, and $C$ is the compute budget. The exponents $\alpha$ are positive constants determined empirically. This means that each 10x increase in compute yields a roughly constant improvement in loss -- a remarkable regularity that holds over many orders of magnitude.
Chinchilla (Hoffmann et al., 2022) refined these scaling laws, showing that the original estimates underweighted the importance of data relative to model size. The Chinchilla-optimal approach trains smaller models on more data, achieving the same performance at lower inference cost. This finding has influenced the design of subsequent language models, shifting the field toward more data-efficient training regimes.
19.12 Understanding Information Flow
19.12.1 The Residual Stream View
A helpful mental model for the Transformer is the residual stream perspective. The residual connections create a "stream" of information that flows from the input to the output. Each attention layer and feed-forward layer reads from this stream and writes an additive update back to it.
In the encoder: 1. The stream starts with token embeddings + positional encodings. 2. Each self-attention sublayer reads the stream, computes attention-weighted combinations, and writes the result back. 3. Each FFN sublayer reads the stream, applies a nonlinear transformation, and writes the result back.
In the decoder, the same process occurs, but with the addition of cross-attention layers that also read from the encoder's final stream.
19.12.2 What Each Component Does
-
Self-attention: Moves information between positions. After self-attention, a token's representation contains information from other tokens it attended to. This is the mechanism for contextual understanding.
-
Cross-attention: Moves information from the source to the target. This is how the decoder knows what the input says.
-
Feed-forward network: Processes information at each position independently. Research suggests that FFN layers act as key-value memories, storing factual knowledge and performing token-level computations.
-
Layer normalization: Stabilizes the magnitude of activations, preventing them from growing or shrinking as they flow through many layers.
-
Residual connections: Ensure that information can flow unchanged through layers, preventing gradient degradation and allowing each layer to contribute incrementally.
19.12.3 Attention Patterns
When we visualize attention weights in trained Transformers, we observe characteristic patterns:
- Positional attention: Some heads attend to specific relative positions (e.g., the previous token or the next token), effectively recovering some of the sequential inductive bias that recurrence provides.
- Content-based attention: Some heads attend based on semantic similarity, connecting related words regardless of position.
- Delimiter attention: Some heads attend to special tokens like periods or commas, possibly using them as aggregate information stores.
- Rare-token attention: Some heads specifically focus on low-frequency tokens that carry high information content.
These patterns emerge through training without any explicit instruction --- the model learns whatever attention patterns are useful for the task.
19.13 Practical Considerations
19.13.1 Gradient Clipping
Training Transformers, especially larger ones, can suffer from gradient explosions. The standard practice is to clip gradients to a maximum norm (typically 1.0):
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
19.13.2 Mixed-Precision Training
Modern Transformer training almost always uses mixed-precision (FP16 or BF16) arithmetic to reduce memory usage and increase throughput. PyTorch provides this through torch.cuda.amp:
scaler = torch.cuda.amp.GradScaler()
with torch.cuda.amp.autocast():
logits = model(src, tgt, tgt_mask=tgt_mask)
loss = criterion(logits.view(-1, vocab_size), labels.view(-1))
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
19.13.3 Regularization
The original Transformer uses three forms of regularization:
- Dropout ($p = 0.1$): Applied to attention weights, sublayer outputs, and after positional encoding.
- Label smoothing ($\epsilon = 0.1$): Softens the target distribution.
- Weight decay: Not used in the original paper but common in modern implementations.
19.13.4 Training Stability Techniques
Training large Transformers can be surprisingly fragile. Beyond learning rate warmup and gradient clipping, several additional techniques help ensure stable training:
Weight initialization matters. The original paper uses Xavier uniform initialization. For very deep Transformers, scaling the initialization of residual sublayers by $1/\sqrt{N}$ (where $N$ is the number of layers) prevents the residual stream magnitude from growing with depth. GPT-2 introduced this technique, and it has become standard practice.
Gradient accumulation for effective batch size. Large language models are typically trained with very large batch sizes (thousands of sequences). When GPU memory is limited, gradient accumulation simulates large batches by accumulating gradients over multiple forward-backward passes before performing an optimizer step:
accumulation_steps = 8 # Effective batch size = actual_batch * 8
optimizer.zero_grad()
for step, (src, tgt) in enumerate(dataloader):
logits = model(src, tgt)
loss = criterion(logits, labels) / accumulation_steps
loss.backward()
if (step + 1) % accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
Loss spike detection. Production training pipelines monitor for sudden increases in loss (loss spikes), which can indicate numerical instability. Common remedies include reducing the learning rate, increasing warmup steps, or reverting to an earlier checkpoint.
19.13.5 Computational Cost
The computational cost of the Transformer is dominated by:
- Self-attention: $O(n^2 \cdot d)$ where $n$ is sequence length and $d$ is model dimension
- Feed-forward network: $O(n \cdot d \cdot d_{ff})$
For typical model sizes ($d_{ff} = 4d$), the FFN is actually more expensive than attention for short sequences. Attention only dominates for very long sequences where $n > 4d$.
19.14 Production Implementation Details
When moving from a toy Transformer to a production system, several implementation details become critical.
19.14.1 KV-Cache for Efficient Inference
During autoregressive generation, the naive approach recomputes the key and value matrices for all previous tokens at every step. This is wasteful because the keys and values for already-generated tokens do not change. A KV-cache stores the key and value tensors from previous steps and only computes them for the new token:
class CachedSelfAttention(nn.Module):
"""Self-attention with KV-caching for efficient autoregressive generation.
During generation, previously computed key and value tensors are cached
and reused, reducing the per-step complexity from O(n^2 * d) to O(n * d).
Args:
d_model: Model dimension.
n_heads: Number of attention heads.
"""
def __init__(self, d_model: int, n_heads: int) -> None:
super().__init__()
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
def forward(
self,
x: torch.Tensor,
cache: tuple[torch.Tensor, torch.Tensor] | None = None,
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""Forward pass with optional KV-cache.
Args:
x: Input tensor of shape (batch, seq_len, d_model).
During cached generation, seq_len = 1.
cache: Tuple of (cached_keys, cached_values) from previous steps,
or None for the first step.
Returns:
Tuple of (output, new_cache) where new_cache includes the
current step's keys and values.
"""
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
# Append to cache if available
if cache is not None:
k = torch.cat([cache[0], k], dim=1)
v = torch.cat([cache[1], v], dim=1)
new_cache = (k, v)
# Standard scaled dot-product attention
# (reshaping for multi-head omitted for clarity)
scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_head ** 0.5)
weights = torch.softmax(scores, dim=-1)
output = torch.matmul(weights, v)
output = self.out_proj(output)
return output, new_cache
Without KV-caching, generating a sequence of length $n$ requires $O(n^2)$ total computation (each of $n$ steps processes $n$ tokens). With KV-caching, it requires $O(n)$ per step for attention computation (only the new query attends to all cached keys), reducing the total to $O(n^2)$ in key-value memory but $O(n)$ in compute per step. For long sequences, this is a massive speedup.
19.14.2 Flash Attention
Standard attention computes the full $n \times n$ attention matrix, which requires $O(n^2)$ memory. Flash Attention (Dao et al., 2022) is a GPU-optimized implementation that computes attention without materializing the full attention matrix, reducing memory from $O(n^2)$ to $O(n)$ while also being faster due to better GPU memory access patterns. In PyTorch 2.0+, Flash Attention is available through torch.nn.functional.scaled_dot_product_attention:
# PyTorch 2.0+ automatically uses Flash Attention when available
import torch.nn.functional as F
# This call will use Flash Attention on compatible hardware
output = F.scaled_dot_product_attention(query, key, value, is_causal=True)
Flash Attention has become essential for training and serving large Transformers. It enables processing of longer sequences (4K, 8K, or even 128K tokens) that would otherwise exceed GPU memory.
19.14.3 Tensor Parallelism for Multi-GPU Training
Large Transformers often do not fit on a single GPU. Tensor parallelism splits individual layers across multiple GPUs. For the feed-forward network, this means partitioning the weight matrices column-wise across GPUs, performing local matrix multiplications, and then combining results with an all-reduce operation. For attention, different heads can be distributed across GPUs since they operate independently. These techniques, along with pipeline parallelism (splitting layers across GPUs) and data parallelism (splitting batches), enable training models with hundreds of billions of parameters. We will explore distributed training in more depth in Chapter 35.
19.15 From the Original Transformer to Modern Architectures
The Transformer architecture described in this chapter is the foundation upon which nearly all modern language models are built. Here is a brief roadmap of what came next (many of these will be covered in subsequent chapters):
- BERT (2018): Encoder-only Transformer, pre-trained with masked language modeling (Chapter 21).
- GPT (2018): Decoder-only Transformer, pre-trained with causal language modeling (Chapter 22).
- GPT-2/GPT-3 (2019/2020): Scaled-up decoder-only Transformers demonstrating the power of scale.
- T5 (2020): Encoder-decoder Transformer treating every NLP task as text-to-text.
- Vision Transformer (ViT) (2020): Applied the Transformer to image classification by treating image patches as tokens.
The core architecture --- multi-head attention, feed-forward networks, residual connections, layer normalization --- has remained remarkably stable. Most innovations have been in training methodology, scaling, and efficiency rather than fundamental architectural changes.
19.16 Summary
In this chapter, we have dissected and rebuilt the Transformer architecture from the ground up. Let us review the key components and their roles:
-
Positional encoding injects sequence order information that the permutation-invariant attention mechanism cannot capture on its own. Sinusoidal encodings provide this without additional parameters.
-
Layer normalization stabilizes training by normalizing activations across the feature dimension. Pre-norm placement (before the sublayer) offers more stable training than the original post-norm placement.
-
Multi-head self-attention (from Chapter 18) allows each position to gather information from all other positions. In the encoder, this is unmasked; in the decoder, it is causally masked.
-
Cross-attention connects the decoder to the encoder, allowing the decoder to read the source sequence when generating each target token.
-
Position-wise feed-forward networks apply nonlinear transformations independently at each position, providing the model's primary computational capacity.
-
Residual connections create gradient highways through the network, enabling the training of deep stacks of layers.
-
The encoder stack processes the entire source sequence in parallel, producing a rich contextual representation.
-
The decoder stack generates the output autoregressively, attending both to its own previous outputs and to the encoder's representation.
-
The full model is trained end-to-end with teacher forcing, using the Adam optimizer with a warm-up schedule.
The Transformer did not just replace RNNs for machine translation --- it became the foundation for a revolution in artificial intelligence, enabling the development of large language models, vision transformers, and multimodal models that have transformed the field. Understanding this architecture is essential for anyone working in modern AI.
In Chapter 20, we will explore how the Transformer's components are adapted for different tasks, examining the pre-training and fine-tuning paradigm that has become the dominant approach in natural language processing. The architectural foundation laid in this chapter -- attention, normalization, residual connections, and feed-forward networks -- will appear again and again throughout the remainder of this book, in contexts ranging from language modeling to computer vision to reinforcement learning.
References
-
Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., & Polosukhin, I. (2017). Attention is all you need. Advances in Neural Information Processing Systems, 30.
-
Ba, J. L., Kiros, J. R., & Hinton, G. E. (2016). Layer normalization. arXiv preprint arXiv:1607.06450.
-
He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 770--778.
-
Kaplan, J., McCandlish, S., Henighan, T., Brown, T. B., Chess, B., Child, R., Gray, S., Radford, A., Wu, J., & Amodei, D. (2020). Scaling laws for neural language models. arXiv preprint arXiv:2001.08361.
-
Xiong, R., Yang, Y., He, D., Zheng, K., Zheng, S., Xing, C., Zhang, H., Lan, Y., Wang, L., & Liu, T.-Y. (2020). On layer normalization in the Transformer architecture. International Conference on Machine Learning.
-
Press, O., & Wolf, L. (2017). Using the output embedding to improve language models. Proceedings of the 15th Conference of the European Chapter of the Association for Computational Linguistics.