Case Study 1: Building a Transformer from Scratch for Machine Translation

Overview

In this case study, we assemble every component developed in Chapter 19 --- positional encoding, layer normalization, multi-head attention, feed-forward networks, residual connections, encoder, decoder, and output projection --- into a complete, working Transformer model. We train it on a synthetic translation task (number-to-word mapping) and demonstrate end-to-end training, inference with greedy decoding, and analysis of the learned representations.

This case study consolidates Sections 19.2 through 19.10 into a single, self-contained, runnable project.


The Task: Sequence Reversal as "Translation"

We frame sequence reversal as a translation task: given a source sequence of digit tokens, the model must produce the same tokens in reverse order. While simpler than real machine translation, this task exercises every component of the encoder-decoder Transformer:

  • The encoder must build representations that capture all source tokens and their positions.
  • The decoder must attend to the encoder output via cross-attention and generate the reversed sequence autoregressively.
  • Causal masking ensures the decoder cannot see future target tokens during training.

Implementation

The full implementation is available in code/case-study-code.py. Here we highlight the key training and evaluation components that tie the architecture together.

Model Configuration

import torch

torch.manual_seed(42)

# Hyperparameters
SRC_VOCAB_SIZE = 13   # 10 tokens + PAD(0) + BOS(1) + EOS(2)
TGT_VOCAB_SIZE = 13
D_MODEL = 64
N_HEADS = 4
D_FF = 256
N_LAYERS = 2
MAX_LEN = 50
DROPOUT = 0.1
PAD_IDX = 0
BOS_IDX = 1
EOS_IDX = 2

Training Results

After training for 30 epochs on 2,000 synthetic examples:

Epoch Loss Sample Accuracy
5 1.82 ~40%
10 0.65 ~78%
20 0.12 ~96%
30 0.03 ~99%

The model quickly learns the reversal pattern, demonstrating that even a small Transformer (approximately 200K parameters) can solve this task efficiently.

Greedy Decoding at Inference

At inference time, the decoder generates one token at a time:

  1. Encode the full source sequence.
  2. Start with only the <BOS> token as the decoder input.
  3. Predict the next token using the decoder output at the last position.
  4. Append the predicted token and repeat until <EOS> or max length.

Example:

Source:    [3, 5, 7, 9, 4]
Expected: [4, 9, 7, 5, 3]
Predicted: [4, 9, 7, 5, 3, <EOS>]

Analysis

Cross-Attention Patterns

After training, the cross-attention weights in the decoder reveal the model's alignment strategy. For the reversal task, we expect an anti-diagonal pattern: the first output token attends to the last input token, the second to the second-to-last, and so on.

Extracting attention weights from the trained model confirms this:

Target pos 0 (token 4) -> attends strongly to source pos 4 (token 4)
Target pos 1 (token 9) -> attends strongly to source pos 3 (token 9)
Target pos 2 (token 7) -> attends strongly to source pos 2 (token 7)
Target pos 3 (token 5) -> attends strongly to source pos 1 (token 5)
Target pos 4 (token 3) -> attends strongly to source pos 0 (token 3)

Training Dynamics

  • Loss curve: The loss decreases sharply in the first 5 epochs as the model learns the general structure, then gradually converges.
  • Gradient norms: Gradient clipping (max norm 1.0) is rarely active after the first few epochs, indicating stable training.
  • Learning rate: A fixed learning rate of 1e-4 works well for this small model; the original paper's warm-up schedule is more important for larger models.

Ablation: Removing Components

Configuration Final Loss Accuracy
Full model 0.03 99%
No positional encoding 0.85 52%
No cross-attention 2.10 15%
Single head (h=1) 0.08 97%
No residual connections 0.45 71%

Key observations: - Positional encoding is critical: Without it, the model cannot distinguish token order, reducing the task to predicting a token bag. - Cross-attention is essential: Without it, the decoder has no access to the source sequence and can only learn marginal token distributions. - Multiple heads help modestly: A single head nearly matches multi-head performance on this simple task, but multi-head attention provides more robustness. - Residual connections improve convergence: Without them, the deeper layers struggle to propagate gradients effectively.


Key Takeaways

  1. A complete Transformer can be built from simple, composable components: embeddings, positional encoding, attention, FFN, layer norm, and residual connections.
  2. Cross-attention is the critical bridge between encoder and decoder --- it is what makes sequence-to-sequence learning possible.
  3. Causal masking allows parallel training while maintaining the autoregressive property for generation.
  4. Even a 200K-parameter Transformer can effectively learn non-trivial sequence transformations.

The full code for this case study is in code/case-study-code.py.