Modern deep learning models are extraordinarily capable but profoundly opaque. A large language model can write code, diagnose diseases, and reason about complex problems, yet we cannot fully explain how it does any of these things. A vision model...
In This Chapter
- Introduction: Why Understanding Models Matters
- 38.1 Why Interpretability Matters
- 38.2 Feature Attribution: SHAP
- 38.3 LIME: Local Interpretable Model-agnostic Explanations
- 38.4 Gradient-Based Attribution
- 38.5 Attention as Explanation: Promises and Limitations
- 38.6 Probing Classifiers
- 38.7 Mechanistic Interpretability
- 38.8 Interpretability and Regulation
- 38.9 Model Editing
- 38.10 Practical Interpretability Workflow
- Summary
- Quick Reference
Chapter 38: Interpretability, Explainability, and Mechanistic Understanding
Introduction: Why Understanding Models Matters
Modern deep learning models are extraordinarily capable but profoundly opaque. A large language model can write code, diagnose diseases, and reason about complex problems, yet we cannot fully explain how it does any of these things. A vision model can detect tumors in medical images with superhuman accuracy, but it cannot tell us why it flagged a particular scan. This opacity is not merely an intellectual curiosity---it is a practical, ethical, and safety-critical problem.
Consider a hospital deploying a deep learning model to assist radiologists. If the model flags an image as cancerous, the radiologist needs to understand why. Is the model responding to the actual tumor, or to an artifact in the imaging equipment? A loan approval model must explain its decisions to comply with regulations. A self-driving car must make decisions that engineers can debug when something goes wrong. And as AI systems become more powerful, understanding their internal mechanisms becomes essential for ensuring they remain aligned with human values.
This chapter covers the full spectrum of interpretability, from practical explanation tools (SHAP, LIME, integrated gradients) through probing and analysis methods (probing classifiers, attention analysis) to the emerging field of mechanistic interpretability, which seeks to reverse-engineer the actual algorithms learned by neural networks. We will implement each technique in PyTorch and critically examine what each method does and does not tell us.
The Interpretability Landscape
The field uses several overlapping terms:
- Interpretability: The degree to which a human can understand the cause of a model's decision. Inherently interpretable models (linear regression, decision trees) are interpretable by design.
- Explainability: Post-hoc methods that provide explanations for an already-trained black-box model. SHAP, LIME, and saliency maps fall here.
- Mechanistic interpretability: Understanding the internal computations of neural networks---what individual neurons, circuits, and features represent and how they compose to produce behavior.
These form a spectrum from practical (explaining individual predictions) to fundamental (understanding how models work).
Prerequisites
Before diving in, you should be comfortable with: - Neural network training with PyTorch (Chapters 4--7, 11--12) - The Transformer architecture and attention mechanism (Chapters 18--19) - Basic gradient computation and backpropagation (Chapter 3) - Familiarity with large language models (Chapter 21) is helpful
38.1 Why Interpretability Matters
38.1.1 The Stakes of Opacity
Deep learning models are deployed in contexts where mistakes have severe consequences:
- Healthcare: A model that predicts disease risk based on spurious correlations (e.g., the hospital's watermark on X-rays rather than pathology) can harm patients.
- Criminal justice: Risk assessment tools used in sentencing must be transparent. The ProPublica investigation of COMPAS showed that an opaque recidivism prediction model exhibited racial bias.
- Finance: Regulations like the EU's GDPR and the Equal Credit Opportunity Act require that individuals receive explanations for automated decisions that affect them.
- Safety-critical systems: Self-driving cars, aircraft autopilots, and industrial control systems require interpretable decision-making for certification and debugging.
38.1.2 Trust, Debugging, and Scientific Discovery
Beyond compliance, interpretability serves three practical purposes:
-
Building trust: Stakeholders (doctors, judges, users) are more likely to adopt AI systems they can understand and verify. Trust calibration---knowing when to trust a model---requires understanding its reasoning.
-
Debugging: When a model fails, interpretability tools help diagnose why. Feature attributions can reveal shortcuts (models using background color instead of object shape), data leakage (labels correlated with image metadata), or distribution shift.
-
Scientific discovery: Interpreting what a model has learned can generate scientific hypotheses. AlphaFold's learned representations encode genuine biochemical knowledge. Analyzing what language models learn about syntax and semantics advances linguistics.
38.1.3 The Accuracy-Interpretability Trade-off
A persistent belief holds that more interpretable models are necessarily less accurate. This is sometimes true---a linear model cannot capture the same patterns as a deep neural network. But the trade-off is often exaggerated:
- For tabular data, well-tuned gradient boosted trees (moderately interpretable) often match or beat deep neural networks.
- Explanation methods can make complex models partially interpretable without sacrificing accuracy.
- Mechanistic interpretability seeks to make models fully transparent without changing their architecture.
The goal is not to replace deep learning with decision trees, but to develop tools and techniques that bring transparency to powerful models.
38.1.4 A Historical Perspective
The quest for model interpretability has a long history, evolving alongside the models themselves:
- 1960s-1980s: Expert systems were inherently interpretable---they consisted of hand-crafted if-then rules. Interpretability was a design feature, not an afterthought.
- 1990s-2000s: As statistical learning (SVMs, random forests) gained prominence, interpretability began to diverge from performance. Feature importance measures and partial dependence plots emerged as post-hoc tools.
- 2010s: Deep learning achieved breakthrough performance but was largely opaque. The first generation of interpretability tools (saliency maps, attention visualization) appeared.
- 2016-2017: LIME and SHAP provided the first principled frameworks for post-hoc explanation. These remain the workhorses of applied interpretability.
- 2020-present: Mechanistic interpretability emerged as a distinct research program, driven by Anthropic, DeepMind, and independent researchers. The goal shifted from explaining individual predictions to understanding the model's internal algorithms. Sparse autoencoders, circuit analysis, and representation engineering became key tools.
This trajectory reflects a shift from "what did the model predict?" to "why did it predict that?" to "how does the model compute its predictions?" Each level provides deeper understanding but requires more sophisticated tools.
38.2 Feature Attribution: SHAP
38.2.1 Shapley Values: From Game Theory to Machine Learning
SHAP (SHapley Additive exPlanations, Lundberg and Lee, 2017) is arguably the most principled feature attribution method. It is grounded in cooperative game theory, a branch of mathematics developed by Lloyd Shapley in 1953 (for which he received the Nobel Prize in Economics in 2012). The original problem was: given a coalition of players who cooperate to achieve some total payoff, how should the payoff be divided fairly among the players?
The game-theoretic setup. In the original formulation, we have a set of $d$ players and a value function $v(S)$ that gives the total payoff when coalition $S$ cooperates. The Shapley value assigns to each player $j$ a fair share of the total value based on their marginal contribution across all possible coalitions.
Translating to ML: The "players" are input features. The "value function" $v(S)$ is the model's prediction when only features in $S$ are observed (features not in $S$ are marginalized out). The "total payoff" is the model's prediction minus the baseline (average) prediction.
Intuition first: Imagine features arriving one at a time in some random order. Each feature's contribution is the change in prediction when it arrives. The Shapley value is the average of this contribution over all possible orderings---this is what makes it fair.
Formula: For a model $f$ with $d$ input features, the Shapley value of feature $j$ for input $x$ is:
$$\phi_j(x) = \sum_{S \subseteq \{1, \ldots, d\} \setminus \{j\}} \frac{|S|!(d - |S| - 1)!}{d!} \left[f(x_{S \cup \{j\}}) - f(x_S)\right]$$
where: - $S$ is a subset of features not including feature $j$ - $x_S$ denotes the input with features not in $S$ replaced by baseline values (typically sampled from the training data distribution) - $|S|!(d - |S| - 1)!/d!$ is a combinatorial weight that equals $1/(d \binom{d-1}{|S|})$, ensuring each ordering is equally likely - $f(x_{S \cup \{j\}}) - f(x_S)$ is the marginal contribution of feature $j$ to coalition $S$
Worked Example: SHAP for a 3-Feature Model. Consider a model $f$ with 3 features predicting house prices. For a specific house, $f(\text{all features}) = \$500K$, $f(\text{baseline}) = \$300K$, so the total value to distribute is \$200K. Suppose:
- $f(\{1\}) = \$380K$, $f({2}) = \$350K$, $f(\{3\}) = \$310K$
- $f(\{1,2\}) = \$460K$, $f({1,3}) = \$420K$, $f(\{2,3\}) = \$380K$
Then the Shapley value of feature 1 (say, square footage) is computed by averaging its marginal contribution over all orderings:
| Ordering | Feature 1's marginal contribution |
|---|---|
| 1, 2, 3 | $f(\{1\}) - f(\emptyset) = 380-300 = 80$ |
| 1, 3, 2 | $f(\{1\}) - f(\emptyset) = 380-300 = 80$ |
| 2, 1, 3 | $f(\{1,2\}) - f(\{2\}) = 460-350 = 110$ |
| 2, 3, 1 | $f(\{1,2,3\}) - f(\{2,3\}) = 500-380 = 120$ |
| 3, 1, 2 | $f(\{1,3\}) - f(\{3\}) = 420-310 = 110$ |
| 3, 2, 1 | $f(\{1,2,3\}) - f(\{2,3\}) = 500-380 = 120$ |
$$\phi_1 = \frac{80+80+110+120+110+120}{6} = \frac{620}{6} \approx \$103K$$
This tells us that square footage contributes approximately \$103K to this house's predicted price above the baseline.
Shapley values satisfy four desirable properties---and crucially, they are the only attribution method satisfying all four simultaneously: - Efficiency: $\sum_j \phi_j(x) = f(x) - E[f(X)]$. The attributions sum to the difference between the prediction and the average prediction. No value is "lost." - Symmetry: If two features contribute identically in all coalitions, they receive equal attribution. - Dummy (Null player): A feature that never changes the prediction in any coalition receives zero attribution. - Linearity (Additivity): For a model that is a sum of two models, attributions are the sum of attributions.
The uniqueness theorem (Shapley, 1953) proves that no other attribution method satisfies all four axioms. This is why SHAP is considered the gold standard for feature attribution.
38.2.2 SHAP Variants
Computing exact Shapley values requires evaluating $2^d$ subsets, which is intractable for models with many features ($2^{100} \approx 10^{30}$). SHAP provides efficient approximations tailored to different model types:
-
KernelSHAP: Uses a specially weighted linear regression to approximate Shapley values for any black-box model. The key insight is that the Shapley values are the solution to a specific weighted least squares problem, where the weight of each coalition $S$ is $w(S) = \frac{(d-1)}{\binom{d}{|S|} |S| (d-|S|)}$. This gives higher weight to coalitions near empty (few features) and full (almost all features), where marginal contributions are most informative. KernelSHAP is model-agnostic but requires $O(2^d)$ model evaluations for exact computation; in practice, $O(d \log d)$ samples suffice for good approximations.
-
TreeSHAP: Computes exact Shapley values in polynomial time $O(TLD^2)$ for tree-based models (random forests, XGBoost, LightGBM), where $T$ is the number of trees, $L$ is the maximum number of leaves, and $D$ is the maximum depth. TreeSHAP exploits the tree structure to efficiently compute the conditional expectations needed for Shapley values. It is exact (no approximation) and very fast---typically seconds even for large ensembles.
-
DeepSHAP: Combines SHAP with DeepLIFT for neural networks, propagating "contribution scores" through the network layer by layer. Each layer distributes the contribution of its outputs to its inputs using rules that are customized for each activation function (ReLU, sigmoid, etc.). This is faster than KernelSHAP for neural networks but is an approximation.
-
GradientSHAP: Uses expected gradients (integrating gradients over a distribution of baselines) as an approximation to SHAP values. The connection is: expected gradients are equivalent to Aumann-Shapley values, which converge to Shapley values under certain conditions. In practice, GradientSHAP randomly samples baselines and interpolation points, making it stochastic but efficient.
38.2.3 Implementation
import torch
import torch.nn as nn
import numpy as np
torch.manual_seed(42)
np.random.seed(42)
class SimpleClassifier(nn.Module):
"""A simple neural network for demonstration."""
def __init__(self, input_dim: int, hidden_dim: int, num_classes: int) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, num_classes),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
def compute_gradient_shap(
model: nn.Module,
x: torch.Tensor,
baselines: torch.Tensor,
target_class: int,
num_samples: int = 50,
) -> torch.Tensor:
"""Compute GradientSHAP attributions.
Approximates SHAP values by computing expected gradients over
random interpolations between the input and baseline samples.
Args:
model: Trained model.
x: Input tensor [1, input_dim].
baselines: Baseline samples [num_baselines, input_dim].
target_class: Class index to explain.
num_samples: Number of interpolation samples.
Returns:
SHAP values [input_dim].
"""
model.eval()
attributions = torch.zeros_like(x.squeeze(0))
for _ in range(num_samples):
# Random baseline
idx = torch.randint(0, baselines.size(0), (1,)).item()
baseline = baselines[idx : idx + 1]
# Random interpolation point
alpha = torch.rand(1)
interpolated = baseline + alpha * (x - baseline)
interpolated.requires_grad_(True)
output = model(interpolated)
score = output[0, target_class]
score.backward()
# Gradient * (input - baseline)
grad = interpolated.grad.squeeze(0)
diff = (x - baseline).squeeze(0)
attributions += grad * diff
attributions /= num_samples
return attributions
38.2.4 Interpreting SHAP Values
SHAP values provide rich information at multiple levels of analysis:
Local explanations. For each individual prediction, SHAP values answer: which features pushed the prediction up (positive SHAP value) and which pushed it down (negative SHAP value)? A SHAP waterfall plot shows these contributions as a sequence of arrows starting from the baseline prediction and arriving at the actual prediction. For example, for a house price prediction of \$500K:
$$\underbrace{300K}_{\text{baseline}} \xrightarrow{+120K}_{\text{sq.ft.}} \xrightarrow{+50K}_{\text{location}} \xrightarrow{-20K}_{\text{age}} \xrightarrow{+50K}_{\text{other}} = \underbrace{500K}_{\text{prediction}}$$
Global feature importance. Average the absolute SHAP values across many predictions: $I_j = \frac{1}{N}\sum_{i=1}^N |\phi_j(x_i)|$. This gives a global ranking of feature importance that is more nuanced than simple permutation importance because it accounts for feature interactions and non-linear effects.
Feature interactions. SHAP interaction values decompose the prediction into main effects and pairwise interactions: $f(x) = \phi_0 + \sum_j \phi_j + \sum_{j SHAP dependence plots. Plot SHAP values for a single feature against the feature's actual values. This reveals the feature's marginal effect on the prediction, including non-linear relationships and interactions. If the dependence plot is not a clean curve but shows a "cloud" of points, the feature interacts with other features. SHAP summary plots. Combine local explanations across all samples into a single visualization: each dot represents one prediction, positioned by the feature's SHAP value (x-axis) and colored by the feature's actual value. This provides a bird's-eye view of how each feature affects predictions across the entire dataset. LIME (Ribeiro, Tullio, and Guestrin, 2016) explains individual predictions by fitting an interpretable model (typically a linear model or decision tree) to the model's behavior in the neighborhood of the input. The algorithm:
1. Generate perturbed samples around the input $x$ by randomly toggling features on/off.
2. Obtain the model's predictions for each perturbed sample.
3. Weight each perturbed sample by its proximity to $x$ (using a kernel, typically exponential).
4. Fit a weighted linear model to the (perturbation, prediction) pairs.
5. The linear model's coefficients are the feature attributions. Formally, LIME solves: $$\xi(x) = \argmin_{g \in G} \mathcal{L}(f, g, \pi_x) + \Omega(g)$$ where $G$ is the class of interpretable models, $\mathcal{L}$ measures fidelity of $g$ to $f$ in the neighborhood of $x$ (weighted by kernel $\pi_x$), and $\Omega(g)$ penalizes complexity. LIME's perturbation strategy varies by data type, which is part of what makes it so flexible: Tabular data. Features are toggled on/off (replaced with baseline values) or perturbed with Gaussian noise. The interpretable representation is a binary vector indicating which features are present. Text. Words are randomly removed from the input. The interpretable representation is a binary vector over words (1 = present, 0 = removed). This naturally produces explanations like "the word 'excellent' contributed +0.3 to the positive sentiment prediction." Images. The image is first segmented into superpixels (using SLIC or quickshift), then superpixels are randomly masked (set to gray or mean pixel value). The interpretable representation is a binary vector over superpixels. This produces visual explanations highlighting which regions of the image drove the prediction. Practical tip. LIME's explanations depend on the perturbation distribution and the kernel width. Two runs of LIME on the same input may produce different explanations due to the random sampling. Always run LIME multiple times and report the variance. If explanations are highly unstable, consider increasing When to use which. Use SHAP when you need theoretically grounded, consistent attributions and can afford the computational cost. Use LIME when you need quick, qualitative explanations, especially for image or text data where LIME's perturbation approach is intuitive. In practice, many teams use both as a cross-check: if SHAP and LIME agree on the important features, the explanation is more trustworthy. The simplest attribution method computes the gradient of the output with respect to the input: $$\text{attribution}_j = \frac{\partial f(x)_c}{\partial x_j}$$ where $c$ is the target class. Features with large gradients are "important" in the sense that small changes to them significantly change the output. Vanilla gradients suffer from saturation: for ReLU networks, gradients are zero in saturated regions, producing noisy and incomplete attributions. Sundararajan, Taly, and Yan (2017) proposed Integrated Gradients to address the limitations of vanilla gradients. The key idea: instead of computing the gradient at a single point, integrate the gradient along a straight-line path from a baseline $x'$ (typically zero) to the input $x$: $$\text{IG}_j(x) = (x_j - x'_j) \times \int_0^1 \frac{\partial f(x' + \alpha(x - x'))}{\partial x_j} d\alpha$$ Integrated Gradients satisfy two important axioms:
- Completeness: $\sum_j \text{IG}_j(x) = f(x) - f(x')$. Attributions sum to the difference between the output and the baseline output. This follows from the fundamental theorem of calculus.
- Sensitivity: If changing feature $j$ changes the output, feature $j$ receives non-zero attribution. This is stronger than vanilla gradients, which can assign zero attribution to features in saturated regions. Intuition: Think of walking along a straight path from a blank image (baseline) to the actual image. At each step, ask: "which pixels is the model most sensitive to right now?" By summing these sensitivities over the entire path, we get a complete account of how each pixel contributed to the prediction. The path integral ensures we do not miss contributions in regions where the gradient happens to be zero at the input. Choosing a baseline. The choice of baseline $x'$ significantly affects the attributions:
- Zero baseline (black image for vision, zero vector for tabular): Simple but can produce misleading attributions if black pixels are informative.
- Blurred input: Removes fine details while preserving coarse structure.
- Random baseline: Average attributions over multiple random baselines to reduce dependence on any single choice.
- Maximum-distance baseline: Choose the baseline that maximizes distance from the input in the model's output space. For language models, the baseline is typically the embedding of the [PAD] or [MASK] token. For tabular data, using the training set mean or median for each feature is common. Worked example. Consider a sentiment classifier that scores the sentence "This movie is absolutely fantastic" at 0.95 (positive) and the empty baseline at 0.50. Integrated Gradients might assign: The attributions sum to $0.02 + 0.05 + 0.01 + 0.12 + 0.25 = 0.45 = 0.95 - 0.50$, satisfying completeness. The word "fantastic" receives the largest attribution, followed by "absolutely," which aligns with our intuition about sentiment. Several variants of gradient-based attribution exist, each addressing different limitations: SmoothGrad (Smilkov et al., 2017): Averages gradients over multiple noisy copies of the input: $\text{SmoothGrad}_j = \frac{1}{N}\sum_{n=1}^N \frac{\partial f(x + \epsilon_n)}{\partial x_j}$, where $\epsilon_n \sim \mathcal{N}(0, \sigma^2)$. This produces visually smoother, less noisy saliency maps. GradCAM (Selvaraju et al., 2017): For CNNs, computes the gradient of the target class score with respect to the feature maps of the last convolutional layer, averages over spatial dimensions, and uses the result as weights for a linear combination of feature maps. Produces coarse but localizable heatmaps. DeepLIFT (Shrikumar et al., 2017): Propagates "difference from reference" signals through the network, handling nonlinearities more carefully than vanilla backpropagation. DeepLIFT's rescale rule decomposes the output difference into contributions from each input feature. In Transformer models (Chapter 19), attention weights determine how much each token attends to every other token. It is tempting to interpret these weights as explanations: "The model predicted 'positive' because it attended strongly to the word 'excellent.'" Jain and Wallace (2019) and Wiegreffe and Pinter (2019) demonstrated fundamental problems with using attention weights as explanations: Attention is not unique: Different attention patterns can produce identical outputs. There exist adversarial attention distributions that give the same prediction but attend to completely different tokens. Attention does not reflect feature importance: High attention to a token does not mean that token is causally important for the prediction. The model might attend to a token to extract syntactic information while basing its prediction on semantic content encoded elsewhere. Multi-head and multi-layer attention is complex: Which heads? Which layers? Attention patterns across layers are composed, not independent. A token may receive high attention in layer 3 but be ignored in layer 6. Attention weights are softmaxed: They sum to 1, creating zero-sum competition. A token's attention weight depends on all other tokens, not just its own relevance. Despite these caveats, attention is not useless for interpretation: The residual stream perspective. Modern interpretability work views the Transformer not as a stack of attention layers but as a residual stream with attention heads and MLP layers "reading from" and "writing to" it. Each attention head reads a query and key from the residual stream, computes attention weights, and writes a value back. This perspective, emphasized by Elhage et al. (2021), makes it clear that attention weights alone cannot explain the model's computation---the value vectors and the residual stream must also be considered. Attention distance and syntactic structure. Hewitt and Manning (2019) showed that the syntactic tree structure of a sentence is linearly encoded in BERT's representation space. They defined a "structural probe" that recovers the parse tree from internal representations with high accuracy. This connection between attention distance and syntactic distance provides a meaningful (if limited) use of attention for interpretation. Probing classifiers (also called diagnostic classifiers) test whether a model's internal representations encode specific information. The idea: freeze the model, extract representations from a specific layer, and train a simple classifier (linear or shallow MLP) to predict a linguistic property from those representations. If the probe succeeds, the information is encoded in the representation. Common probing tasks include: One of the most informative uses of probing is to examine how information evolves through layers. By training probes at each layer of a deep model, we can trace the trajectory of information processing. Typical findings for BERT-like models:
- Layer 0 (embeddings): Encodes surface features (word identity, position). POS tagging probes already achieve moderate accuracy.
- Layers 1-4: Syntactic information (POS tags, dependency relations, constituency) is maximally represented. Probing accuracy for syntactic tasks peaks in this range.
- Layers 5-8: Semantic information emerges. Named entity recognition and semantic role labeling probes peak here.
- Layers 9-12: Task-relevant information becomes dominant. Representations become increasingly specialized for the pre-training objective. This "information hierarchy" mirrors the processing stages of a traditional NLP pipeline (tokenization -> POS -> parsing -> semantics -> task), suggesting that the Transformer rediscovers this classical structure through end-to-end training. Probing with causal interventions. Beyond correlational probing, causal probing techniques modify the representation to remove specific information and test whether downstream performance degrades. If removing syntactic information from layer 3 does not affect the model's performance on its downstream task, the model likely does not use syntactic information from that layer, even if a probe can recover it. Probing has a fundamental limitation: a successful probe does not prove the model uses the information. A representation might encode syntactic information without the model ever accessing it for its downstream task. Hewitt and Liang (2019) introduced control tasks to address this: if a probe can learn to predict random labels nearly as well as linguistic labels, the probe itself is too powerful and the result is meaningless. Best practices for probing:
- Use the simplest possible probe (linear). If a linear probe succeeds, the information is linearly accessible.
- Always compare against control tasks (predicting random labels, predicting from random representations).
- Report selectivity: the difference in probe accuracy between the real task and the control task.
- Consider the minimum description length (MDL) probing framework (Voita and Titov, 2020), which measures not just whether a probe can predict a property but how easily it can do so (measured by the online code length). Mechanistic interpretability (MI) aims to reverse-engineer neural networks, understanding not just what a model predicts but how it computes those predictions. The aspiration is to produce a complete, human-understandable description of a model's algorithms. This is analogous to reverse-engineering a compiled program: the weights are the "machine code," and MI seeks to recover the "source code." Key goals of MI:
- Identify interpretable features in neural network representations
- Understand the circuits that compute specific behaviors
- Detect and understand failure modes before they manifest
- Verify that models are safe and aligned with intended behavior Three foundational concepts structure MI research: Features: A feature is a property of the input that a neuron or direction in activation space represents. For example, a neuron in a vision model might represent "curve detectors," "dog ears," or "sky." In language models, features might represent "the token is a number," "the context is about biology," or "the next word should be a verb." Circuits: A circuit is a subgraph of the model's computational graph that implements a specific behavior. For example, Olsson et al. (2022) identified "induction heads" in language models---attention heads that implement a simple pattern matching algorithm: if the model has seen the sequence [A][B] before and now sees [A], the induction head predicts [B]. Superposition: Elhage et al. (2022) discovered that neural networks represent more features than they have neurons by using superposition---encoding many features as nearly orthogonal directions in a lower-dimensional space. This is analogous to compressed sensing: if features are sparse (rarely active simultaneously), they can share dimensions. Why does superposition occur? Consider a model with $d$ neurons in a hidden layer, but the world has $D \gg d$ meaningful features. If the model could only represent $d$ features (one per neuron), it would need to discard useful information. Instead, the model can represent up to $D$ features as nearly orthogonal directions in $\mathbb{R}^d$. The Johnson-Lindenstrauss lemma tells us that $D$ can be exponentially larger than $d$ while maintaining approximate orthogonality: $D \sim \exp(c \cdot d)$ vectors can be packed into $\mathbb{R}^d$ with pairwise angles close to 90 degrees. The trade-off is interference: when two features that share a dimension are simultaneously active, they interfere with each other, producing noise. Superposition works because natural features are sparse---most features are inactive for any given input, so interference is rare. The model implicitly performs a cost-benefit analysis: more features (better representation) at the cost of occasional interference (small errors). Polysemanticity. Superposition causes polysemanticity: individual neurons respond to multiple, seemingly unrelated concepts. A famous example from early Anthropic research found a neuron that responded to a cat's face, a car's front, and the top of the Eiffel Tower---all V-shaped patterns. The neuron is not confused; it represents a geometric feature that happens to appear in multiple contexts. Understanding polysemanticity is critical because it means we cannot simply read off what a model "knows" by examining individual neurons. Superposition is a fundamental challenge for MI: individual neurons do not correspond to individual features. A single neuron may participate in representing multiple unrelated features (polysemanticity), and a single feature may be distributed across multiple neurons. Activation patching (also called causal tracing or interchange interventions) is a technique for identifying which components of a model are causally responsible for a specific behavior. The procedure:
1. Run the model on a "clean" input that produces behavior A.
2. Run the model on a "corrupted" input that produces behavior B.
3. For each component (layer, attention head, MLP), replace its activation on the corrupted input with its activation from the clean input.
4. Measure whether this restoration recovers behavior A. If patching a specific component recovers the original behavior, that component is causally important for the behavior. Worked Example: Factual Recall. Consider a language model completing "The Eiffel Tower is located in ___." We want to find which components store the association between "Eiffel Tower" and "Paris." This technique was used by Meng et al. (2022) to identify where factual knowledge is stored in GPT-style models, leading to the ROME model editing technique discussed in Section 38.9. Variants of activation patching:
- Resample ablation: Replace the activation with a sample from its empirical distribution (rather than from a specific corrupted input), measuring the component's overall importance.
- Path patching: Patch activations along specific computational paths (e.g., from one attention head to a specific MLP), isolating the contribution of specific circuits.
- Attribution patching (Neel Nanda, 2023): Use linear approximations to patch all components simultaneously, dramatically reducing the number of forward passes needed. Since superposition means individual neurons do not correspond to individual features, we need methods to extract the true features from model activations. Sparse autoencoders (SAEs) have emerged as the primary tool for this. The idea: train an autoencoder on model activations with a sparsity constraint on the latent representation. The autoencoder's latent dimensions then correspond to interpretable features. Given activations $\mathbf{x} \in \mathbb{R}^d$ from a model layer: $$\mathbf{z} = \text{ReLU}(\mathbf{W}_{\text{enc}} (\mathbf{x} - \mathbf{b}_{\text{dec}}) + \mathbf{b}_{\text{enc}})$$
$$\hat{\mathbf{x}} = \mathbf{W}_{\text{dec}} \mathbf{z} + \mathbf{b}_{\text{dec}}$$ The loss is: $$\mathcal{L} = \|\mathbf{x} - \hat{\mathbf{x}}\|^2 + \lambda \|\mathbf{z}\|_1$$ where $\lambda$ controls sparsity. The encoder typically has many more latent dimensions than input dimensions (e.g., $4\times$ to $64\times$ overcomplete), allowing it to represent the many features in superposition. Once trained, SAE features can be interpreted by examining:
1. Maximally activating examples: Find the inputs that most strongly activate each feature.
2. Feature effects: What happens to the model's output when a feature is artificially activated or suppressed?
3. Feature co-occurrence: Which features tend to activate together?
4. Feature families: Cluster features by their activation patterns to find semantic groupings. Anthropic's research on Claude's internal features (Templeton et al., 2024) demonstrated that SAEs can discover millions of interpretable features, including features for specific people, places, concepts, code patterns, and abstract reasoning strategies. Practical considerations for training SAEs: Scale of overcomplete dictionary: Typical SAEs use 4x to 64x expansion (if the model layer has dimension $d$, the SAE has $4d$ to $64d$ latent dimensions). Larger dictionaries capture more features but are harder to train and interpret. Training data: Train the SAE on diverse activations from the target model. Using only one type of input (e.g., only English text) will miss features that activate primarily on other inputs (e.g., code, other languages, mathematical notation). Sparsity tuning: The $\lambda$ coefficient controls the trade-off between reconstruction quality and feature interpretability. Too low $\lambda$ produces dense, uninterpretable features; too high $\lambda$ degrades reconstruction. A good target is approximately 10-50 active features per input (out of thousands or millions of total features). Dead features: Some latent dimensions may never activate after training. Monitor the fraction of features that activate at least once per batch and reinitialize dead features periodically during training. Validation: After training, manually inspect the top-activating examples for approximately 100 randomly selected features. If most features have a clear, coherent interpretation (e.g., "references to water," "Python function definitions," "emotional language"), the SAE is working well. Features that lack clear interpretation may represent noise or highly abstract concepts. MI research has produced several landmark results that illustrate the promise of the approach: Induction heads (Olsson et al., 2022). These are pairs of attention heads across two layers that implement a simple but powerful algorithm: if the model has previously seen the bigram [A][B] in context, and it now encounters [A], the induction head predicts [B]. Layer 0 contains a "previous token head" that copies information about each token to the next position's residual stream. Layer 1 contains an "induction head" that searches for tokens matching the current token and copies what followed them. Induction heads appear to be a key mechanism underlying in-context learning in Transformers. Indirect Object Identification (IOI) (Wang et al., 2023). This study fully reverse-engineered the circuit in GPT-2 Small that performs indirect object identification (e.g., in "When Mary and John went to the store, John gave a drink to ___", the model should predict "Mary"). The researchers identified a circuit of 26 attention heads organized into functional groups: name mover heads, S-inhibition heads, duplicate token heads, and backup heads. Each group plays a specific computational role, and ablating any group degrades performance. Grokking and phase transitions. Nanda et al. (2023) used MI to understand "grokking"---the phenomenon where a model memorizes training data quickly but suddenly learns to generalize much later. They found that the model first memorizes using a lookup table mechanism and then discovers a genuine algorithm (modular arithmetic via Fourier analysis), with the transition corresponding to a phase change in the model's internal representations. Interpreting LLMs presents unique challenges due to their scale (billions of parameters), the breadth of their capabilities, and the difficulty of defining what "correct" internal behavior looks like. Scaling sparse autoencoders to LLMs. Anthropic's research on Claude (Templeton et al., 2024) trained sparse autoencoders on Claude 3 Sonnet's residual stream activations, discovering millions of interpretable features. These included features for specific entities (Golden Gate Bridge, Albert Einstein), concepts (deception, sycophancy, code quality), languages, emotional tones, and abstract reasoning patterns. Notably, they found features related to safety-relevant concepts like deception, power-seeking, and harmful content---features that could be monitored or manipulated to improve model safety. Representation engineering. Zou et al. (2023) proposed representation engineering, which identifies linear directions in activation space that correspond to high-level concepts like truthfulness, honesty, and harmfulness. By adding or subtracting these directions from the residual stream during inference, the model's behavior can be steered. For example, adding the "truthfulness" direction reduces hallucination, and subtracting the "sycophancy" direction produces more honest responses. Logit lens and tuned lens. The logit lens (nostalgebraist, 2020) applies the final layer's unembedding matrix to intermediate residual stream states to "read off" what the model is predicting at each layer. This reveals the progressive refinement of predictions through the network. The tuned lens (Belrose et al., 2023) improves on this by training a small affine transformation at each layer, accounting for the fact that intermediate representations may not lie in the same space as the final representation. Interpretability is not only a research interest---it is increasingly a legal requirement. Several regulations impose explainability obligations on AI systems: EU AI Act (as we will explore in detail in Chapter 39). High-risk AI systems must provide "sufficient transparency to enable users to interpret the system's output and use it appropriately." This requires technical documentation explaining the system's logic, and the ability to provide meaningful explanations of individual decisions. GDPR Right to Explanation. Article 22 of the EU's General Data Protection Regulation gives individuals the right not to be subject to decisions based solely on automated processing, and Articles 13-15 require "meaningful information about the logic involved." While the exact legal interpretation is debated, this has been widely interpreted as requiring some form of explainability for automated decisions. US Equal Credit Opportunity Act (ECOA). Lenders must provide specific reasons for credit denials. This is one of the oldest explainability requirements and predates modern ML. In practice, SHAP values or similar attributions are used to generate the required adverse action notices. US Executive Order on AI Safety (2023). Requires developers of large AI systems to share safety test results with the government and mandates risk assessments for high-impact AI systems, which implicitly require interpretability. Meeting regulatory requirements typically involves: Choosing appropriate explanation methods. For tabular models in finance, SHAP with TreeExplainer is often the standard. For image models in healthcare, GradCAM or Integrated Gradients provide visual explanations. Documenting the explanation pipeline. Record which method was used, what baseline was chosen, and how explanations are generated and presented to end users. Validating explanations. Regulators may ask whether explanations are faithful (accurately reflecting the model's reasoning) and useful (actually helping the affected individual understand the decision). SHAP's axioms provide a strong argument for faithfulness. Human-readable summaries. Raw SHAP values or saliency maps are not sufficient for non-technical users. Translate feature attributions into natural language: "Your loan was denied primarily because your debt-to-income ratio (contributing 35% to the decision) exceeds our threshold, and your credit history length (contributing 25%) is shorter than average." If we can identify where and how a model stores specific knowledge, we should be able to edit that knowledge without retraining. Model editing techniques modify a model's behavior on specific inputs while preserving its behavior on everything else. Rank-One Model Editing (ROME, Meng et al., 2022) identifies the MLP layers in a Transformer that store factual associations and modifies them with a rank-one update. The procedure: MEMIT (Mass-Editing Memory in a Transformer) extends ROME to edit many facts simultaneously by distributing updates across multiple layers. Model editing is powerful but limited:
- Ripple effects: Editing "The Eiffel Tower is in London" should also change answers to "What country is the Eiffel Tower in?" but often does not. Cohen et al. (2024) proposed the "ripple effect" evaluation, measuring whether related facts are updated consistently.
- Specificity: Edits sometimes affect unrelated inputs. Changing "Eiffel Tower is in London" might inadvertently affect predictions about other towers or other landmarks in Paris.
- Scalability: Large numbers of edits can degrade model performance. After 1,000 ROME edits, Meng et al. found that general language modeling performance degrades noticeably.
- Representation vs. computation: Factual knowledge may be stored in complex distributed circuits, not simple key-value pairs. Recent work suggests that factual recall involves multiple layers and attention heads working together. Implications for model governance. Model editing raises important governance questions: if we can surgically modify a model's knowledge, who decides what to edit? Should model editing be used to remove copyrighted knowledge, correct biases, or update outdated facts? These questions connect interpretability to the broader AI governance discussion in Chapter 39. Use multiple methods: No single interpretability tool gives the full picture. Combine feature attributions with probing and mechanistic analysis. If SHAP and Integrated Gradients agree on the important features, the explanation is more trustworthy than either alone. Validate explanations: Check that explanations are consistent across methods and align with domain knowledge. If SHAP says a feature is important, verify with a perturbation test: remove the feature and confirm that performance drops. If it does not, the attribution may be misleading. Report limitations: Every method has assumptions. LIME explanations depend on the perturbation distribution. SHAP values depend on the baseline. Probing accuracy does not imply usage. Activation patching only tests sufficiency (the component is sufficient to produce the behavior when patched in) but not necessity (the model may have redundant circuits). Consider the audience: Technical stakeholders may want detailed feature attributions; regulators may want simple explanations; end users may want counterfactuals ("If your income were $10K higher, the loan would be approved"). Design your explanation pipeline with the end consumer in mind. Do not over-interpret: A saliency map highlighting a region does not prove the model "understands" the object there. It shows a correlation between that region and the output. Similarly, finding a "cat detector" neuron does not mean the model processes cats in a human-like way. Document your interpretability pipeline: Record which methods were used, what hyperparameters (baseline, number of samples, kernel width), and how explanations were validated. This documentation is essential for reproducibility and regulatory compliance (as we discussed in Section 38.8). Test on known cases: Before applying interpretability methods to novel data, test them on cases where you know the ground truth. For example, create a synthetic dataset where you know which features are truly important, and verify that your interpretability method recovers them. Here is a step-by-step workflow for interpreting a credit scoring model: Global analysis: Compute SHAP values for 1,000 representative predictions. Create summary plots to identify the most globally important features. Verify that the top features align with domain knowledge (credit history, income, debt-to-income ratio should appear prominently). Bias detection: Disaggregate SHAP values by protected group. If a feature like "zip code" has different SHAP distributions across racial groups, it may be acting as a proxy for race (as we will discuss further in Chapter 39). Local explanations for denials: For every loan denial, generate a SHAP waterfall plot showing how each feature contributed to the decision. Convert these into natural-language adverse action notices: "Your application was primarily affected by: (1) credit utilization ratio above 80% (-0.3), (2) credit history length less than 3 years (-0.2)." Sanity checks: Verify that SHAP attributions pass sanity checks---features that should be irrelevant (e.g., application submission time) should have near-zero attributions. If they do not, investigate potential data leakage. Model debugging: If the model produces unexpected predictions, use SHAP force plots to trace which features are responsible. If a high-income applicant is denied, the SHAP plot might reveal an unusual pattern like excessive credit inquiries. Interpretability spans a wide spectrum, from practical explanation tools to fundamental reverse-engineering of neural networks. Feature attribution methods---SHAP, LIME, and Integrated Gradients---explain individual predictions by assigning importance scores to input features. SHAP is the most principled (grounded in Shapley values), LIME is the most flexible (model-agnostic, works with any perturbation scheme), and Integrated Gradients is the most efficient for differentiable models. Attention weights are tempting but unreliable explanations: different attention patterns can produce identical outputs, and attention does not imply causal importance. Probing classifiers test what information is encoded in representations but cannot prove the model uses that information. Mechanistic interpretability aims to fully understand neural network computations. Key findings include the discovery of interpretable circuits (induction heads, indirect object identification), the phenomenon of superposition (more features than neurons), and the use of sparse autoencoders to extract human-interpretable features from model activations. Model editing techniques like ROME demonstrate that mechanistic understanding can enable targeted modifications to model behavior, though current methods have significant limitations. The field is evolving rapidly, driven by the dual imperatives of safety (understanding what models do before deploying them in critical systems) and scientific curiosity (understanding how neural networks learn and represent knowledge). Looking ahead. Several trends will shape the future of interpretability: Scaling interpretability to frontier models. Current mechanistic interpretability research often focuses on small models (GPT-2 scale). Scaling these techniques to models with hundreds of billions of parameters is both a technical and a conceptual challenge. Sparse autoencoders are one path forward, but the computational cost of training SAEs at scale is itself substantial. Automated interpretability. Using AI to interpret AI: training models to automatically label SAE features, identify circuits, and generate natural-language descriptions of model behavior. Anthropic and OpenAI have both explored using LLMs to interpret other LLMs' internal representations, with promising early results. Causal interpretability. Moving beyond correlational methods (probing, attention analysis) to causal methods (activation patching, causal mediation analysis) that can answer "what would happen if this component were different?" rather than just "what does this component correlate with?" Interpretability for safety. Using interpretability tools to detect deceptive behavior, monitor for safety-relevant features (power-seeking, deception, harmful content generation), and verify alignment properties. This is perhaps the highest-stakes application of interpretability research, connecting directly to the AI safety concerns we will discuss in Chapter 39. Standardization. As interpretability becomes a regulatory requirement, standardized evaluation protocols, benchmarks, and reporting formats will emerge. The field currently lacks consensus on how to measure whether an explanation is "good enough," and developing such standards is an important open problem.
38.3 LIME: Local Interpretable Model-agnostic Explanations
38.3.1 The LIME Algorithm
38.3.2 Implementation
def lime_explain(
model: nn.Module,
x: torch.Tensor,
num_features: int,
num_samples: int = 1000,
kernel_width: float = 0.75,
) -> torch.Tensor:
"""Generate a LIME explanation for a single prediction.
Args:
model: Trained model.
x: Input to explain [input_dim].
num_features: Number of input features.
num_samples: Number of perturbation samples.
kernel_width: Width of the exponential kernel.
Returns:
Feature attribution weights [input_dim].
"""
model.eval()
# Get original prediction
with torch.no_grad():
original_pred = model(x.unsqueeze(0)).argmax(dim=1).item()
# Generate binary perturbation masks
masks = torch.bernoulli(
torch.full((num_samples, num_features), 0.5)
)
# Create perturbed inputs (masked features set to zero)
perturbed = x.unsqueeze(0).repeat(num_samples, 1) * masks
# Get model predictions
with torch.no_grad():
preds = model(perturbed)
pred_probs = torch.softmax(preds, dim=1)[:, original_pred]
# Compute distances (cosine) and kernel weights
distances = 1.0 - torch.cosine_similarity(
masks, torch.ones(1, num_features), dim=1
)
weights = torch.exp(-(distances ** 2) / (kernel_width ** 2))
# Weighted linear regression: solve (X^T W X) beta = X^T W y
X = masks
y = pred_probs
W = torch.diag(weights)
XtWX = X.T @ W @ X + 1e-6 * torch.eye(num_features)
XtWy = X.T @ W @ y
beta = torch.linalg.solve(XtWX, XtWy)
return beta
38.3.3 LIME for Different Data Types
num_samples or using SHAP instead.38.3.4 LIME vs. SHAP
Aspect
SHAP
LIME
Theoretical foundation
Shapley values (axiomatic)
Local linear approximation
Consistency
Guaranteed (Shapley axioms)
Not guaranteed
Computational cost
Moderate to high
Moderate
Model-agnostic
KernelSHAP: yes
Yes
Deterministic
Some variants
No (sampling-based)
Additivity
Attributions sum to prediction difference
No guarantee
Stability
High (exact methods)
Low (depends on sampling)
38.4 Gradient-Based Attribution
38.4.1 Vanilla Gradients
38.4.2 Integrated Gradients
Token
IG Attribution
This
+0.02
movie
+0.05
is
+0.01
absolutely
+0.12
fantastic
+0.25
38.4.3 Other Gradient-Based Methods
38.4.4 Implementation
def integrated_gradients(
model: nn.Module,
x: torch.Tensor,
baseline: torch.Tensor,
target_class: int,
num_steps: int = 50,
) -> torch.Tensor:
"""Compute Integrated Gradients attribution.
Args:
model: Trained model.
x: Input tensor [input_dim].
baseline: Baseline tensor [input_dim] (typically zeros).
target_class: Class to explain.
num_steps: Number of integration steps (higher = more accurate).
Returns:
Attribution values [input_dim].
"""
model.eval()
# Generate interpolated inputs along the path
alphas = torch.linspace(0, 1, num_steps + 1)
interpolated = torch.stack([
baseline + alpha * (x - baseline) for alpha in alphas
]) # [num_steps+1, input_dim]
interpolated.requires_grad_(True)
# Forward pass
outputs = model(interpolated)
scores = outputs[:, target_class]
total_score = scores.sum()
total_score.backward()
# Average gradients along the path
avg_gradients = interpolated.grad.mean(dim=0)
# Scale by (input - baseline)
attributions = (x - baseline) * avg_gradients
return attributions
38.5 Attention as Explanation: Promises and Limitations
38.5.1 The Intuition
38.5.2 Why Attention Is Not Explanation
38.5.3 When Attention Can Be Informative
def attention_rollout(
attention_matrices: list[torch.Tensor],
) -> torch.Tensor:
"""Compute attention rollout across layers.
Accounts for residual connections by adding identity and
re-normalizing at each layer.
Args:
attention_matrices: List of attention matrices per layer,
each of shape [num_heads, seq_len, seq_len].
Returns:
Rolled-out attention [seq_len, seq_len].
"""
# Average attention heads at each layer
avg_attentions = [attn.mean(dim=0) for attn in attention_matrices]
# Add residual connection (identity matrix) and renormalize
rollout = torch.eye(avg_attentions[0].size(0))
for attn in avg_attentions:
# Account for residual: 0.5 * attn + 0.5 * I
attn_with_residual = 0.5 * attn + 0.5 * torch.eye(attn.size(0))
# Renormalize rows
attn_with_residual = attn_with_residual / attn_with_residual.sum(dim=-1, keepdim=True)
rollout = rollout @ attn_with_residual
return rollout
38.6 Probing Classifiers
38.6.1 What Do Representations Encode?
38.6.2 The Probing Methodology
class LinearProbe(nn.Module):
"""Linear probe for analyzing model representations.
Trains a simple linear classifier on frozen representations
to test whether specific information is encoded.
"""
def __init__(self, hidden_dim: int, num_classes: int) -> None:
super().__init__()
self.classifier = nn.Linear(hidden_dim, num_classes)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Classify from frozen hidden states.
Args:
hidden_states: [batch_size, hidden_dim] or [batch_size, seq_len, hidden_dim].
Returns:
Logits [batch_size, num_classes] or [batch_size, seq_len, num_classes].
"""
return self.classifier(hidden_states)
def train_probe(
representations: torch.Tensor,
labels: torch.Tensor,
num_classes: int,
num_epochs: int = 100,
lr: float = 0.01,
) -> dict[str, float]:
"""Train a linear probe on frozen representations.
Args:
representations: Model representations [num_samples, hidden_dim].
labels: Ground truth labels [num_samples].
num_classes: Number of target classes.
num_epochs: Training epochs.
lr: Learning rate.
Returns:
Dictionary with train and test accuracy.
"""
# Train/test split
n = representations.size(0)
perm = torch.randperm(n)
n_train = int(0.8 * n)
train_idx = perm[:n_train]
test_idx = perm[n_train:]
probe = LinearProbe(representations.size(1), num_classes)
optimizer = torch.optim.Adam(probe.parameters(), lr=lr)
for epoch in range(num_epochs):
probe.train()
optimizer.zero_grad()
logits = probe(representations[train_idx])
loss = nn.functional.cross_entropy(logits, labels[train_idx])
loss.backward()
optimizer.step()
probe.eval()
with torch.no_grad():
train_pred = probe(representations[train_idx]).argmax(dim=1)
train_acc = (train_pred == labels[train_idx]).float().mean().item()
test_pred = probe(representations[test_idx]).argmax(dim=1)
test_acc = (test_pred == labels[test_idx]).float().mean().item()
return {"train_acc": train_acc, "test_acc": test_acc}
38.6.3 Layer-by-Layer Probing
38.6.4 The Probing Paradox
38.7 Mechanistic Interpretability
38.7.1 The Vision of Mechanistic Interpretability
38.7.2 Features, Circuits, and Superposition
38.7.3 Activation Patching
def activation_patching(
model: nn.Module,
clean_input: torch.Tensor,
corrupted_input: torch.Tensor,
layer_index: int,
target_output_fn: callable,
) -> dict[str, float]:
"""Perform activation patching on a specified layer.
Identifies whether a specific layer is causally responsible
for the difference in model behavior between two inputs.
Args:
model: The model to analyze (must support hooks).
clean_input: Input producing the target behavior.
corrupted_input: Input producing different behavior.
layer_index: Which layer to patch.
target_output_fn: Function that extracts the metric of interest
from model output.
Returns:
Dict with clean, corrupted, and patched output metrics.
"""
# Get clean activations
clean_activations = {}
def save_hook(name: str):
def hook(module: nn.Module, input: tuple, output: torch.Tensor) -> None:
clean_activations[name] = output.detach().clone()
return hook
# Register hooks on all layers
hooks = []
for i, layer in enumerate(model.net):
h = layer.register_forward_hook(save_hook(f"layer_{i}"))
hooks.append(h)
# Forward pass on clean input
with torch.no_grad():
clean_output = target_output_fn(model(clean_input))
# Remove hooks
for h in hooks:
h.remove()
# Get corrupted output
with torch.no_grad():
corrupted_output = target_output_fn(model(corrupted_input))
# Patch: run corrupted input but replace target layer's activation
def patch_hook(module: nn.Module, input: tuple, output: torch.Tensor) -> torch.Tensor:
return clean_activations[f"layer_{layer_index}"]
hooks = []
target_layer = list(model.net.children())[layer_index]
h = target_layer.register_forward_hook(patch_hook)
hooks.append(h)
with torch.no_grad():
patched_output = target_output_fn(model(corrupted_input))
for h in hooks:
h.remove()
return {
"clean": clean_output,
"corrupted": corrupted_output,
"patched": patched_output,
"recovery": (patched_output - corrupted_output) / (clean_output - corrupted_output + 1e-8),
}
38.7.4 Sparse Autoencoders for Feature Discovery
class SparseAutoencoder(nn.Module):
"""Sparse autoencoder for extracting features from model activations.
Learns an overcomplete dictionary of features with sparsity
constraints, revealing interpretable directions in activation space.
"""
def __init__(
self,
input_dim: int,
hidden_dim: int,
l1_coeff: float = 1e-3,
) -> None:
super().__init__()
self.encoder = nn.Linear(input_dim, hidden_dim)
self.decoder = nn.Linear(hidden_dim, input_dim, bias=True)
self.l1_coeff = l1_coeff
# Initialize decoder weights to unit norm
with torch.no_grad():
self.decoder.weight.data = nn.functional.normalize(
self.decoder.weight.data, dim=0
)
def forward(
self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Encode and reconstruct activations.
Args:
x: Model activations [batch_size, input_dim].
Returns:
Tuple of (reconstructed, latent_activations, loss).
"""
# Center around decoder bias
x_centered = x - self.decoder.bias
# Encode with ReLU (enforces non-negativity)
z = torch.relu(self.encoder(x_centered))
# Decode
x_hat = self.decoder(z) # bias is added here
# Reconstruction loss + sparsity penalty
recon_loss = (x - x_hat).pow(2).sum(dim=-1).mean()
sparsity_loss = self.l1_coeff * z.abs().sum(dim=-1).mean()
total_loss = recon_loss + sparsity_loss
return x_hat, z, total_loss
def get_feature_directions(self) -> torch.Tensor:
"""Return the learned feature directions (decoder columns).
Returns:
Feature directions [hidden_dim, input_dim].
"""
return self.decoder.weight.data.T
38.7.5 Interpreting SAE Features
38.7.6 Key Findings in Mechanistic Interpretability
38.7.7 Interpretability for Large Language Models
38.8 Interpretability and Regulation
38.8.1 Regulatory Requirements for Explainability
38.8.2 Practical Compliance
38.9 Model Editing
38.9.1 The Promise of Model Editing
38.9.2 ROME and MEMIT
38.9.3 Limitations of Model Editing
38.10 Practical Interpretability Workflow
38.10.1 Choosing the Right Tool
Goal
Method
Computational Cost
Explain a single prediction
SHAP, LIME, Integrated Gradients
Low to moderate
Global feature importance
SHAP summary plots
Moderate
Understand what representations encode
Probing classifiers
Low
Identify causal components
Activation patching
Moderate
Discover learned features
Sparse autoencoders
High
Edit specific knowledge
ROME, MEMIT
Moderate
Understand attention patterns
Attention rollout
Low
38.10.2 Best Practices
38.10.3 A Complete Interpretability Workflow Example
Summary
Quick Reference
Concept
Key Idea
SHAP
Shapley values for fair feature attribution
LIME
Local linear approximation of model behavior
Integrated Gradients
Path integral of gradients from baseline to input
Probing
Train simple classifier on frozen representations
Superposition
More features than neurons, encoded via near-orthogonality
Activation patching
Replace activations to identify causal components
Sparse autoencoders
Overcomplete dictionary learning for feature discovery
Induction heads
Attention heads implementing pattern matching
ROME
Rank-one edits to MLP layers for knowledge modification