> "The question is not whether a model can be explained — every model can be summarized, approximated, or narrated. The question is whether the explanation is faithful to what the model actually does, and whether it is useful to the person receiving...
In This Chapter
- Learning Objectives
- 35.1 The Explanation Problem
- 35.2 Interpretability vs. Explainability: A Precise Distinction
- 35.3 Shapley Values: The Game-Theoretic Foundation
- 35.4 SHAP: From Theory to Production
- 35.5 Beyond SHAP: LIME, PDP, ALE, and ICE
- 35.6 Gradient-Based Explanations for Deep Learning
- 35.7 Concept-Based Explanations: Speaking the Domain Expert's Language
- 35.8 Counterfactual Explanations: "What Would Need to Change?"
- 35.9 Attention as Explanation — And Its Limits
- 35.10 Explanation Infrastructure for Production Systems
- 35.11 The Regulatory Landscape
- 35.12 Honest Limitations: What Explanations Cannot Do
- 35.13 Progressive Project: StreamRec Explanation Infrastructure
- 35.14 Synthesis: An Explanation Methodology
- Chapter Summary
- References
Chapter 35: Interpretability and Explainability at Scale — From SHAP to Concept-Based Explanations in Production
"The question is not whether a model can be explained — every model can be summarized, approximated, or narrated. The question is whether the explanation is faithful to what the model actually does, and whether it is useful to the person receiving it." — Cynthia Rudin, "Stop Explaining Black Box Machine Learning Models for High Stakes Decisions and Use Interpretable Models Instead" (Nature Machine Intelligence, 2019)
Learning Objectives
By the end of this chapter, you will be able to:
- Distinguish interpretability from explainability, articulate the tradeoffs between inherently interpretable models and post-hoc explanations, and select appropriate explanation methods for a given audience and regulatory context
- Apply SHAP (TreeSHAP, DeepSHAP, KernelSHAP), LIME, partial dependence plots (PDP), accumulated local effects (ALE), and individual conditional expectation (ICE) plots at production scale
- Implement attention-based and gradient-based explanation methods (saliency maps, integrated gradients, GradCAM) for deep learning models using Captum
- Build concept-based explanations (TCAV, concept bottleneck models) that communicate model behavior in terms meaningful to domain experts
- Design explanation infrastructure for production systems — audit logging, explanation APIs, dashboards, and regulatory documentation
35.1 The Explanation Problem
A gradient-boosted tree model at Meridian Financial declines a credit application. The applicant calls. Under the Equal Credit Opportunity Act (ECOA) and Regulation B, the lender must provide a statement of specific reasons for the adverse action — not "your score was below the threshold," but specific, actionable factors like "your debt-to-income ratio exceeds the acceptable range" or "the length of your credit history is insufficient." The regulation does not require that the explanation be mathematically precise. It requires that it be useful: the applicant must be able to understand what they could change to receive a different outcome.
Twelve thousand miles away, a clinician at a pharmaceutical trial site reviews a deep learning model's prediction that a patient's tumor will not respond to a candidate therapy. The clinician does not need to know that neuron 4,217 in layer 12 activated at 0.83. She needs to know whether the model's prediction is driven by histological features she recognizes — cell morphology, tissue architecture, marker expression patterns — or by statistical artifacts she cannot evaluate.
A user on the StreamRec platform sees a recommendation and wonders: why this content? The user does not need a technical explanation. They need a sentence — "Because you watched similar cooking videos" or "Popular with viewers in your area" — that helps them calibrate trust in the recommendation.
These are three different explanation problems, requiring three different methods, serving three different audiences, operating under three different constraints. The field of interpretability and explainability provides a toolkit for all three. This chapter develops that toolkit rigorously, starting from the game-theoretic foundations of Shapley values and proceeding through gradient-based methods for neural networks, concept-based explanations for domain experts, counterfactual explanations for end users, and the engineering infrastructure required to deliver explanations at production scale.
Know How Your Model Is Wrong: Interpretability is the operational manifestation of this recurring theme. Throughout this textbook, we have assessed model quality through aggregate metrics (AUC, RMSE, Recall@20), disaggregated metrics (fairness audits, Chapter 31), calibration curves (Chapter 34), and prediction intervals (Chapter 34). Explanations add a complementary perspective: they tell you why the model made a specific prediction, which features drove it, and which inputs would have changed it. A model can be accurate, calibrated, fair, and still make individual predictions for the wrong reasons — reasons that would be obvious to a domain expert who could see the explanation but invisible in any aggregate metric. Explanations are the microscope to aggregate metrics' telescope.
Production ML = Software Engineering: Generating explanations in a research notebook is straightforward. Generating explanations at 15,000 requests per second, with sub-100ms latency, audit-logged, version-tracked, and compliant with three regulatory frameworks — that is a software engineering problem. This chapter treats explanation infrastructure with the same rigor that Chapters 24-30 applied to serving, monitoring, and deployment.
35.2 Interpretability vs. Explainability: A Precise Distinction
The terms are often used interchangeably. They should not be.
Interpretability is a property of the model itself. A model is interpretable if a human can understand the entire mapping from inputs to outputs — not a summary, not an approximation, but the actual function. Linear regression with 10 features is interpretable: each coefficient has a precise meaning, the prediction is a weighted sum, and the contribution of each feature is unambiguous. A decision tree with 5 levels is interpretable: you can trace any prediction through the tree and identify the exact conditions that led to it. A logistic regression, a GAM, a RuleFit model, a scoring card — these are interpretable models.
Explainability is a property of the explanation method, not the model. A model is explainable if we can produce a post-hoc explanation — a secondary, simplified account of why the model made a particular prediction. SHAP values, LIME, saliency maps, attention visualizations — these are explanation methods that produce explanations for models that are not themselves interpretable.
The distinction matters because interpretability and explainability have fundamentally different epistemic statuses:
| Dimension | Interpretability | Explainability |
|---|---|---|
| What is understood | The actual model | An approximation of the model |
| Faithfulness guarantee | By construction — the explanation IS the model | Not guaranteed — the explanation approximates the model |
| Audience | Anyone who can read the model's representation | Anyone who can read the explanation's representation |
| Failure mode | Model too complex to interpret (fails gracefully) | Explanation unfaithful to model (fails silently) |
| Regulatory acceptance | Generally preferred (ECOA, EU AI Act Annex III) | Accepted with documentation of limitations |
Cynthia Rudin has argued forcefully (2019) that for high-stakes decisions — criminal sentencing, medical diagnosis, credit scoring — interpretable models should always be preferred over post-hoc explanations of black-box models, because post-hoc explanations can be unfaithful: they can present a simplified story that does not accurately reflect the model's actual reasoning. This argument has genuine force. A SHAP explanation that attributes 40% of a credit denial to "debt-to-income ratio" may be an excellent global summary while masking a nonlinear interaction with zip code that is the actual marginal driver for this specific applicant.
But interpretability has limits. A linear model cannot capture the interaction effects that drive a 2-point AUC improvement in credit scoring. A decision tree cannot match the sequential pattern recognition that makes a transformer effective for session-based recommendation. A GAM cannot learn the hierarchical representations that make a CNN effective for medical imaging. When the performance gap between interpretable and complex models is large enough to affect real outcomes — when the linear credit scoring model's lower accuracy means denying credit to qualified applicants, when the simpler clinical model misses treatable cancers — the decision is not obvious.
This chapter takes the following position: use interpretable models when the performance gap is small; use complex models with rigorous post-hoc explanations when the gap is material; and never deploy a model in a high-stakes setting without either interpretability or explainability. The remainder of the chapter develops the tools for the second case.
35.3 Shapley Values: The Game-Theoretic Foundation
The Shapley value, introduced by Lloyd Shapley in 1953, is a solution concept from cooperative game theory. It answers a deceptively simple question: given a game played by a coalition of players, how should the total payoff be fairly divided among the players?
The Cooperative Game Setup
A cooperative game consists of a set of players $N = \{1, 2, \ldots, p\}$ and a value function $v: 2^N \to \mathbb{R}$ that assigns a real-valued payoff to every coalition (subset) of players. The value function satisfies $v(\emptyset) = 0$ — the empty coalition has zero value.
In the explanation context, the "players" are features, the "coalition" is a subset of features used for prediction, and the "value function" is the model's prediction (or some function of it) when only the features in the coalition are available. The "total payoff" to be divided is the difference between the model's prediction for a specific instance $x$ and the average prediction over the dataset (the baseline).
The Shapley Value Formula
The Shapley value $\phi_j$ for player (feature) $j$ is:
$$\phi_j = \sum_{S \subseteq N \setminus \{j\}} \frac{|S|! \; (p - |S| - 1)!}{p!} \left[ v(S \cup \{j\}) - v(S) \right]$$
where the sum ranges over all subsets $S$ of players that do not include $j$. The term $v(S \cup \{j\}) - v(S)$ is the marginal contribution of player $j$ to coalition $S$ — how much the payoff increases when $j$ joins $S$. The weighting factor $\frac{|S|!(p-|S|-1)!}{p!}$ ensures equal weight across all possible orderings in which $j$ could join the coalition.
The Four Axioms
Shapley proved that the Shapley value is the unique allocation satisfying four axioms:
-
Efficiency. The Shapley values sum to the total payoff: $\sum_{j=1}^{p} \phi_j = v(N) - v(\emptyset)$. Every unit of prediction difference is attributed to some feature. Nothing is lost.
-
Symmetry. If features $j$ and $k$ contribute equally to every coalition ($v(S \cup \{j\}) = v(S \cup \{k\})$ for all $S$ not containing either), they receive equal Shapley values. Features that behave identically are treated identically.
-
Dummy. If feature $j$ contributes nothing to any coalition ($v(S \cup \{j\}) = v(S)$ for all $S$), its Shapley value is zero. Irrelevant features get zero attribution.
-
Linearity (Additivity). If the value function is a sum of two games, $v = v_1 + v_2$, the Shapley values of the sum are the sums of the Shapley values: $\phi_j(v) = \phi_j(v_1) + \phi_j(v_2)$.
These axioms are not arbitrary mathematical conditions. They are fairness axioms — the same word appearing in a very different context than Chapter 31. The Shapley value is the unique "fair" allocation under these axioms. No other attribution method satisfies all four simultaneously. This is the theoretical foundation for SHAP's claim to be a principled attribution method, not merely a heuristic.
Computational Complexity
The Shapley value formula requires evaluating $2^p$ coalitions for each feature — exponential in the number of features. For a model with 200 features (typical in production credit scoring), this requires $2^{200}$ evaluations per feature per instance. This is computationally intractable.
The genius of SHAP (SHapley Additive exPlanations), introduced by Lundberg and Lee (2017), is to exploit model structure to make Shapley value computation tractable for specific model families. The key insight: the general Shapley formula requires exponential time because it treats the model as a black box. If you know the model's internal structure — the tree splits in a gradient-boosted ensemble, the layer-by-layer computation in a neural network — you can compute exact or approximate Shapley values in polynomial time.
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Optional, Any
import numpy as np
@dataclass
class ShapleyExplanation:
"""Container for a single instance's Shapley value explanation.
Attributes:
instance_id: Unique identifier for the explained instance.
base_value: Expected model output over the reference dataset (E[f(X)]).
shap_values: Dictionary mapping feature name to its Shapley value.
prediction: Model output for this instance (base_value + sum(shap_values)).
feature_values: Dictionary mapping feature name to input value.
"""
instance_id: str
base_value: float
shap_values: Dict[str, float]
prediction: float
feature_values: Dict[str, Any]
def top_features(self, k: int = 5) -> List[Tuple[str, float]]:
"""Return the k features with the largest absolute SHAP values."""
sorted_features = sorted(
self.shap_values.items(),
key=lambda item: abs(item[1]),
reverse=True,
)
return sorted_features[:k]
def verify_efficiency(self, tolerance: float = 1e-4) -> bool:
"""Verify the efficiency axiom: base_value + sum(shap) ≈ prediction."""
reconstructed = self.base_value + sum(self.shap_values.values())
return abs(reconstructed - self.prediction) < tolerance
def positive_contributors(self) -> Dict[str, float]:
"""Return features that push the prediction above the base value."""
return {k: v for k, v in self.shap_values.items() if v > 0}
def negative_contributors(self) -> Dict[str, float]:
"""Return features that push the prediction below the base value."""
return {k: v for k, v in self.shap_values.items() if v < 0}
35.4 SHAP: From Theory to Production
TreeSHAP: Exact Shapley Values for Tree Ensembles
Lundberg, Erion, Chen, and Lee (2020) introduced TreeSHAP, an algorithm that computes exact Shapley values for tree-based models (decision trees, random forests, gradient-boosted trees) in $O(TLD^2)$ time, where $T$ is the number of trees, $L$ is the maximum number of leaves, and $D$ is the maximum tree depth. For a typical XGBoost model with 500 trees and depth 6, this is milliseconds per instance — a four-order-of-magnitude improvement over the exponential brute-force computation.
The algorithm works by recursively propagating instance weights through each tree's structure, using the tree's splits to efficiently compute the conditional expectation $E[f(x) \mid x_S]$ for every feature subset $S$ simultaneously. The key insight is that a tree's structure already encodes the conditional independence assumptions needed: features not in the current branch's path are marginalized by following both child nodes weighted by the training data fraction at each split.
import shap
import xgboost as xgb
import numpy as np
import pandas as pd
def compute_treeshap_explanations(
model: xgb.Booster,
X: pd.DataFrame,
background: Optional[pd.DataFrame] = None,
check_additivity: bool = True,
) -> shap.Explanation:
"""Compute TreeSHAP explanations for an XGBoost model.
Args:
model: Trained XGBoost booster.
X: Instances to explain. Each row produces one explanation.
background: Reference dataset for the base value. If None, uses
the model's internal training data statistics.
check_additivity: If True, verify the efficiency axiom for every
instance. Raises ValueError if any instance violates it by
more than 1e-4.
Returns:
shap.Explanation object containing SHAP values, base values,
feature names, and input data.
"""
explainer = shap.TreeExplainer(
model,
data=background,
feature_perturbation="tree_path_dependent",
model_output="probability",
)
shap_values = explainer(X, check_additivity=check_additivity)
return shap_values
def extract_top_factors(
shap_explanation: shap.Explanation,
instance_idx: int,
k: int = 4,
) -> List[Tuple[str, float, Any]]:
"""Extract top-k contributing features for a single instance.
Returns:
List of (feature_name, shap_value, feature_value) tuples,
sorted by absolute SHAP value descending.
"""
values = shap_explanation.values[instance_idx]
feature_names = shap_explanation.feature_names
data = shap_explanation.data[instance_idx]
indexed = list(zip(feature_names, values, data))
indexed.sort(key=lambda t: abs(t[1]), reverse=True)
return indexed[:k]
TreeSHAP's Two Modes
TreeSHAP operates in two modes that produce different results:
Tree-path-dependent (the default since SHAP 0.40): marginalizes absent features by following the tree's internal data distribution — when a feature $x_j$ is absent from coalition $S$, the algorithm follows both children of any split on $x_j$, weighted by the fraction of training data going to each child. This respects the correlations in the training data and produces explanations that reflect the model's actual behavior.
Interventional (the original Shapley formulation): replaces absent features with values sampled from a reference distribution, breaking correlations. This answers a different question: "What would the prediction be if we intervened to set feature $x_j$ to a random value?" rather than "What would the prediction be if we didn't know feature $x_j$?"
The distinction is consequential in practice. For correlated features — income and credit limit, age and credit history length — the two modes produce different attributions. The tree-path-dependent mode attributes value to the first feature encountered in the tree path; the interventional mode distributes value more evenly among correlated features. Neither is universally "correct" — they answer different questions. The tree-path-dependent mode is faster and is the default. The interventional mode aligns more closely with causal reasoning (Chapters 15-19) when the reference distribution is chosen carefully.
DeepSHAP: Approximate Shapley Values for Neural Networks
DeepSHAP (Lundberg and Lee, 2017) combines SHAP's game-theoretic framework with DeepLIFT's (Shrikumar, Greenside, and Kundaje, 2017) backpropagation-based attribution. The idea: propagate Shapley value computations layer by layer through the neural network, using each layer's linear structure to compute exact Shapley values within that layer and composing across layers.
For a network $f = f_L \circ f_{L-1} \circ \cdots \circ f_1$ where each $f_l$ is a layer function, DeepSHAP computes the Shapley values of the output with respect to the input by applying the chain rule of Shapley values (a consequence of the linearity axiom) across layers. At each layer, the method uses the layer's activation function to compute local Shapley values, then propagates them backward.
import torch
import shap
def compute_deepshap_explanations(
model: torch.nn.Module,
X: torch.Tensor,
background: torch.Tensor,
output_index: Optional[int] = None,
) -> shap.Explanation:
"""Compute DeepSHAP explanations for a PyTorch model.
Args:
model: Trained PyTorch model in eval mode.
X: Instances to explain, shape (n_instances, *input_shape).
background: Reference dataset for the base value,
shape (n_background, *input_shape). Typically 100-500
random training samples.
output_index: For multi-output models, which output to explain.
None for single-output models.
Returns:
shap.Explanation object with SHAP values.
"""
model.eval()
explainer = shap.DeepExplainer(model, background)
shap_values = explainer.shap_values(X)
if output_index is not None and isinstance(shap_values, list):
shap_values = shap_values[output_index]
return shap.Explanation(
values=shap_values,
base_values=explainer.expected_value,
data=X.numpy() if isinstance(X, torch.Tensor) else X,
)
Limitations of DeepSHAP. DeepSHAP is an approximation, not exact. The layer-by-layer decomposition assumes that Shapley values compose multiplicatively through nonlinearities, which is not strictly true for non-piecewise-linear activation functions. For ReLU networks (piecewise linear), the approximation is exact. For sigmoid, tanh, or GELU activations, the approximation introduces error that can be significant for deep networks. GradientSHAP — which combines DeepSHAP with integrated gradients sampling — is generally more robust for non-ReLU architectures.
KernelSHAP: Model-Agnostic Shapley Approximation
KernelSHAP is the most general SHAP variant: it works with any model as a black box, requiring only the ability to evaluate the model on perturbed inputs. The idea is to formulate the Shapley value computation as a weighted linear regression problem.
Given an instance $x$ to explain, KernelSHAP: 1. Samples coalition vectors $z \in \{0, 1\}^p$ (each $z_j = 1$ means feature $j$ is "present") 2. For each $z$, constructs a perturbed input $h_x(z)$ by replacing absent features with reference values 3. Evaluates the model on each perturbed input: $f(h_x(z))$ 4. Fits a weighted linear regression: $g(z) = \phi_0 + \sum_{j=1}^p \phi_j z_j$, weighted by the SHAP kernel: $\pi(z) = \frac{p - 1}{\binom{p}{|z|} \cdot |z| \cdot (p - |z|)}$
The SHAP kernel assigns infinite weight to the empty and full coalitions (forcing the efficiency axiom) and high weight to coalitions of size 1 and $p-1$ (which are most informative about individual features). Lundberg and Lee (2017) proved that this weighted regression recovers exact Shapley values in the limit of infinite samples.
def compute_kernelshap_explanations(
predict_fn,
X: np.ndarray,
background: np.ndarray,
n_samples: int = 2048,
feature_names: Optional[List[str]] = None,
) -> shap.Explanation:
"""Compute KernelSHAP explanations for any model.
Args:
predict_fn: Callable that takes a numpy array and returns
predictions. Shape: (n, p) -> (n,) or (n, c).
X: Instances to explain, shape (n_instances, p).
background: Reference dataset, shape (n_background, p).
Larger background is more accurate but slower.
100 samples is a common choice.
n_samples: Number of coalition samples for the regression.
Higher is more accurate. "auto" lets SHAP decide.
feature_names: Optional list of feature names.
Returns:
shap.Explanation object with approximate SHAP values.
"""
explainer = shap.KernelExplainer(predict_fn, background)
shap_values = explainer.shap_values(X, nsamples=n_samples)
return shap.Explanation(
values=shap_values,
base_values=explainer.expected_value,
data=X,
feature_names=feature_names,
)
KernelSHAP's cost. For each instance, KernelSHAP requires $O(n_{\text{samples}} \cdot n_{\text{background}})$ model evaluations — each coalition sample is evaluated by replacing absent features with every background sample and averaging. For 2,048 samples, 100 background instances, and a model with 100ms inference time, that is $\sim$204,800 evaluations per instance — roughly 5.7 hours per explanation. This makes KernelSHAP unsuitable for real-time production use with slow models. It is a research and auditing tool, not a serving-time tool.
Choosing the Right SHAP Variant
| Variant | Model Type | Exact? | Time per Instance | Production Viable? |
|---|---|---|---|---|
| TreeSHAP | Tree ensembles | Yes (path-dependent) | 1-10 ms | Yes |
| DeepSHAP | Neural networks (ReLU) | Approximate | 10-100 ms | Yes, with batching |
| GradientSHAP | Neural networks (any) | Approximate | 50-200 ms | Yes, with batching |
| KernelSHAP | Any model | Approximate | 1-60 min | No (audit only) |
35.5 Beyond SHAP: LIME, PDP, ALE, and ICE
SHAP is not the only explanation method, and it is not always the best. Several complementary methods provide different perspectives on model behavior.
LIME (Local Interpretable Model-Agnostic Explanations)
Ribeiro, Singh, and Guestrin (2016) introduced LIME, which explains a prediction by fitting a local interpretable model (typically a sparse linear model) in the neighborhood of the instance to be explained.
The procedure: 1. Generate perturbed instances in the neighborhood of $x$ by randomly toggling interpretable features on/off 2. Evaluate the model on each perturbed instance 3. Weight the perturbed instances by their proximity to $x$ (using an exponential kernel) 4. Fit a weighted sparse linear regression (LASSO) to the perturbed instances 5. Return the coefficients of the linear model as feature attributions
import lime
import lime.lime_tabular
def compute_lime_explanation(
predict_fn,
X_train: np.ndarray,
instance: np.ndarray,
feature_names: List[str],
num_features: int = 10,
num_samples: int = 5000,
) -> Dict[str, float]:
"""Compute a LIME explanation for a single instance.
Args:
predict_fn: Callable, shape (n, p) -> (n, c) for classification
or (n, p) -> (n,) for regression.
X_train: Training data for fitting the LIME discretizer.
instance: Single instance to explain, shape (p,).
feature_names: List of feature names.
num_features: Number of features in the local explanation.
num_samples: Number of perturbed samples to generate.
Returns:
Dictionary mapping feature name to local importance weight.
"""
explainer = lime.lime_tabular.LimeTabularExplainer(
training_data=X_train,
feature_names=feature_names,
mode="classification",
discretize_continuous=True,
)
explanation = explainer.explain_instance(
data_row=instance,
predict_fn=predict_fn,
num_features=num_features,
num_samples=num_samples,
)
return dict(explanation.as_list())
LIME's limitations are well-documented. Alvarez-Melis and Jaakkola (2018) showed that LIME explanations are unstable: running LIME twice on the same instance with different random seeds can produce substantially different explanations. The instability arises from the random perturbation sampling — the local neighborhood is sampled differently each time, and the LASSO fit can change dramatically. SHAP does not have this problem because Shapley values are deterministic (for exact methods) or converge to a unique solution (for sampling methods).
A second limitation: LIME's kernel width — the parameter controlling how "local" the neighborhood is — has no principled selection criterion. A narrow kernel focuses on very nearby instances but may not sample enough of the model's decision boundary. A wide kernel samples a broader region but produces explanations that are less locally faithful. The default kernel width works reasonably in many cases but can fail silently.
When LIME is still useful: LIME produces explanations in the form of a sparse linear model, which is natural for users who think in terms of "feature X increased the prediction by Y." LIME is also model-agnostic and does not require access to model internals, making it useful for auditing proprietary models via API. For production deployment, prefer TreeSHAP (for tree models) or DeepSHAP (for neural networks) for their stability and theoretical guarantees.
Partial Dependence Plots (PDP) and Accumulated Local Effects (ALE)
PDP and ALE are global explanation methods — they explain the model's average behavior across the entire dataset, not individual predictions.
Partial Dependence Plots (Friedman, 2001) show the marginal effect of one or two features on the model's prediction, averaging over all other features:
$$\hat{f}_S(x_S) = \frac{1}{n} \sum_{i=1}^{n} f(x_S, x_{C}^{(i)})$$
where $x_S$ is the feature(s) of interest and $x_C^{(i)}$ are the remaining features from the $i$-th training instance. For each value of $x_S$, the PDP averages the model's prediction over all training instances, replacing only the features of interest.
The PDP's limitation: When features are correlated, PDP averages over feature combinations that may not exist in the data. If income and credit limit are strongly correlated, the PDP for income at $200,000 averages over all credit limits in the dataset — including credit limits of $2,000 that no $200,000-income applicant would have. This produces unrealistic predictions that distort the PDP.
Accumulated Local Effects (Apley and Zhu, 2020) address this by computing local effects within conditional bands:
$$\hat{f}_{j,\text{ALE}}(x_j) = \sum_{k=1}^{K_j(x_j)} \frac{1}{n_k} \sum_{i: x_j^{(i)} \in (z_{k-1}, z_k]} \left[ f(z_k, x_{-j}^{(i)}) - f(z_{k-1}, x_{-j}^{(i)}) \right]$$
Instead of averaging over all possible values of other features, ALE computes the change in prediction as the feature of interest moves within small intervals, using the actual conditional distribution of other features within each interval. This avoids extrapolation to unrealistic feature combinations.
from sklearn.inspection import PartialDependenceDisplay
import matplotlib.pyplot as plt
def plot_pdp_and_ale(
model,
X: pd.DataFrame,
feature_name: str,
kind: str = "both",
grid_resolution: int = 50,
) -> plt.Figure:
"""Plot PDP and/or ALE for a single feature.
Args:
model: Trained sklearn-compatible model with .predict() or
.predict_proba().
X: Feature matrix as a DataFrame (for feature names).
feature_name: Name of the feature to plot.
kind: "average" for PDP, "individual" for ICE, "both" for PDP+ICE.
grid_resolution: Number of grid points for the feature axis.
Returns:
Matplotlib figure with the PDP/ICE plot.
"""
fig, ax = plt.subplots(1, 1, figsize=(8, 5))
PartialDependenceDisplay.from_estimator(
model,
X,
features=[feature_name],
kind=kind,
ax=ax,
grid_resolution=grid_resolution,
ice_lines_kw={"color": "steelblue", "alpha": 0.05, "linewidth": 0.5},
pd_line_kw={"color": "darkred", "linewidth": 2.5},
)
ax.set_title(f"PDP/ICE: {feature_name}")
fig.tight_layout()
return fig
Individual Conditional Expectation (ICE)
ICE plots (Goldstein et al., 2015) disaggregate the PDP: instead of plotting the average effect, ICE plots show one curve per instance. Where the PDP shows the average marginal effect of income on default probability, the ICE plot shows a separate curve for each applicant — revealing whether the effect of income is heterogeneous. If all ICE curves are parallel, the feature's effect is additive (no interactions). If the curves cross or diverge, there are interaction effects that the PDP's average masks.
ICE plots are computationally cheap (each curve requires one model evaluation per grid point) and diagnostically powerful. They are the first tool to reach for when the PDP shows a surprising shape — the ICE curves reveal whether the shape reflects a uniform effect or an average of heterogeneous effects.
35.6 Gradient-Based Explanations for Deep Learning
For neural networks, the gradient of the output with respect to the input provides a natural attribution: features whose perturbation changes the output the most are the most important. Gradient-based methods formalize this intuition at varying levels of sophistication.
Saliency Maps (Vanilla Gradients)
Simonyan, Vedaldi, and Zisserman (2014) proposed the simplest gradient-based attribution: compute $\frac{\partial f(x)}{\partial x_j}$ for each input feature $x_j$. The absolute gradient magnitude $|\frac{\partial f}{\partial x_j}|$ indicates how sensitive the output is to small changes in feature $j$.
The problem: Vanilla gradients are noisy, visually fragmented (for images), and suffer from gradient saturation — for inputs in the flat region of a sigmoid or ReLU, the gradient is near zero even if the feature was critical to the prediction. The saliency map shows where the model is locally sensitive, not what the model used.
Integrated Gradients
Sundararajan, Taly, and Yan (2017) introduced integrated gradients (IG), which address the saturation problem by accumulating gradients along a path from a baseline $x'$ (typically a zero input or a black image) to the actual input $x$:
$$\text{IG}_j(x) = (x_j - x_j') \times \int_{\alpha=0}^{1} \frac{\partial f(x' + \alpha(x - x'))}{\partial x_j} \, d\alpha$$
Integrated gradients satisfy two key axioms: completeness (the attributions sum to $f(x) - f(x')$, analogous to Shapley's efficiency axiom) and sensitivity (if changing feature $j$ from the baseline changes the output, feature $j$ receives nonzero attribution). These axioms rule out many pathological behaviors of vanilla gradients.
GradCAM: Class Activation Mapping
Selvaraju et al. (2017) introduced GradCAM, a visualization technique for CNNs that produces coarse spatial heatmaps highlighting which regions of an image were important for a particular class prediction. GradCAM computes the gradient of the target class score with respect to the feature maps of the last convolutional layer, uses global average pooling of the gradients to obtain importance weights for each feature map channel, and then computes a weighted combination of the feature maps followed by a ReLU.
Captum: A Unified Framework
Meta's Captum library provides a unified PyTorch implementation of all major gradient-based methods. Using Captum is the recommended approach for production systems.
import torch
from captum.attr import (
IntegratedGradients,
GradientShap,
LayerGradCam,
Saliency,
NoiseTunnel,
)
@dataclass
class GradientExplanationConfig:
"""Configuration for gradient-based explanations."""
method: str # "integrated_gradients", "gradient_shap", "gradcam", "saliency"
n_steps: int = 50 # For integrated gradients: number of interpolation steps
n_samples: int = 25 # For GradientSHAP: number of reference samples
noise_tunnel: bool = False # Apply SmoothGrad noise averaging
noise_tunnel_samples: int = 10 # Number of noisy copies
noise_tunnel_type: str = "smoothgrad" # "smoothgrad", "smoothgrad_sq", "vargrad"
internal_batch_size: int = 32 # Batch size for attribution computation
target_layer: Optional[str] = None # For GradCAM: which conv layer
def compute_gradient_attribution(
model: torch.nn.Module,
inputs: torch.Tensor,
target: int,
config: GradientExplanationConfig,
baselines: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Compute gradient-based feature attributions using Captum.
Args:
model: Trained PyTorch model in eval mode.
inputs: Input tensor, shape (1, *input_shape).
target: Target class index for attribution.
config: Configuration specifying method and hyperparameters.
baselines: Baseline input(s) for methods that require them.
Shape (1, *input_shape) or (n_baselines, *input_shape).
Returns:
Attribution tensor, same shape as inputs.
"""
model.eval()
if baselines is None:
baselines = torch.zeros_like(inputs)
if config.method == "integrated_gradients":
attr_method = IntegratedGradients(model)
attributions = attr_method.attribute(
inputs,
baselines=baselines,
target=target,
n_steps=config.n_steps,
internal_batch_size=config.internal_batch_size,
)
elif config.method == "gradient_shap":
attr_method = GradientShap(model)
attributions = attr_method.attribute(
inputs,
baselines=baselines,
target=target,
n_samples=config.n_samples,
)
elif config.method == "gradcam":
if config.target_layer is None:
raise ValueError("GradCAM requires target_layer in config.")
# Resolve layer by name
layer = dict(model.named_modules())[config.target_layer]
attr_method = LayerGradCam(model, layer)
attributions = attr_method.attribute(inputs, target=target)
elif config.method == "saliency":
attr_method = Saliency(model)
attributions = attr_method.attribute(inputs, target=target)
else:
raise ValueError(f"Unknown method: {config.method}")
# Optionally apply SmoothGrad noise tunnel
if config.noise_tunnel and config.method != "gradcam":
noise_tunnel = NoiseTunnel(attr_method)
attributions = noise_tunnel.attribute(
inputs,
nt_type=config.noise_tunnel_type,
nt_samples=config.noise_tunnel_samples,
target=target,
baselines=baselines if config.method != "saliency" else None,
)
return attributions
Comparing gradient-based methods:
| Method | Completeness? | Baseline Required? | Computational Cost | Strengths |
|---|---|---|---|---|
| Vanilla Gradients | No | No | 1 backward pass | Fast, simple |
| Integrated Gradients | Yes | Yes | $n_{\text{steps}}$ forward+backward | Axiomatically grounded, reliable |
| GradientSHAP | Yes (approx.) | Yes | $n_{\text{samples}}$ forward+backward | Combines IG + SHAP theory |
| GradCAM | No (spatial only) | No | 1 backward pass | Interpretable heatmaps for CNNs |
| SmoothGrad | No | No | $n_{\text{samples}}$ backward | Noise reduction for any base method |
35.7 Concept-Based Explanations: Speaking the Domain Expert's Language
Feature-level explanations — "pixel (142, 87) contributed +0.03 to the prediction" or "feature lab_value_creatinine SHAP = -0.12" — are technically precise but often useless to domain experts. A clinician does not think in terms of individual pixel intensities or isolated lab values. She thinks in terms of concepts: "inflammation pattern," "fibrotic tissue architecture," "treatment-resistant phenotype." A credit analyst thinks in terms of "payment reliability," "income stability," "debt management pattern" — not individual features.
Concept-based explanations bridge the gap between feature-level attributions and human reasoning by computing model sensitivity to human-defined concepts.
TCAV: Testing with Concept Activation Vectors
Kim et al. (2018) introduced TCAV (Testing with Concept Activation Vectors), which quantifies the importance of a user-defined concept to a model's predictions. The method:
-
Define concepts. Collect a set of examples that represent the concept and a set of random examples that do not. For a medical imaging model, the "fibrotic tissue" concept might be represented by 50 image patches annotated as containing fibrosis.
-
Learn a Concept Activation Vector (CAV). Train a linear classifier in the activation space of a chosen internal layer to distinguish concept examples from random examples. The unit normal vector to the decision boundary is the CAV — a direction in activation space that corresponds to the concept.
-
Compute the conceptual sensitivity. For each input $x$ and target class $k$, compute the directional derivative of the model's output with respect to the CAV direction: $S_{C,k,l}(x) = \nabla h_l(x) \cdot v_C^l$, where $h_l(x)$ is the activation at layer $l$ and $v_C^l$ is the CAV.
-
Report the TCAV score. The fraction of class-$k$ inputs with positive conceptual sensitivity: $\text{TCAV}_{C,k} = \frac{|\{x \in X_k : S_{C,k,l}(x) > 0\}|}{|X_k|}$. A TCAV score significantly above 0.5 (tested via two-sided t-test against random CAVs) indicates that the concept is positively associated with the model's prediction.
@dataclass
class ConceptSet:
"""A collection of examples representing a human-defined concept.
Attributes:
name: Human-readable concept name (e.g., "fibrotic_tissue").
positive_examples: Tensor of examples containing the concept.
negative_examples: Tensor of random examples not containing it.
layer_name: Internal layer at which to learn the CAV.
"""
name: str
positive_examples: torch.Tensor
negative_examples: torch.Tensor
layer_name: str
@dataclass
class TCAVResult:
"""TCAV score and statistical test result for one concept-class pair.
Attributes:
concept_name: Name of the concept tested.
target_class: Class whose sensitivity to the concept was measured.
tcav_score: Fraction of target-class inputs with positive
conceptual sensitivity (range [0, 1]).
p_value: p-value from two-sided t-test against random CAVs.
significant: Whether the TCAV score is significantly different
from 0.5 at alpha = 0.05.
n_random_runs: Number of random CAV runs for the statistical test.
"""
concept_name: str
target_class: int
tcav_score: float
p_value: float
significant: bool
n_random_runs: int = 10
def compute_tcav(
model: torch.nn.Module,
concept_set: ConceptSet,
target_class: int,
X_target: torch.Tensor,
n_random_runs: int = 10,
) -> TCAVResult:
"""Compute TCAV score for a concept-class pair.
Args:
model: Trained PyTorch model.
concept_set: Positive and negative examples for the concept.
target_class: Class index to test sensitivity for.
X_target: Inputs of the target class to evaluate.
n_random_runs: Number of random CAV runs for the t-test.
Returns:
TCAVResult with score, p-value, and significance.
"""
from sklearn.linear_model import LogisticRegression
from scipy import stats
# Step 1: Extract activations at the target layer
activations = {}
def hook_fn(module, input, output):
activations["target"] = output.detach()
layer = dict(model.named_modules())[concept_set.layer_name]
handle = layer.register_forward_hook(hook_fn)
# Get activations for positive and negative concept examples
model.eval()
with torch.no_grad():
model(concept_set.positive_examples)
pos_acts = activations["target"].view(
concept_set.positive_examples.size(0), -1
).numpy()
model(concept_set.negative_examples)
neg_acts = activations["target"].view(
concept_set.negative_examples.size(0), -1
).numpy()
# Step 2: Train a linear classifier to learn the CAV
X_cav = np.concatenate([pos_acts, neg_acts], axis=0)
y_cav = np.concatenate([
np.ones(pos_acts.shape[0]),
np.zeros(neg_acts.shape[0]),
])
clf = LogisticRegression(max_iter=1000, solver="lbfgs")
clf.fit(X_cav, y_cav)
cav_vector = clf.coef_[0]
cav_vector = cav_vector / np.linalg.norm(cav_vector)
# Step 3: Compute conceptual sensitivity for target class inputs
handle.remove()
# Re-register hook, now we need gradients
layer_activations = {}
def grad_hook_fn(module, input, output):
output.retain_grad()
layer_activations["target"] = output
handle = layer.register_forward_hook(grad_hook_fn)
sensitivities = []
for i in range(X_target.size(0)):
x_i = X_target[i:i+1].requires_grad_(True)
output = model(x_i)
if output.dim() > 1:
score = output[0, target_class]
else:
score = output[0]
score.backward(retain_graph=False)
act_grad = layer_activations["target"].grad
if act_grad is not None:
grad_flat = act_grad.view(-1).detach().numpy()
sensitivity = np.dot(grad_flat, cav_vector)
sensitivities.append(sensitivity)
model.zero_grad()
handle.remove()
tcav_score = np.mean(np.array(sensitivities) > 0)
# Step 4: Statistical test against random CAVs
random_scores = []
for _ in range(n_random_runs):
random_cav = np.random.randn(len(cav_vector))
random_cav = random_cav / np.linalg.norm(random_cav)
random_sens = [
np.dot(s_grad, random_cav) if isinstance(s_grad, np.ndarray)
else 0.0
for s_grad in sensitivities
]
random_scores.append(np.mean(np.array(random_sens) > 0))
_, p_value = stats.ttest_1samp(
[tcav_score] + random_scores, 0.5
)
return TCAVResult(
concept_name=concept_set.name,
target_class=target_class,
tcav_score=float(tcav_score),
p_value=float(p_value),
significant=p_value < 0.05,
n_random_runs=n_random_runs,
)
Concept Bottleneck Models
Koh et al. (2020) proposed a more radical approach: instead of post-hoc testing concepts against a black-box model, build concepts into the model architecture. A concept bottleneck model (CBM) has two stages:
- Concept predictor: Maps raw inputs to a vector of concept activations $c = g(x) \in \mathbb{R}^m$, where each $c_j$ represents the predicted presence/absence or intensity of a human-defined concept.
- Task predictor: Maps concept activations to the final prediction $\hat{y} = h(c)$.
Because the prediction must pass through the concept bottleneck, the model's reasoning is inherently constrained to operate through human-interpretable concepts. The task predictor $h$ is typically a simple linear model or shallow MLP, so the contribution of each concept to the prediction is transparent.
Tradeoffs: CBMs sacrifice some accuracy (the bottleneck constrains the model to concepts humans have defined, potentially missing useful features that do not correspond to any defined concept) in exchange for genuine interpretability. They require concept annotations during training, which can be expensive. And the set of defined concepts may be incomplete — the model can only reason about concepts it was given.
Pharma application: In a drug response prediction task, a CBM might define concepts as: inflammation level, tumor mutation burden, immune infiltration score, prior treatment response, genetic risk markers. The task predictor $h$ is a logistic regression over these concepts. The clinician can inspect both the concept predictions ("the model thinks inflammation is high") and the task predictor weights ("high inflammation contributes +0.3 to the non-response probability"). If the concept predictions are wrong — the model says inflammation is high but the clinician disagrees — the clinician can override the concept value and get a revised prediction. This human-in-the-loop capability is unique to CBMs and not available with post-hoc explanation methods.
35.8 Counterfactual Explanations: "What Would Need to Change?"
Wachter, Mittelstadt, and Russell (2018) introduced counterfactual explanations, which answer the question: "What is the smallest change to the input that would change the model's decision?" Rather than explaining why a prediction was made (attribution), counterfactual explanations explain what would need to be different for the prediction to change (recourse).
For a credit application denied with prediction $\hat{y} = 0$ (deny), a counterfactual explanation might be: "If your debt-to-income ratio were 32% instead of 48% and your credit history were 5 years instead of 2 years, the application would have been approved." This is directly actionable — the applicant knows what to work on — and it satisfies the spirit (though not necessarily the letter) of ECOA adverse action requirements.
Formulation
Given an instance $x$ with prediction $f(x) = y_{\text{original}}$ and a desired outcome $y_{\text{desired}}$, find the counterfactual $x^{cf}$ that minimizes:
$$x^{cf} = \arg\min_{x'} d(x, x') \quad \text{subject to} \quad f(x') = y_{\text{desired}}$$
where $d$ is a distance function that measures the "cost" of changing $x$ to $x'$. In practice, the constraint is relaxed to a penalty:
$$x^{cf} = \arg\min_{x'} \lambda \cdot d(x, x') + \ell(f(x'), y_{\text{desired}})$$
where $\ell$ is a loss function (e.g., cross-entropy) penalizing predictions far from the desired outcome.
Practical Constraints
Raw counterfactual optimization can produce unrealistic results. Several constraints are typically imposed:
- Immutability. Some features cannot be changed: age, race, sex. The optimization must hold these fixed.
- Causal consistency. Changing income should also change expected credit limit, but not vice versa. The counterfactual should respect the causal graph (Chapters 15-19).
- Actionability. The counterfactual should suggest changes the applicant can actually make. "Increase your age by 10 years" is mathematically valid but useless.
- Sparsity. Humans prefer explanations with few changes. A counterfactual that changes 2 features is more useful than one that changes 12.
- Plausibility. The counterfactual should lie within the data manifold — it should look like a real applicant, not an impossible combination of features.
@dataclass
class CounterfactualConfig:
"""Configuration for counterfactual explanation generation."""
desired_class: int
max_iterations: int = 1000
learning_rate: float = 0.01
lambda_distance: float = 0.1
lambda_sparsity: float = 0.05
immutable_features: List[str] = field(default_factory=list)
feature_ranges: Optional[Dict[str, Tuple[float, float]]] = None
convergence_threshold: float = 0.5 # Predicted probability threshold
@dataclass
class CounterfactualExplanation:
"""A counterfactual explanation for a single instance.
Attributes:
original: Original feature values.
counterfactual: Modified feature values that change the prediction.
original_prediction: Model output for the original instance.
counterfactual_prediction: Model output for the counterfactual.
changes: Dictionary of {feature: (original_value, new_value)}.
distance: Distance between original and counterfactual.
n_features_changed: Number of features that differ.
"""
original: Dict[str, float]
counterfactual: Dict[str, float]
original_prediction: float
counterfactual_prediction: float
changes: Dict[str, Tuple[float, float]]
distance: float
n_features_changed: int
def summary(self) -> str:
"""Generate a human-readable summary of the counterfactual."""
lines = [
f"Original prediction: {self.original_prediction:.3f}",
f"Counterfactual prediction: {self.counterfactual_prediction:.3f}",
f"Number of changes: {self.n_features_changed}",
"Changes required:",
]
for feature, (orig, new) in self.changes.items():
direction = "increase" if new > orig else "decrease"
lines.append(
f" - {feature}: {orig:.2f} -> {new:.2f} ({direction} by {abs(new - orig):.2f})"
)
return "\n".join(lines)
def generate_counterfactual(
model: torch.nn.Module,
instance: torch.Tensor,
feature_names: List[str],
config: CounterfactualConfig,
) -> CounterfactualExplanation:
"""Generate a counterfactual explanation via gradient-based optimization.
Args:
model: Trained PyTorch model (binary classifier, outputs probability).
instance: Original input tensor, shape (1, p).
feature_names: List of feature names.
config: Counterfactual generation configuration.
Returns:
CounterfactualExplanation with the closest valid counterfactual.
"""
model.eval()
x_cf = instance.clone().detach().requires_grad_(True)
optimizer = torch.optim.Adam([x_cf], lr=config.learning_rate)
# Identify immutable feature indices
immutable_idx = [
feature_names.index(f)
for f in config.immutable_features
if f in feature_names
]
for step in range(config.max_iterations):
optimizer.zero_grad()
pred = model(x_cf)
if pred.dim() > 1:
pred = pred[0, config.desired_class]
else:
pred = pred[0]
# Classification loss: push prediction toward desired class
target = torch.tensor(1.0)
cls_loss = torch.nn.functional.binary_cross_entropy(
torch.sigmoid(pred) if pred.min() < 0 else pred, target
)
# Distance loss: minimize change from original
distance_loss = torch.nn.functional.mse_loss(x_cf, instance)
# Sparsity loss: L1 penalty on changes
sparsity_loss = torch.mean(torch.abs(x_cf - instance))
total_loss = (
cls_loss
+ config.lambda_distance * distance_loss
+ config.lambda_sparsity * sparsity_loss
)
total_loss.backward()
# Zero gradients for immutable features
if x_cf.grad is not None and immutable_idx:
x_cf.grad[0, immutable_idx] = 0.0
optimizer.step()
# Enforce feature ranges
if config.feature_ranges is not None:
with torch.no_grad():
for fname, (lo, hi) in config.feature_ranges.items():
idx = feature_names.index(fname)
x_cf[0, idx] = x_cf[0, idx].clamp(lo, hi)
# Check convergence
with torch.no_grad():
current_pred = model(x_cf)
if current_pred.dim() > 1:
current_pred = current_pred[0, config.desired_class]
else:
current_pred = current_pred[0]
if current_pred.item() >= config.convergence_threshold:
break
# Build the explanation
original_np = instance.detach().numpy()[0]
cf_np = x_cf.detach().numpy()[0]
changes = {}
for i, fname in enumerate(feature_names):
if abs(original_np[i] - cf_np[i]) > 1e-4:
changes[fname] = (float(original_np[i]), float(cf_np[i]))
with torch.no_grad():
orig_pred = model(instance).item()
cf_pred = model(x_cf).item()
return CounterfactualExplanation(
original=dict(zip(feature_names, original_np.tolist())),
counterfactual=dict(zip(feature_names, cf_np.tolist())),
original_prediction=orig_pred,
counterfactual_prediction=cf_pred,
changes=changes,
distance=float(torch.norm(x_cf - instance).item()),
n_features_changed=len(changes),
)
35.9 Attention as Explanation — And Its Limits
The transformer architecture (Chapter 10) produces attention weights at every layer and head: $\alpha_{ij}$ indicates how much position $i$ attends to position $j$. It is tempting to interpret these weights as explanations — "the model focused on these tokens/items" — and attention visualization is widely used in practice for exactly this purpose.
The case for attention as explanation. Attention weights are readily available (no additional computation), they provide a natural notion of "where the model looked," and they are intuitive to non-technical audiences. For the StreamRec transformer session model (Chapter 10), visualizing which items in a user's history the model attends to when scoring a candidate item provides a compelling narrative: "We recommended this because you recently watched X and Y."
The case against attention as explanation. Jain and Wallace (2019) demonstrated that attention weights are not explanations in any rigorous sense:
- Attention weights do not correlate reliably with gradient-based feature importance. The items receiving the highest attention weights are not necessarily the items whose removal would most change the prediction.
- Alternative attention distributions can produce identical predictions. There exist adversarial attention distributions — dramatically different from the learned attention — that produce nearly identical outputs, because the value vectors compensate for the changed attention pattern.
- Multi-head attention distributes information across heads. Looking at a single head's attention pattern gives a partial and potentially misleading picture.
Wiegreffe and Pinter (2019) offered a partial rebuttal, showing that attention is a "weak form of explanation" — it provides meaningful information about model behavior, just not the precise causal attribution that Shapley values or integrated gradients provide.
Practical recommendation. Use attention visualization as a communication tool for non-technical audiences and as a debugging tool for model developers. Do not use it as the sole basis for regulatory compliance, fairness audits, or any context where faithful attribution is required. Pair attention visualization with integrated gradients or DeepSHAP for rigorous attribution.
@dataclass
class AttentionExplanation:
"""Attention-based explanation for a transformer model.
Attributes:
item_ids: List of item identifiers in the input sequence.
attention_weights: Attention from the query position to each
item, averaged across heads. Shape (n_items,).
layer_index: Which transformer layer's attention was used.
head_aggregation: How attention was aggregated across heads
("mean", "max", or specific head index).
top_k_items: List of (item_id, attention_weight) for the
top-k most-attended items.
"""
item_ids: List[str]
attention_weights: np.ndarray
layer_index: int
head_aggregation: str
top_k_items: List[Tuple[str, float]]
def extract_attention_explanation(
model: torch.nn.Module,
input_ids: torch.Tensor,
query_position: int = -1,
layer_index: int = -1,
top_k: int = 5,
head_aggregation: str = "mean",
) -> AttentionExplanation:
"""Extract attention-based explanation from a transformer model.
Args:
model: Transformer model that returns attention weights.
input_ids: Input token/item IDs, shape (1, seq_len).
query_position: Position to explain (default: last position).
layer_index: Which layer's attention to extract (default: last).
top_k: Number of top-attended items to return.
head_aggregation: "mean" to average across heads, "max" to take
the max, or an integer for a specific head.
Returns:
AttentionExplanation with weights and top items.
"""
model.eval()
with torch.no_grad():
outputs = model(input_ids, output_attentions=True)
# outputs.attentions is a tuple of (n_layers,) tensors,
# each of shape (batch, n_heads, seq_len, seq_len)
attention = outputs.attentions[layer_index] # (1, n_heads, seq, seq)
attention = attention[0] # (n_heads, seq, seq)
# Extract attention from query_position to all other positions
query_attention = attention[:, query_position, :] # (n_heads, seq)
if head_aggregation == "mean":
aggregated = query_attention.mean(dim=0) # (seq,)
elif head_aggregation == "max":
aggregated = query_attention.max(dim=0).values # (seq,)
else:
aggregated = query_attention[int(head_aggregation)] # (seq,)
aggregated_np = aggregated.numpy()
item_ids = [str(i) for i in input_ids[0].numpy()]
top_indices = np.argsort(aggregated_np)[::-1][:top_k]
top_items = [
(item_ids[idx], float(aggregated_np[idx]))
for idx in top_indices
]
return AttentionExplanation(
item_ids=item_ids,
attention_weights=aggregated_np,
layer_index=layer_index,
head_aggregation=head_aggregation,
top_k_items=top_items,
)
35.10 Explanation Infrastructure for Production Systems
Generating an explanation in a Jupyter notebook is an analysis task. Serving explanations to 15,000 credit applicants per day, with audit trails, regulatory documentation, and version tracking, is a software engineering task. This section covers the engineering infrastructure required to deliver explanations at production scale.
The Explanation API
In a production system, explanations are generated by an API endpoint that sits alongside — or within — the prediction API. The design must address several constraints:
- Latency. For synchronous explanations (e.g., real-time adverse action notices), the explanation must be generated within the prediction latency budget. TreeSHAP (1-10ms) fits comfortably; KernelSHAP (minutes) does not.
- Consistency. The explanation must correspond to the actual model that made the prediction. If the model was updated between prediction and explanation, the explanation is invalid.
- Determinism. For regulated applications, the same input must produce the same explanation every time. This rules out stochastic methods (LIME, KernelSHAP with random sampling) unless seeded deterministically.
- Audit trail. Every explanation must be logged with the input, the prediction, the model version, the explanation method, and the timestamp. In regulated industries, these logs must be retained for 2-7 years.
import hashlib
import json
import time
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any
from datetime import datetime, timezone
@dataclass
class ExplanationRequest:
"""Incoming request for an explanation."""
request_id: str
model_id: str
model_version: str
instance: Dict[str, Any]
explanation_method: str # "treeshap", "deepshap", "counterfactual", "attention"
audience: str # "applicant", "underwriter", "regulator", "internal"
timestamp: str = field(
default_factory=lambda: datetime.now(timezone.utc).isoformat()
)
@dataclass
class ExplanationResponse:
"""Response containing the explanation and audit metadata."""
request_id: str
model_id: str
model_version: str
prediction: float
prediction_label: str
explanation_method: str
audience: str
top_factors: List[Dict[str, Any]] # [{name, value, contribution, direction}]
counterfactual: Optional[Dict[str, Any]] # For counterfactual explanations
natural_language: str # Human-readable explanation text
timestamp: str
computation_time_ms: float
explanation_hash: str # SHA-256 of the explanation for integrity verification
@staticmethod
def compute_hash(
request_id: str,
model_version: str,
top_factors: List[Dict[str, Any]],
) -> str:
"""Compute deterministic hash for audit integrity."""
payload = json.dumps(
{"request_id": request_id, "model_version": model_version,
"factors": top_factors},
sort_keys=True,
)
return hashlib.sha256(payload.encode()).hexdigest()
class ExplanationService:
"""Production explanation service with audit logging.
Serves explanations for multiple model types and audiences,
maintains an audit log, and generates natural language summaries.
"""
def __init__(
self,
model_registry: Dict[str, Any],
explainer_registry: Dict[str, Any],
audit_logger: Any,
nl_generator: Any,
):
self.model_registry = model_registry
self.explainer_registry = explainer_registry
self.audit_logger = audit_logger
self.nl_generator = nl_generator
def explain(self, request: ExplanationRequest) -> ExplanationResponse:
"""Generate an explanation for a prediction.
Args:
request: Explanation request with instance, method, and audience.
Returns:
ExplanationResponse with factors, NL summary, and audit metadata.
"""
start_time = time.monotonic()
# 1. Load model and explainer (same version that made the prediction)
model = self.model_registry[request.model_id][request.model_version]
explainer = self.explainer_registry[request.explanation_method]
# 2. Generate prediction (re-compute for consistency)
prediction = model.predict(request.instance)
# 3. Generate explanation
raw_explanation = explainer.explain(model, request.instance)
# 4. Format top factors for the specified audience
top_factors = self._format_factors(
raw_explanation, audience=request.audience
)
# 5. Generate counterfactual if requested
counterfactual = None
if request.explanation_method == "counterfactual":
counterfactual = raw_explanation.get("counterfactual")
# 6. Generate natural language explanation
nl_text = self.nl_generator.generate(
factors=top_factors,
prediction=prediction,
audience=request.audience,
)
computation_time_ms = (time.monotonic() - start_time) * 1000
response = ExplanationResponse(
request_id=request.request_id,
model_id=request.model_id,
model_version=request.model_version,
prediction=float(prediction),
prediction_label="approved" if prediction > 0.5 else "denied",
explanation_method=request.explanation_method,
audience=request.audience,
top_factors=top_factors,
counterfactual=counterfactual,
natural_language=nl_text,
timestamp=request.timestamp,
computation_time_ms=computation_time_ms,
explanation_hash=ExplanationResponse.compute_hash(
request.request_id, request.model_version, top_factors
),
)
# 7. Log to audit trail
self.audit_logger.log(response)
return response
def _format_factors(
self,
raw_explanation: Dict[str, Any],
audience: str,
) -> List[Dict[str, Any]]:
"""Format explanation factors for a specific audience.
Different audiences receive different granularity:
- "applicant": Top 4 factors, plain language, no values
- "underwriter": All factors, feature values, risk context
- "regulator": All factors, full SHAP values, statistical context
- "internal": Raw SHAP values, debug information
"""
shap_values = raw_explanation.get("shap_values", {})
sorted_factors = sorted(
shap_values.items(),
key=lambda item: abs(item[1]),
reverse=True,
)
if audience == "applicant":
n_factors = min(4, len(sorted_factors))
elif audience == "underwriter":
n_factors = min(10, len(sorted_factors))
else:
n_factors = len(sorted_factors)
factors = []
for name, value in sorted_factors[:n_factors]:
factors.append({
"name": name,
"contribution": float(value),
"direction": "increases risk" if value > 0 else "decreases risk",
"magnitude": "high" if abs(value) > 0.1 else "moderate"
if abs(value) > 0.03 else "low",
})
return factors
Audit Logging for Regulatory Compliance
In regulated industries, explanation audit logs serve as legal documentation. The ECOA requires that adverse action notices be documentable and reproducible. The EU AI Act (Annex III) requires that high-risk AI systems maintain logs of their operations. The audit log must be:
- Immutable. Once written, log entries cannot be modified or deleted. Append-only storage (e.g., write-once cloud storage, blockchain-anchored timestamps) provides tamper evidence.
- Complete. Every prediction that results in an adverse action must have a corresponding explanation log entry.
- Reproducible. Given the same input, model version, and explanation method, the system must produce the same explanation. This requires deterministic seeding and version pinning.
- Retained. ECOA requires retention of adverse action records for 25 months. GDPR Article 22 records should be retained for the duration of the processing activity plus the limitation period. Internal policy should specify retention periods by jurisdiction.
@dataclass
class AuditLogEntry:
"""Immutable audit log entry for one explanation event.
Designed for append-only storage with tamper-evident hashing.
"""
entry_id: str
timestamp: str
request_id: str
model_id: str
model_version: str
model_artifact_hash: str # SHA-256 of the model artifact
input_hash: str # SHA-256 of the input features (not raw PII)
prediction: float
prediction_label: str
explanation_method: str
explanation_hash: str
top_factors: List[Dict[str, Any]]
counterfactual_summary: Optional[str]
audience: str
computation_time_ms: float
regulatory_context: str # "ECOA_adverse_action", "GDPR_Art22", "internal"
previous_entry_hash: str # Hash chain for tamper evidence
class ExplanationAuditLogger:
"""Append-only audit logger with hash chain integrity."""
def __init__(self, storage_backend: Any):
self.storage = storage_backend
self._last_hash = "GENESIS"
def log(self, response: ExplanationResponse) -> AuditLogEntry:
"""Create and persist an immutable audit log entry."""
entry = AuditLogEntry(
entry_id=f"audit-{response.request_id}",
timestamp=response.timestamp,
request_id=response.request_id,
model_id=response.model_id,
model_version=response.model_version,
model_artifact_hash=self._get_model_hash(
response.model_id, response.model_version
),
input_hash="[computed from request]",
prediction=response.prediction,
prediction_label=response.prediction_label,
explanation_method=response.explanation_method,
explanation_hash=response.explanation_hash,
top_factors=response.top_factors,
counterfactual_summary=(
json.dumps(response.counterfactual)
if response.counterfactual else None
),
audience=response.audience,
computation_time_ms=response.computation_time_ms,
regulatory_context=self._classify_regulatory_context(response),
previous_entry_hash=self._last_hash,
)
# Compute this entry's hash for the chain
entry_payload = json.dumps({
"entry_id": entry.entry_id,
"timestamp": entry.timestamp,
"explanation_hash": entry.explanation_hash,
"previous_hash": entry.previous_entry_hash,
}, sort_keys=True)
entry_hash = hashlib.sha256(entry_payload.encode()).hexdigest()
self._last_hash = entry_hash
self.storage.append(entry)
return entry
def _get_model_hash(self, model_id: str, version: str) -> str:
"""Retrieve the SHA-256 hash of the model artifact."""
return f"sha256:{model_id}:{version}"
def _classify_regulatory_context(
self, response: ExplanationResponse
) -> str:
"""Determine the regulatory context for retention policy."""
if response.prediction_label == "denied":
return "ECOA_adverse_action"
if response.audience == "regulator":
return "regulatory_review"
return "internal"
Natural Language Explanation Generation
End users and applicants do not read SHAP values. They read sentences. The final component of the explanation infrastructure is a natural language generator that translates structured explanations into human-readable text.
@dataclass
class NLExplanationTemplate:
"""Template for natural language explanation generation.
Templates map feature names to human-readable factor descriptions
and generate sentences from structured SHAP explanations.
"""
feature_descriptions: Dict[str, str] # feature_name -> plain English
positive_template: str = "contributed to {outcome}"
negative_template: str = "reduced the likelihood of {outcome}"
def generate(
self,
factors: List[Dict[str, Any]],
prediction: float,
audience: str,
) -> str:
"""Generate a natural language explanation.
Args:
factors: Sorted list of explanation factors.
prediction: Model prediction value.
audience: Target audience identifier.
Returns:
Human-readable explanation string.
"""
if audience == "applicant":
return self._generate_applicant_explanation(factors, prediction)
elif audience == "underwriter":
return self._generate_underwriter_explanation(factors, prediction)
else:
return self._generate_technical_explanation(factors, prediction)
def _generate_applicant_explanation(
self,
factors: List[Dict[str, Any]],
prediction: float,
) -> str:
"""Generate a plain-language explanation for an applicant."""
outcome = "approval" if prediction > 0.5 else "denial"
lines = [f"The primary factors in this decision were:"]
for i, factor in enumerate(factors[:4], 1):
feature_name = factor["name"]
readable = self.feature_descriptions.get(
feature_name, feature_name.replace("_", " ")
)
direction = factor["direction"]
lines.append(f" {i}. Your {readable} ({direction})")
if outcome == "denial":
lines.append(
"\nYou have the right to request additional information "
"about this decision."
)
return "\n".join(lines)
def _generate_underwriter_explanation(
self,
factors: List[Dict[str, Any]],
prediction: float,
) -> str:
"""Generate a detailed explanation for an underwriter."""
lines = [
f"Model score: {prediction:.4f}",
f"Recommendation: {'Approve' if prediction > 0.5 else 'Deny'}",
"",
"Contributing factors (sorted by magnitude):",
]
for factor in factors:
name = factor["name"]
contrib = factor["contribution"]
lines.append(
f" {name}: {contrib:+.4f} ({factor['direction']}, "
f"{factor['magnitude']} impact)"
)
return "\n".join(lines)
def _generate_technical_explanation(
self,
factors: List[Dict[str, Any]],
prediction: float,
) -> str:
"""Generate a full technical explanation."""
lines = [
f"Prediction: {prediction:.6f}",
"",
"SHAP Attribution (all features):",
]
total_positive = sum(
f["contribution"] for f in factors if f["contribution"] > 0
)
total_negative = sum(
f["contribution"] for f in factors if f["contribution"] < 0
)
lines.append(f" Total positive attribution: {total_positive:+.4f}")
lines.append(f" Total negative attribution: {total_negative:+.4f}")
lines.append("")
for factor in factors:
lines.append(
f" {factor['name']}: {factor['contribution']:+.6f}"
)
return "\n".join(lines)
35.11 The Regulatory Landscape
Explanation requirements are codified in several regulatory frameworks. Understanding the legal requirements is as important as understanding the technical methods.
ECOA and Regulation B (United States)
The Equal Credit Opportunity Act (15 U.S.C. 1691) and its implementing regulation (Regulation B, 12 CFR 1002) require creditors to provide "a statement of specific reasons for the action taken" when denying credit or taking other adverse action. The CFPB has issued guidance (CFPB Circular 2022-03) clarifying that the use of complex models, including machine learning, does not exempt creditors from this requirement.
Key requirements: - Specificity. The reasons must identify the specific factors that contributed to the adverse action. "Your application did not meet our criteria" is insufficient. - Accuracy. The stated reasons must accurately reflect the factors that actually influenced the decision. An explanation method that attributes importance to the wrong features — even if the prediction is correct — violates this requirement. - Actionability. The reasons should, where possible, indicate what the applicant could change. The CFPB has indicated that counterfactual-style explanations ("if your balance were lower") are consistent with the regulation's intent. - Number of reasons. The regulation permits listing the "principal reasons" — typically 2-4 factors. It does not require explaining the full model.
GDPR Article 22 (European Union)
GDPR Article 22 gives individuals the right "not to be subject to a decision based solely on automated processing, including profiling, which produces legal effects concerning him or her or similarly significantly affects him or her." When automated processing is permitted (under explicit consent, contractual necessity, or member state law), Article 22(3) requires "suitable measures to safeguard the data subject's rights and freedoms and legitimate interests, at least the right to obtain human intervention ... to express his or her point of view and to contest the decision."
Recital 71 (interpretive, not legally binding) adds that the data subject should have the right "to obtain an explanation of the decision reached after such assessment." The scope of this "right to explanation" is debated among legal scholars. The Article 29 Working Party (now the EDPB) has interpreted it to require "meaningful information about the logic involved" — which most practitioners interpret as requiring some form of feature importance or factor-based explanation.
EU AI Act (2024)
The EU AI Act classifies AI systems by risk level. High-risk systems (Annex III) — including credit scoring, employment decisions, and access to essential services — must satisfy transparency requirements (Article 13) including:
- Sufficient transparency to enable users to interpret the system's output and use it appropriately
- Information about the system's level of accuracy, robustness, and cybersecurity
- Logging capabilities to enable tracing of the AI system's operation during its lifecycle
These requirements do not mandate a specific explanation method but create a de facto requirement for explanation infrastructure: audit logs, documentation of explanation methodology, and the ability to produce explanations on demand.
Practical Compliance Strategy
| Requirement | Technical Solution | Implementation |
|---|---|---|
| ECOA adverse action reasons | TreeSHAP top-4 factors + counterfactual | ExplanationService with "applicant" audience |
| GDPR "meaningful information" | Global model description + local feature importance | Model card + per-prediction SHAP summary |
| EU AI Act logging | Immutable audit trail with hash chain | ExplanationAuditLogger |
| EU AI Act transparency | Documentation of explanation methodology | Technical appendix to model documentation |
| Internal governance | Explanation dashboards, drift monitoring | Grafana dashboards (Chapter 30 infrastructure) |
35.12 Honest Limitations: What Explanations Cannot Do
This chapter would be incomplete without an honest accounting of what explanations cannot do.
1. Post-hoc explanations are not the model. Every post-hoc explanation method — SHAP, LIME, integrated gradients — is a separate model that approximates the original model's behavior. The explanation is faithful to the degree that the approximation is good. For TreeSHAP, the approximation is exact (modulo the conditional vs. interventional distinction). For KernelSHAP, LIME, and gradient-based methods, the approximation can be imperfect, especially in regions of high nonlinearity, feature interaction, or out-of-distribution inputs.
2. Explanations can be manipulated. Slack et al. (2020) demonstrated that it is possible to build adversarial models that produce any desired SHAP or LIME explanation while using different features for the actual prediction. A model could attribute its decision to "credit history" in the explanation while actually relying on "zip code" (a proxy for race). This is a genuine adversarial risk in contexts where model developers have incentives to produce favorable-looking explanations.
3. Explanations do not explain causation. A SHAP value tells you which features the model used for a prediction. It does not tell you which features caused the outcome. High SHAP importance for "number of credit inquiries" means the model's prediction would be different if the inquiry count were different — not that the inquiries caused the default. This distinction (Chapters 15-19) is essential for actionable explanations.
4. Global explanations can mask local pathology. A summary plot showing that "income" is the most important feature globally may mask the fact that for a specific subgroup — say, applicants with thin credit files — "zip code" dominates. Global explanations are summaries; summaries lose information.
5. Natural language explanations add a layer of lossy compression. Translating structured attributions into sentences necessarily discards nuance. "Your debt-to-income ratio contributed to this decision" does not convey the same information as "debt_to_income: SHAP = +0.087, pushing the default probability from the baseline of 0.12 to 0.21." The natural language version is more useful to the recipient but less faithful to the model.
These limitations do not invalidate explanations. They define the boundaries within which explanations are trustworthy. A well-designed explanation system acknowledges these boundaries, documents them, and designs safeguards (e.g., periodic audits of explanation faithfulness, adversarial testing of explanation methods, comparison of multiple methods for consistency).
35.13 Progressive Project: StreamRec Explanation Infrastructure
The StreamRec recommendation system has been built, deployed, monitored, audited for fairness, and trained with differential privacy over the course of this textbook. This milestone adds the final user-facing component: explanation infrastructure.
M16: Explanation Infrastructure
Deliverable 1: Global SHAP importance for the ranking model. Compute TreeSHAP values for the gradient-boosted re-ranking model (Chapter 24) on a representative sample of 10,000 user-item pairs. Produce a global summary plot showing the top 20 features by mean absolute SHAP value. Identify the three features that most differentiate high-engagement predictions from low-engagement predictions.
Deliverable 2: Attention visualization for the session transformer. For the transformer session model (Chapter 10), extract attention weights for 100 randomly sampled user sessions. Visualize which historical items the model attends to when scoring the top-recommended item. Identify common patterns: does the model attend to recent items, similar items, or items from the same content category?
Deliverable 3: Natural language explanation generation. Build a template-based explanation generator that translates SHAP values and attention patterns into user-facing explanation text. The generator should support three templates: - "Because you watched [item A] and [item B]" (attention-based) - "Popular in [content category] with viewers like you" (SHAP feature attribution) - "Trending in [region]" (contextual feature attribution)
Deliverable 4: Audit logging. Integrate the ExplanationAuditLogger into the StreamRec serving pipeline. Every recommendation that is shown to a user should have a corresponding audit log entry containing: the user ID (hashed), the recommended item ID, the model version, the explanation method, the top-3 factors, and the timestamp. The audit log should be queryable by user, by item, by model version, and by time range.
@dataclass
class StreamRecExplanation:
"""Explanation for a single StreamRec recommendation.
Combines SHAP attribution, attention signals, and natural language
into a single explanation object for serving and logging.
"""
user_id_hash: str
item_id: str
model_version: str
ranking_score: float
# SHAP attribution from the re-ranking model
shap_top_features: List[Dict[str, Any]]
shap_global_rank: Dict[str, int] # Feature -> rank in global importance
# Attention from the session transformer
attended_items: List[Tuple[str, float]] # (item_id, attention_weight)
# Natural language
explanation_text: str
explanation_type: str # "watch_history", "category", "trending"
# Audit metadata
timestamp: str
explanation_hash: str
class StreamRecExplanationPipeline:
"""End-to-end explanation pipeline for StreamRec recommendations.
Generates SHAP, attention, and NL explanations for each
recommendation, with audit logging.
"""
def __init__(
self,
ranking_model: Any,
session_model: Any,
shap_explainer: Any,
nl_templates: Dict[str, str],
audit_logger: ExplanationAuditLogger,
):
self.ranking_model = ranking_model
self.session_model = session_model
self.shap_explainer = shap_explainer
self.nl_templates = nl_templates
self.audit_logger = audit_logger
def explain_recommendation(
self,
user_features: Dict[str, Any],
item_features: Dict[str, Any],
session_history: List[str],
user_id_hash: str,
item_id: str,
) -> StreamRecExplanation:
"""Generate a complete explanation for one recommendation.
Args:
user_features: User feature dictionary.
item_features: Item feature dictionary.
session_history: List of item IDs in the user's recent session.
user_id_hash: SHA-256 hash of the user ID (for audit logging).
item_id: ID of the recommended item.
Returns:
StreamRecExplanation with SHAP, attention, and NL components.
"""
# 1. SHAP explanation for the ranking model
combined_features = {**user_features, **item_features}
shap_values = self.shap_explainer.explain(
self.ranking_model, combined_features
)
top_features = sorted(
shap_values.items(),
key=lambda x: abs(x[1]),
reverse=True,
)[:5]
# 2. Attention explanation for the session model
attention_weights = self._extract_session_attention(
session_history, item_id
)
top_attended = sorted(
attention_weights.items(),
key=lambda x: x[1],
reverse=True,
)[:3]
# 3. Select explanation type and generate NL
explanation_type, nl_text = self._select_explanation(
top_features, top_attended, item_features
)
# 4. Compute ranking score
ranking_score = self.ranking_model.predict(combined_features)
# 5. Build explanation object
explanation = StreamRecExplanation(
user_id_hash=user_id_hash,
item_id=item_id,
model_version=self.ranking_model.version,
ranking_score=float(ranking_score),
shap_top_features=[
{"name": n, "value": float(v)} for n, v in top_features
],
shap_global_rank={}, # Populated from precomputed global ranks
attended_items=[(iid, float(w)) for iid, w in top_attended],
explanation_text=nl_text,
explanation_type=explanation_type,
timestamp=datetime.now(timezone.utc).isoformat(),
explanation_hash=hashlib.sha256(
nl_text.encode()
).hexdigest()[:16],
)
return explanation
def _extract_session_attention(
self,
session_history: List[str],
target_item: str,
) -> Dict[str, float]:
"""Extract attention weights from the session transformer."""
# Implementation delegates to extract_attention_explanation
# Returns {item_id: attention_weight}
weights = {}
for i, item in enumerate(session_history):
weights[item] = 1.0 / (len(session_history) - i)
total = sum(weights.values())
return {k: v / total for k, v in weights.items()}
def _select_explanation(
self,
top_features: List[Tuple[str, float]],
top_attended: List[Tuple[str, float]],
item_features: Dict[str, Any],
) -> Tuple[str, str]:
"""Select the most appropriate explanation type and generate NL."""
# Heuristic: if top attended items explain >60% of attention,
# use watch_history; if category dominates SHAP, use category;
# otherwise use trending.
total_attention = sum(w for _, w in top_attended)
if total_attention > 0.6 and len(top_attended) >= 2:
items_str = " and ".join(
[f"'{iid}'" for iid, _ in top_attended[:2]]
)
return (
"watch_history",
f"Because you recently watched {items_str}",
)
top_feature_name = top_features[0][0] if top_features else ""
if "category" in top_feature_name.lower():
category = item_features.get("category", "this category")
return (
"category",
f"Popular in {category} with viewers like you",
)
return (
"trending",
f"Trending in your region",
)
35.14 Synthesis: An Explanation Methodology
The tools in this chapter are not interchangeable. Each serves a specific purpose, audience, and context. The table below summarizes when to use each method.
| Method | Scope | Model Type | Audience | Regulatory Use | Production Viable |
|---|---|---|---|---|---|
| TreeSHAP | Local + Global | Trees | All | ECOA, GDPR, EU AI Act | Yes (1-10ms) |
| DeepSHAP | Local | Neural nets | Technical | Internal, GDPR | Yes (10-100ms) |
| KernelSHAP | Local | Any | Auditors | Audit, research | No (minutes) |
| LIME | Local | Any | Non-technical | Limited (unstable) | No (unstable) |
| Integrated Gradients | Local | Neural nets | Technical | GDPR, EU AI Act | Yes (50-200ms) |
| GradCAM | Local (spatial) | CNNs | Clinical, visual | Domain-specific | Yes (fast) |
| TCAV | Global (concept) | Neural nets | Domain experts | Research, clinical | Batch only |
| Concept Bottleneck | Local + Global | CBM architectures | Domain experts | High-risk AI | Yes (inherent) |
| Counterfactual | Local | Any (differentiable) | End users | ECOA (actionable) | Depends on model |
| PDP/ALE | Global | Any | Technical | Model documentation | Batch only |
| Attention | Local | Transformers | Non-technical | Communication (not rigorous) | Yes (free) |
The recommended approach for production systems:
- Start with the regulatory requirement. If ECOA applies, TreeSHAP + counterfactuals. If GDPR Article 22 applies, global model description + local SHAP summary. If EU AI Act Annex III applies, full audit logging + explanation documentation.
- Choose the method that matches the model. TreeSHAP for tree ensembles, DeepSHAP or integrated gradients for neural networks, KernelSHAP for audit of any model.
- Format for the audience. Raw SHAP for internal analysis. Top-4 factors for applicants. Concept-level summaries for clinicians. Natural language for end users.
- Build the infrastructure. Explanation API, audit logging, version tracking, monitoring of explanation quality (are the same features appearing as top contributors? Has the explanation distribution shifted?).
- Acknowledge the limitations. Document which method was used, what it measures, and what it does not measure. No explanation method is perfect. Honest documentation of limitations is itself a form of transparency.
Chapter Summary
Interpretability and explainability are not luxuries for academically interesting models. They are engineering requirements for production systems that affect people's lives — credit decisions, medical recommendations, content curation. This chapter has developed the technical foundations (Shapley values from cooperative game theory), the practical toolkit (SHAP, LIME, gradient methods, concept-based methods, counterfactuals), the engineering infrastructure (explanation APIs, audit logs, NL generation), and the regulatory context (ECOA, GDPR, EU AI Act) required to deliver explanations at scale.
The limitations are real. Post-hoc explanations approximate the model. Attention weights are not faithful attributions. Natural language compresses information. Explanations can be adversarially manipulated. These limitations do not excuse the absence of explanations — they define the standard of care for explanation systems: multiple methods for cross-validation, periodic faithfulness audits, adversarial testing, and honest documentation.
The StreamRec progressive project now has explanation infrastructure: SHAP for the ranking model, attention visualization for the session transformer, natural language for users, and audit logging for governance. The system can tell a user why it recommended something, tell a regulator how it works, and prove to an auditor what it said and when.
Knowing how your model is wrong requires knowing what your model thinks it knows. Explanations are how you find out.
References
- Apley, D. W., and Zhu, J. (2020). Visualizing the effects of predictor variables in black box supervised learning models. Journal of the Royal Statistical Society Series B, 82(4), 1059-1086.
- Goldstein, A., Kapelner, A., Bleich, J., and Pitkin, E. (2015). Peeking inside the black box: Visualizing statistical learning with plots of individual conditional expectation. Journal of Computational and Graphical Statistics, 24(1), 44-65.
- Jain, S., and Wallace, B. C. (2019). Attention is not explanation. NAACL-HLT.
- Kim, B., Wattenberg, M., Gilmer, J., Caruana, R., Welling, M., and Viégas, F. (2018). Interpretability beyond feature attribution: Quantitative testing with concept activation vectors (TCAV). ICML.
- Koh, P. W., Nguyen, T., Tang, Y. S., Mussmann, S., Pierson, E., Kim, B., and Liang, P. (2020). Concept bottleneck models. ICML.
- Lundberg, S. M., and Lee, S.-I. (2017). A unified approach to interpreting model predictions. NeurIPS.
- Lundberg, S. M., Erion, G., Chen, H., DeGrave, A., Prutkin, J. M., Nair, B., Katz, R., Himmelfarb, J., Bansal, N., and Lee, S.-I. (2020). From local explanations to global understanding with explainable AI for trees. Nature Machine Intelligence, 2(1), 56-67.
- Ribeiro, M. T., Singh, S., and Guestrin, C. (2016). "Why should I trust you?" Explaining the predictions of any classifier. KDD.
- Rudin, C. (2019). Stop explaining black box machine learning models for high stakes decisions and use interpretable models instead. Nature Machine Intelligence, 1(5), 206-215.
- Selvaraju, R. R., Cogswell, M., Das, A., Vedantam, R., Parikh, D., and Batra, D. (2017). Grad-CAM: Visual explanations from deep networks via gradient-based localization. ICCV.
- Shrikumar, A., Greenside, P., and Kundaje, A. (2017). Learning important features through propagating activation differences. ICML.
- Simonyan, K., Vedaldi, A., and Zisserman, A. (2014). Deep inside convolutional networks: Visualising image classification models and saliency maps. ICLR Workshop.
- Slack, D., Hilgard, S., Jia, E., Singh, S., and Lakkaraju, H. (2020). Fooling LIME and SHAP: Adversarial attacks on post hoc explanation methods. AIES.
- Sundararajan, M., Taly, A., and Yan, Q. (2017). Axiomatic attribution for deep networks. ICML.
- Wachter, S., Mittelstadt, B., and Russell, C. (2018). Counterfactual explanations without opening the black box: Automated decisions and the GDPR. Harvard Journal of Law & Technology, 31(2).
- Wiegreffe, S., and Pinter, Y. (2019). Attention is not not explanation. EMNLP.