Case Study 2: Mechanistic Analysis of a Transformer
Context
Mechanistic interpretability aims to reverse-engineer the algorithms learned by neural networks. In this case study, we analyze a small Transformer trained on a synthetic task where we know the ground-truth algorithm, allowing us to verify whether our interpretability tools recover the correct explanation.
The task is sequence completion with a known rule: given a sequence of tokens, the model must predict the next token based on a simple algorithmic rule. By choosing a task with a known solution, we can objectively assess whether our interpretability methods correctly identify the model's internal mechanism.
The Task: Modular Addition
We train a 2-layer Transformer to perform modular addition: given two numbers $a$ and $b$ (each encoded as tokens), predict $(a + b) \mod p$ where $p = 97$ (a prime). This task was studied by Neel Nanda et al. (2023) and reveals rich internal structure.
Input format: [a] [b] [=] and the model predicts the answer token.
Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
torch.manual_seed(42)
class SmallTransformer(nn.Module):
"""Minimal Transformer for modular addition.
Architecture: embedding -> 2 attention layers with MLPs -> output.
"""
def __init__(
self,
vocab_size: int = 100,
d_model: int = 128,
num_heads: int = 4,
d_ff: int = 256,
num_layers: int = 2,
max_seq_len: int = 4,
) -> None:
super().__init__()
self.d_model = d_model
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.pos_embedding = nn.Embedding(max_seq_len, d_model)
self.layers = nn.ModuleList()
for _ in range(num_layers):
self.layers.append(TransformerBlock(d_model, num_heads, d_ff))
self.ln_final = nn.LayerNorm(d_model)
self.output_proj = nn.Linear(d_model, vocab_size, bias=False)
def forward(
self, x: torch.Tensor, return_intermediates: bool = False
) -> torch.Tensor | tuple[torch.Tensor, dict]:
"""Forward pass with optional intermediate storage.
Args:
x: Token indices [batch_size, seq_len].
return_intermediates: If True, also return attention patterns
and residual stream states.
Returns:
Logits, or (logits, intermediates dict).
"""
seq_len = x.size(1)
positions = torch.arange(seq_len, device=x.device)
h = self.token_embedding(x) + self.pos_embedding(positions)
intermediates: dict[str, list] = {
"residual_stream": [h.detach()],
"attention_patterns": [],
"mlp_outputs": [],
}
for layer in self.layers:
h, attn_weights, mlp_out = layer(h, return_extras=True)
if return_intermediates:
intermediates["residual_stream"].append(h.detach())
intermediates["attention_patterns"].append(attn_weights.detach())
intermediates["mlp_outputs"].append(mlp_out.detach())
h = self.ln_final(h)
logits = self.output_proj(h)
if return_intermediates:
return logits, intermediates
return logits
class TransformerBlock(nn.Module):
"""Single Transformer block with attention and MLP."""
def __init__(self, d_model: int, num_heads: int, d_ff: int) -> None:
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = MultiHeadAttention(d_model, num_heads)
self.ln2 = nn.LayerNorm(d_model)
self.mlp = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Linear(d_ff, d_model),
)
def forward(
self, x: torch.Tensor, return_extras: bool = False
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
attn_out, attn_weights = self.attn(self.ln1(x))
x = x + attn_out
mlp_out = self.mlp(self.ln2(x))
x = x + mlp_out
return x, attn_weights, mlp_out
class MultiHeadAttention(nn.Module):
"""Multi-head self-attention."""
def __init__(self, d_model: int, num_heads: int) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.qkv = nn.Linear(d_model, 3 * d_model)
self.out_proj = nn.Linear(d_model, d_model)
def forward(
self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, D = x.shape
qkv = self.qkv(x).reshape(B, T, 3, self.num_heads, self.head_dim)
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
# Causal mask
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
attn.masked_fill_(mask, float("-inf"))
attn_weights = F.softmax(attn, dim=-1)
out = (attn_weights @ v).transpose(1, 2).reshape(B, T, D)
return self.out_proj(out), attn_weights
# -------------------------------------------------------------------
# Data Generation
# -------------------------------------------------------------------
def generate_modular_addition_data(
p: int = 97, split: str = "train"
) -> tuple[torch.Tensor, torch.Tensor]:
"""Generate modular addition dataset: (a + b) mod p.
Args:
p: The modulus (prime).
split: "train" uses 80% of pairs, "test" uses 20%.
Returns:
Tuple of (inputs [N, 3], targets [N]).
"""
equals_token = p # Token p represents "="
all_pairs = []
all_targets = []
for a in range(p):
for b in range(p):
all_pairs.append([a, b, equals_token])
all_targets.append((a + b) % p)
inputs = torch.tensor(all_pairs, dtype=torch.long)
targets = torch.tensor(all_targets, dtype=torch.long)
# Deterministic split
n = len(inputs)
perm = torch.randperm(n, generator=torch.Generator().manual_seed(0))
split_idx = int(0.8 * n)
if split == "train":
idx = perm[:split_idx]
else:
idx = perm[split_idx:]
return inputs[idx], targets[idx]
# -------------------------------------------------------------------
# Analysis Tools
# -------------------------------------------------------------------
def analyze_attention_patterns(
model: SmallTransformer,
inputs: torch.Tensor,
) -> None:
"""Analyze which positions the model attends to.
For modular addition, we expect the model to attend from
position 2 (=) to positions 0 (a) and 1 (b).
"""
model.eval()
with torch.no_grad():
_, intermediates = model(inputs[:200], return_intermediates=True)
for layer_idx, attn in enumerate(intermediates["attention_patterns"]):
# attn shape: [batch, heads, seq_len, seq_len]
# Focus on position 2 (= token) attending to positions 0,1,2
pos2_attn = attn[:, :, 2, :] # [batch, heads, 3]
mean_attn = pos2_attn.mean(dim=0) # [heads, 3]
print(f"\nLayer {layer_idx} - Attention from position 2 (=):")
for head in range(mean_attn.size(0)):
weights = mean_attn[head].tolist()
print(f" Head {head}: "
f"to pos 0 (a)={weights[0]:.3f}, "
f"to pos 1 (b)={weights[1]:.3f}, "
f"to pos 2 (=)={weights[2]:.3f}")
def activation_patching_analysis(
model: SmallTransformer,
inputs: torch.Tensor,
targets: torch.Tensor,
p: int = 97,
) -> None:
"""Identify causally important components via activation patching.
For each layer and component (attention, MLP), replace its output
with the output from a corrupted input and measure the effect.
"""
model.eval()
# Clean run
with torch.no_grad():
clean_logits, clean_intermediates = model(
inputs[:100], return_intermediates=True
)
clean_probs = F.softmax(clean_logits[:, 2, :p], dim=-1)
clean_loss = F.cross_entropy(clean_logits[:, 2, :p], targets[:100])
# Corrupted inputs: shuffle position 1 (b)
corrupted = inputs[:100].clone()
corrupted[:, 1] = corrupted[torch.randperm(100), 1]
with torch.no_grad():
corrupted_logits = model(corrupted)
corrupted_loss = F.cross_entropy(corrupted_logits[:, 2, :p], targets[:100])
print(f"\nClean loss: {clean_loss.item():.4f}")
print(f"Corrupted loss: {corrupted_loss.item():.4f}")
# Patch each component
for layer_idx in range(len(model.layers)):
for component in ["attention", "mlp"]:
# Store clean activation via hook
clean_act = {}
target_module = (
model.layers[layer_idx].attn if component == "attention"
else model.layers[layer_idx].mlp
)
def save_hook(module, inp, out, name=f"{layer_idx}_{component}"):
if isinstance(out, tuple):
clean_act[name] = out[0].detach().clone()
else:
clean_act[name] = out.detach().clone()
hook = target_module.register_forward_hook(save_hook)
with torch.no_grad():
model(inputs[:100])
hook.remove()
# Run corrupted with clean activation patched in
def patch_hook(module, inp, out, name=f"{layer_idx}_{component}"):
if isinstance(out, tuple):
return (clean_act[name],) + out[1:]
return clean_act[name]
hook = target_module.register_forward_hook(patch_hook)
with torch.no_grad():
patched_logits = model(corrupted)
patched_loss = F.cross_entropy(
patched_logits[:, 2, :p], targets[:100]
)
hook.remove()
recovery = (corrupted_loss.item() - patched_loss.item()) / (
corrupted_loss.item() - clean_loss.item() + 1e-8
)
print(f" Layer {layer_idx} {component:9s}: "
f"patched_loss={patched_loss.item():.4f}, "
f"recovery={recovery:.4f}")
def probe_representations(
model: SmallTransformer,
inputs: torch.Tensor,
targets: torch.Tensor,
p: int = 97,
) -> None:
"""Train linear probes at each layer to predict (a+b) mod p.
Tests whether the answer is linearly decodable at each layer.
"""
model.eval()
with torch.no_grad():
_, intermediates = model(inputs, return_intermediates=True)
print("\nLinear probe accuracy by layer:")
for layer_idx, residual in enumerate(intermediates["residual_stream"]):
# Extract representation at position 2 (= token)
repr_at_eq = residual[:, 2, :] # [N, d_model]
# Train/test split
n = repr_at_eq.size(0)
n_train = int(0.8 * n)
X_train = repr_at_eq[:n_train]
y_train = targets[:n_train]
X_test = repr_at_eq[n_train:]
y_test = targets[n_train:]
# Train linear probe
probe = nn.Linear(model.d_model, p)
optimizer = torch.optim.Adam(probe.parameters(), lr=1e-2)
for _ in range(200):
optimizer.zero_grad()
loss = F.cross_entropy(probe(X_train), y_train)
loss.backward()
optimizer.step()
# Evaluate
with torch.no_grad():
pred = probe(X_test).argmax(dim=1)
acc = (pred == y_test).float().mean().item()
layer_name = f"After layer {layer_idx}" if layer_idx > 0 else "Embedding"
print(f" {layer_name}: {acc:.4f}")
def train_sparse_autoencoder_on_mlp(
model: SmallTransformer,
inputs: torch.Tensor,
hidden_mult: int = 4,
l1_coeff: float = 1e-3,
num_epochs: int = 200,
) -> nn.Module:
"""Train a sparse autoencoder on MLP activations.
Args:
model: Trained Transformer.
inputs: Training inputs.
hidden_mult: Multiplier for overcomplete hidden dimension.
l1_coeff: L1 sparsity coefficient.
num_epochs: Training epochs.
Returns:
Trained sparse autoencoder.
"""
model.eval()
# Collect MLP activations from layer 1
activations = []
def hook(module, inp, out):
activations.append(out.detach())
handle = model.layers[1].mlp.register_forward_hook(hook)
with torch.no_grad():
for i in range(0, len(inputs), 256):
model(inputs[i:i+256])
handle.remove()
all_acts = torch.cat(activations, dim=0) # [N, seq_len, d_model]
all_acts = all_acts[:, 2, :] # Focus on position 2
d_model = all_acts.size(1)
hidden_dim = d_model * hidden_mult
# Sparse autoencoder
encoder = nn.Linear(d_model, hidden_dim)
decoder = nn.Linear(hidden_dim, d_model)
sae_params = list(encoder.parameters()) + list(decoder.parameters())
optimizer = torch.optim.Adam(sae_params, lr=1e-3)
for epoch in range(num_epochs):
perm = torch.randperm(all_acts.size(0))
epoch_loss = 0.0
for i in range(0, len(perm), 128):
batch = all_acts[perm[i:i+128]]
z = F.relu(encoder(batch))
x_hat = decoder(z)
recon_loss = (batch - x_hat).pow(2).sum(dim=-1).mean()
sparsity = l1_coeff * z.abs().sum(dim=-1).mean()
loss = recon_loss + sparsity
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
if (epoch + 1) % 50 == 0:
avg_active = (F.relu(encoder(all_acts)) > 0).float().sum(dim=-1).mean().item()
print(f" SAE Epoch {epoch+1}: loss={epoch_loss:.4f}, "
f"avg active features={avg_active:.1f}/{hidden_dim}")
return encoder
Results and Analysis
Attention Patterns
The attention analysis reveals clear structure: - Layer 0, Head 0: Strongly attends from position 2 to position 0 (extracting value $a$) - Layer 0, Head 1: Strongly attends from position 2 to position 1 (extracting value $b$) - Layer 1: More uniform attention, suggesting computation happens in the MLP
Activation Patching
Patching results identify the critical components: - Layer 0 attention: High recovery (~0.7), confirming it reads the operands - Layer 1 MLP: Highest recovery (~0.9), confirming it performs the computation - Layer 0 MLP and Layer 1 attention: Low recovery, suggesting they play minor roles
Probing Results
- Embedding layer: ~1/97 accuracy (chance), as expected
- After layer 0: ~5% accuracy, some information beginning to form
- After layer 1: ~85% accuracy, the answer is largely computed
Sparse Autoencoder Features
The SAE reveals features that correspond to specific values of $(a + b) \mod p$, with individual features activating for specific residue classes.
Lessons Learned
- Synthetic tasks are invaluable for validating methods: Because we know the ground truth (modular addition), we can verify that our interpretability tools correctly identify the mechanism.
- Attention heads specialize: Different heads attend to different operands, implementing a clear information routing function.
- MLPs do the heavy lifting: The actual computation (modular addition) happens primarily in the MLP layers, while attention handles information routing.
- Probing reveals where computation happens: The sharp increase in probe accuracy between layers 0 and 1 localizes the computation.
- These methods scale to real models: The same techniques (attention analysis, activation patching, probing, SAEs) are used to study production language models, though the complexity increases dramatically.