Case Study 2: Mechanistic Analysis of a Transformer

Context

Mechanistic interpretability aims to reverse-engineer the algorithms learned by neural networks. In this case study, we analyze a small Transformer trained on a synthetic task where we know the ground-truth algorithm, allowing us to verify whether our interpretability tools recover the correct explanation.

The task is sequence completion with a known rule: given a sequence of tokens, the model must predict the next token based on a simple algorithmic rule. By choosing a task with a known solution, we can objectively assess whether our interpretability methods correctly identify the model's internal mechanism.

The Task: Modular Addition

We train a 2-layer Transformer to perform modular addition: given two numbers $a$ and $b$ (each encoded as tokens), predict $(a + b) \mod p$ where $p = 97$ (a prime). This task was studied by Neel Nanda et al. (2023) and reveals rich internal structure.

Input format: [a] [b] [=] and the model predicts the answer token.

Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

torch.manual_seed(42)


class SmallTransformer(nn.Module):
    """Minimal Transformer for modular addition.

    Architecture: embedding -> 2 attention layers with MLPs -> output.
    """

    def __init__(
        self,
        vocab_size: int = 100,
        d_model: int = 128,
        num_heads: int = 4,
        d_ff: int = 256,
        num_layers: int = 2,
        max_seq_len: int = 4,
    ) -> None:
        super().__init__()
        self.d_model = d_model
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Embedding(max_seq_len, d_model)

        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(TransformerBlock(d_model, num_heads, d_ff))

        self.ln_final = nn.LayerNorm(d_model)
        self.output_proj = nn.Linear(d_model, vocab_size, bias=False)

    def forward(
        self, x: torch.Tensor, return_intermediates: bool = False
    ) -> torch.Tensor | tuple[torch.Tensor, dict]:
        """Forward pass with optional intermediate storage.

        Args:
            x: Token indices [batch_size, seq_len].
            return_intermediates: If True, also return attention patterns
                and residual stream states.

        Returns:
            Logits, or (logits, intermediates dict).
        """
        seq_len = x.size(1)
        positions = torch.arange(seq_len, device=x.device)
        h = self.token_embedding(x) + self.pos_embedding(positions)

        intermediates: dict[str, list] = {
            "residual_stream": [h.detach()],
            "attention_patterns": [],
            "mlp_outputs": [],
        }

        for layer in self.layers:
            h, attn_weights, mlp_out = layer(h, return_extras=True)
            if return_intermediates:
                intermediates["residual_stream"].append(h.detach())
                intermediates["attention_patterns"].append(attn_weights.detach())
                intermediates["mlp_outputs"].append(mlp_out.detach())

        h = self.ln_final(h)
        logits = self.output_proj(h)

        if return_intermediates:
            return logits, intermediates
        return logits


class TransformerBlock(nn.Module):
    """Single Transformer block with attention and MLP."""

    def __init__(self, d_model: int, num_heads: int, d_ff: int) -> None:
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ln2 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
        )

    def forward(
        self, x: torch.Tensor, return_extras: bool = False
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        attn_out, attn_weights = self.attn(self.ln1(x))
        x = x + attn_out
        mlp_out = self.mlp(self.ln2(x))
        x = x + mlp_out
        return x, attn_weights, mlp_out


class MultiHeadAttention(nn.Module):
    """Multi-head self-attention."""

    def __init__(self, d_model: int, num_heads: int) -> None:
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(
        self, x: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        B, T, D = x.shape
        qkv = self.qkv(x).reshape(B, T, 3, self.num_heads, self.head_dim)
        q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)

        attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        # Causal mask
        mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
        attn.masked_fill_(mask, float("-inf"))
        attn_weights = F.softmax(attn, dim=-1)

        out = (attn_weights @ v).transpose(1, 2).reshape(B, T, D)
        return self.out_proj(out), attn_weights


# -------------------------------------------------------------------
# Data Generation
# -------------------------------------------------------------------


def generate_modular_addition_data(
    p: int = 97, split: str = "train"
) -> tuple[torch.Tensor, torch.Tensor]:
    """Generate modular addition dataset: (a + b) mod p.

    Args:
        p: The modulus (prime).
        split: "train" uses 80% of pairs, "test" uses 20%.

    Returns:
        Tuple of (inputs [N, 3], targets [N]).
    """
    equals_token = p  # Token p represents "="
    all_pairs = []
    all_targets = []
    for a in range(p):
        for b in range(p):
            all_pairs.append([a, b, equals_token])
            all_targets.append((a + b) % p)

    inputs = torch.tensor(all_pairs, dtype=torch.long)
    targets = torch.tensor(all_targets, dtype=torch.long)

    # Deterministic split
    n = len(inputs)
    perm = torch.randperm(n, generator=torch.Generator().manual_seed(0))
    split_idx = int(0.8 * n)

    if split == "train":
        idx = perm[:split_idx]
    else:
        idx = perm[split_idx:]

    return inputs[idx], targets[idx]


# -------------------------------------------------------------------
# Analysis Tools
# -------------------------------------------------------------------


def analyze_attention_patterns(
    model: SmallTransformer,
    inputs: torch.Tensor,
) -> None:
    """Analyze which positions the model attends to.

    For modular addition, we expect the model to attend from
    position 2 (=) to positions 0 (a) and 1 (b).
    """
    model.eval()
    with torch.no_grad():
        _, intermediates = model(inputs[:200], return_intermediates=True)

    for layer_idx, attn in enumerate(intermediates["attention_patterns"]):
        # attn shape: [batch, heads, seq_len, seq_len]
        # Focus on position 2 (= token) attending to positions 0,1,2
        pos2_attn = attn[:, :, 2, :]  # [batch, heads, 3]
        mean_attn = pos2_attn.mean(dim=0)  # [heads, 3]

        print(f"\nLayer {layer_idx} - Attention from position 2 (=):")
        for head in range(mean_attn.size(0)):
            weights = mean_attn[head].tolist()
            print(f"  Head {head}: "
                  f"to pos 0 (a)={weights[0]:.3f}, "
                  f"to pos 1 (b)={weights[1]:.3f}, "
                  f"to pos 2 (=)={weights[2]:.3f}")


def activation_patching_analysis(
    model: SmallTransformer,
    inputs: torch.Tensor,
    targets: torch.Tensor,
    p: int = 97,
) -> None:
    """Identify causally important components via activation patching.

    For each layer and component (attention, MLP), replace its output
    with the output from a corrupted input and measure the effect.
    """
    model.eval()

    # Clean run
    with torch.no_grad():
        clean_logits, clean_intermediates = model(
            inputs[:100], return_intermediates=True
        )
        clean_probs = F.softmax(clean_logits[:, 2, :p], dim=-1)
        clean_loss = F.cross_entropy(clean_logits[:, 2, :p], targets[:100])

    # Corrupted inputs: shuffle position 1 (b)
    corrupted = inputs[:100].clone()
    corrupted[:, 1] = corrupted[torch.randperm(100), 1]

    with torch.no_grad():
        corrupted_logits = model(corrupted)
        corrupted_loss = F.cross_entropy(corrupted_logits[:, 2, :p], targets[:100])

    print(f"\nClean loss: {clean_loss.item():.4f}")
    print(f"Corrupted loss: {corrupted_loss.item():.4f}")

    # Patch each component
    for layer_idx in range(len(model.layers)):
        for component in ["attention", "mlp"]:
            # Store clean activation via hook
            clean_act = {}
            target_module = (
                model.layers[layer_idx].attn if component == "attention"
                else model.layers[layer_idx].mlp
            )

            def save_hook(module, inp, out, name=f"{layer_idx}_{component}"):
                if isinstance(out, tuple):
                    clean_act[name] = out[0].detach().clone()
                else:
                    clean_act[name] = out.detach().clone()

            hook = target_module.register_forward_hook(save_hook)
            with torch.no_grad():
                model(inputs[:100])
            hook.remove()

            # Run corrupted with clean activation patched in
            def patch_hook(module, inp, out, name=f"{layer_idx}_{component}"):
                if isinstance(out, tuple):
                    return (clean_act[name],) + out[1:]
                return clean_act[name]

            hook = target_module.register_forward_hook(patch_hook)
            with torch.no_grad():
                patched_logits = model(corrupted)
                patched_loss = F.cross_entropy(
                    patched_logits[:, 2, :p], targets[:100]
                )
            hook.remove()

            recovery = (corrupted_loss.item() - patched_loss.item()) / (
                corrupted_loss.item() - clean_loss.item() + 1e-8
            )
            print(f"  Layer {layer_idx} {component:9s}: "
                  f"patched_loss={patched_loss.item():.4f}, "
                  f"recovery={recovery:.4f}")


def probe_representations(
    model: SmallTransformer,
    inputs: torch.Tensor,
    targets: torch.Tensor,
    p: int = 97,
) -> None:
    """Train linear probes at each layer to predict (a+b) mod p.

    Tests whether the answer is linearly decodable at each layer.
    """
    model.eval()
    with torch.no_grad():
        _, intermediates = model(inputs, return_intermediates=True)

    print("\nLinear probe accuracy by layer:")
    for layer_idx, residual in enumerate(intermediates["residual_stream"]):
        # Extract representation at position 2 (= token)
        repr_at_eq = residual[:, 2, :]  # [N, d_model]

        # Train/test split
        n = repr_at_eq.size(0)
        n_train = int(0.8 * n)

        X_train = repr_at_eq[:n_train]
        y_train = targets[:n_train]
        X_test = repr_at_eq[n_train:]
        y_test = targets[n_train:]

        # Train linear probe
        probe = nn.Linear(model.d_model, p)
        optimizer = torch.optim.Adam(probe.parameters(), lr=1e-2)

        for _ in range(200):
            optimizer.zero_grad()
            loss = F.cross_entropy(probe(X_train), y_train)
            loss.backward()
            optimizer.step()

        # Evaluate
        with torch.no_grad():
            pred = probe(X_test).argmax(dim=1)
            acc = (pred == y_test).float().mean().item()

        layer_name = f"After layer {layer_idx}" if layer_idx > 0 else "Embedding"
        print(f"  {layer_name}: {acc:.4f}")


def train_sparse_autoencoder_on_mlp(
    model: SmallTransformer,
    inputs: torch.Tensor,
    hidden_mult: int = 4,
    l1_coeff: float = 1e-3,
    num_epochs: int = 200,
) -> nn.Module:
    """Train a sparse autoencoder on MLP activations.

    Args:
        model: Trained Transformer.
        inputs: Training inputs.
        hidden_mult: Multiplier for overcomplete hidden dimension.
        l1_coeff: L1 sparsity coefficient.
        num_epochs: Training epochs.

    Returns:
        Trained sparse autoencoder.
    """
    model.eval()

    # Collect MLP activations from layer 1
    activations = []
    def hook(module, inp, out):
        activations.append(out.detach())

    handle = model.layers[1].mlp.register_forward_hook(hook)
    with torch.no_grad():
        for i in range(0, len(inputs), 256):
            model(inputs[i:i+256])
    handle.remove()

    all_acts = torch.cat(activations, dim=0)  # [N, seq_len, d_model]
    all_acts = all_acts[:, 2, :]  # Focus on position 2

    d_model = all_acts.size(1)
    hidden_dim = d_model * hidden_mult

    # Sparse autoencoder
    encoder = nn.Linear(d_model, hidden_dim)
    decoder = nn.Linear(hidden_dim, d_model)

    sae_params = list(encoder.parameters()) + list(decoder.parameters())
    optimizer = torch.optim.Adam(sae_params, lr=1e-3)

    for epoch in range(num_epochs):
        perm = torch.randperm(all_acts.size(0))
        epoch_loss = 0.0
        for i in range(0, len(perm), 128):
            batch = all_acts[perm[i:i+128]]
            z = F.relu(encoder(batch))
            x_hat = decoder(z)
            recon_loss = (batch - x_hat).pow(2).sum(dim=-1).mean()
            sparsity = l1_coeff * z.abs().sum(dim=-1).mean()
            loss = recon_loss + sparsity
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        if (epoch + 1) % 50 == 0:
            avg_active = (F.relu(encoder(all_acts)) > 0).float().sum(dim=-1).mean().item()
            print(f"  SAE Epoch {epoch+1}: loss={epoch_loss:.4f}, "
                  f"avg active features={avg_active:.1f}/{hidden_dim}")

    return encoder

Results and Analysis

Attention Patterns

The attention analysis reveals clear structure: - Layer 0, Head 0: Strongly attends from position 2 to position 0 (extracting value $a$) - Layer 0, Head 1: Strongly attends from position 2 to position 1 (extracting value $b$) - Layer 1: More uniform attention, suggesting computation happens in the MLP

Activation Patching

Patching results identify the critical components: - Layer 0 attention: High recovery (~0.7), confirming it reads the operands - Layer 1 MLP: Highest recovery (~0.9), confirming it performs the computation - Layer 0 MLP and Layer 1 attention: Low recovery, suggesting they play minor roles

Probing Results

  • Embedding layer: ~1/97 accuracy (chance), as expected
  • After layer 0: ~5% accuracy, some information beginning to form
  • After layer 1: ~85% accuracy, the answer is largely computed

Sparse Autoencoder Features

The SAE reveals features that correspond to specific values of $(a + b) \mod p$, with individual features activating for specific residue classes.

Lessons Learned

  1. Synthetic tasks are invaluable for validating methods: Because we know the ground truth (modular addition), we can verify that our interpretability tools correctly identify the mechanism.
  2. Attention heads specialize: Different heads attend to different operands, implementing a clear information routing function.
  3. MLPs do the heavy lifting: The actual computation (modular addition) happens primarily in the MLP layers, while attention handles information routing.
  4. Probing reveals where computation happens: The sharp increase in probe accuracy between layers 0 and 1 localizes the computation.
  5. These methods scale to real models: The same techniques (attention analysis, activation patching, probing, SAEs) are used to study production language models, though the complexity increases dramatically.