Chapter 18 Key Takeaways
The Big Picture
The attention mechanism is the foundational building block of modern deep learning for sequences. It solves the information bottleneck problem in encoder--decoder models by allowing the decoder to dynamically look back at all encoder states at each generation step. Attention has evolved from a helpful add-on for RNN-based translation into the core computational primitive of Transformers and large language models.
Core Concepts at a Glance
The Information Bottleneck
- Standard seq2seq models compress the entire input into a single fixed-length context vector $\mathbf{c} = \mathbf{h}_T$.
- This bottleneck causes catastrophic quality degradation for sequences longer than 20--30 tokens.
- Attention replaces the fixed context vector with a time-dependent context vector $\mathbf{c}_t$ that is different for each decoder step.
Bahdanau (Additive) Attention
- Computes alignment scores via a learned feed-forward network: $e_{tj} = \mathbf{v}_a^\top \tanh(\mathbf{W}_a \mathbf{s}_{t-1} + \mathbf{U}_a \mathbf{h}_j)$.
- Uses the previous decoder state $\mathbf{s}_{t-1}$.
- More expressive but slower than multiplicative alternatives.
Luong (Multiplicative) Attention
- Three scoring functions: dot ($\mathbf{s}_t^\top \mathbf{h}_j$), general ($\mathbf{s}_t^\top \mathbf{W}_a \mathbf{h}_j$), concat.
- Uses the current decoder state $\mathbf{s}_t$.
- Dot scoring is the fastest; general scoring adds flexibility with a learnable matrix.
The Query-Key-Value Framework
- Query (Q): What you are looking for.
- Key (K): What each position advertises.
- Value (V): What each position communicates.
- Attention is a differentiable dictionary lookup: softmax(similarity(Q, K)) produces weights, and the output is a weighted sum of V.
- Separate projections for Q, K, V give the model flexibility to learn different roles for searching, indexing, and communicating.
Scaled Dot-Product Attention
$$\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\!\left(\frac{\mathbf{Q}\mathbf{K}^\top}{\sqrt{d_k}}\right)\mathbf{V}$$
- The scaling factor $1/\sqrt{d_k}$ prevents softmax saturation as $d_k$ grows.
- Without scaling, dot products have variance $d_k$, pushing softmax into near-zero gradient regions.
Multi-Head Attention
$$\text{MultiHead}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)\mathbf{W}^O$$
- Runs $h$ attention operations in parallel, each in a $d_k = d_{\text{model}}/h$ dimensional subspace.
- Different heads specialize in different relationship types (positional, syntactic, semantic).
- Total parameter count: $4 \times d_{\text{model}}^2$ (for Q, K, V, and output projections).
Attention Masking
- Padding masks prevent attention to meaningless padding tokens.
- Causal masks prevent positions from attending to the future (essential for autoregressive models).
- Masks are applied by setting scores to $-\infty$ before softmax, which drives those weights to zero.
Key Equations
| Concept | Formula |
|---|---|
| Context vector | $\mathbf{c}_t = \sum_{j=1}^{T} \alpha_{tj} \mathbf{h}_j$ |
| Attention weights | $\alpha_{tj} = \text{softmax}(e_{tj})$ |
| Scaled dot-product | $\frac{\mathbf{Q}\mathbf{K}^\top}{\sqrt{d_k}}$ |
| Multi-head output | $\text{Concat}(\text{head}_1, \ldots, \text{head}_h)\mathbf{W}^O$ |
| Dot product variance | $\text{Var}(\mathbf{q} \cdot \mathbf{k}) = d_k$ |
Computational Complexity
| Operation | Time Complexity | Space Complexity |
|---|---|---|
| Self-attention | $O(n^2 d)$ | $O(n^2 + nd)$ |
| RNN | $O(n d^2)$ | $O(nd)$ |
- Attention is quadratic in sequence length $n$ --- the primary scalability concern.
- Attention is fully parallelizable across positions (unlike RNNs).
- Attention has $O(1)$ maximum path length between any two positions (RNNs require $O(n)$).
Practical Insights
- Self-attention vs. cross-attention: Self-attention has Q, K, V from the same sequence; cross-attention has Q from one sequence and K, V from another.
- Attention is not explanation. Attention weights show where the model looks but do not reliably explain why it makes a prediction (Jain and Wallace, 2019).
- Fused QKV projections are more GPU-efficient than three separate projections.
- Flash Attention computes exact attention with $O(n)$ memory by avoiding materializing the full $n \times n$ attention matrix.
Common Pitfalls to Avoid
- Forgetting to scale by $\sqrt{d_k}$: Leads to softmax saturation and vanishing gradients, especially with large $d_k$.
- Not masking padding tokens: Attention will attend to padding, diluting the representation with meaningless values.
- Applying causal masks to encoder self-attention: The encoder should be bidirectional --- causal masks belong only in the decoder.
- Assuming attention weights sum to 1 per key: They sum to 1 per query (row-wise softmax), not per key.
- Interpreting attention as causal explanation: Attention weights are descriptive, not prescriptive.
Looking Ahead
- Chapter 19 assembles attention into the full Transformer architecture with positional encoding, layer normalization, residual connections, and feed-forward networks.
- Chapter 20 explores how Transformers are pre-trained on large corpora and fine-tuned for downstream tasks (BERT, T5).
- Chapter 21 focuses on decoder-only models (GPT) that use causal attention for autoregressive text generation.
- Efficient attention variants (sparse, linear, sliding window) are covered as scaling challenges grow in later chapters.