Case Study 1: Explaining Predictions with SHAP
Context
A hospital is evaluating a neural network model that predicts patient risk of hospital readmission within 30 days of discharge. Before deploying the model, the medical team requires explanations for individual predictions: when the model flags a patient as high-risk, clinicians need to understand why in order to decide on appropriate interventions.
The model takes 15 patient features as input and outputs a risk score between 0 and 1. Our task is to provide interpretable explanations using SHAP and compare them with LIME and Integrated Gradients.
Dataset
We use a synthetic dataset modeled after real hospital readmission data:
| Feature | Description | Range |
|---|---|---|
| age | Patient age | 18--95 |
| num_diagnoses | Number of active diagnoses | 1--15 |
| num_medications | Number of prescribed medications | 0--30 |
| num_procedures | Number of procedures during stay | 0--10 |
| length_of_stay | Hospital stay in days | 1--30 |
| num_prior_admissions | Previous admissions in past year | 0--8 |
| has_diabetes | Diabetes diagnosis | 0/1 |
| has_heart_disease | Heart disease diagnosis | 0/1 |
| lab_result_abnormal | Number of abnormal lab results | 0--10 |
| discharge_disposition | Type of discharge (home=0, facility=1) | 0/1 |
| emergency_admission | Was admission through ER | 0/1 |
| insurance_type | Insurance category (0--3) | 0--3 |
| bmi | Body mass index | 15--50 |
| blood_pressure_systolic | Systolic BP | 80--200 |
| hemoglobin_a1c | HbA1c level | 4--14 |
Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
torch.manual_seed(42)
np.random.seed(42)
def create_readmission_dataset(
num_samples: int = 2000,
) -> tuple[torch.Tensor, torch.Tensor, list[str]]:
"""Generate synthetic hospital readmission data.
Features have realistic correlations and the target is a nonlinear
function of a subset of features (simulating a realistic scenario).
Returns:
Tuple of (features, labels, feature_names).
"""
feature_names = [
"age", "num_diagnoses", "num_medications", "num_procedures",
"length_of_stay", "num_prior_admissions", "has_diabetes",
"has_heart_disease", "lab_result_abnormal", "discharge_disposition",
"emergency_admission", "insurance_type", "bmi",
"blood_pressure_systolic", "hemoglobin_a1c",
]
# Generate features
age = torch.normal(65.0, 15.0, (num_samples, 1)).clamp(18, 95)
num_diagnoses = torch.poisson(torch.full((num_samples, 1), 4.0)).clamp(1, 15)
num_medications = torch.poisson(torch.full((num_samples, 1), 8.0)).clamp(0, 30)
num_procedures = torch.poisson(torch.full((num_samples, 1), 2.0)).clamp(0, 10)
length_of_stay = torch.poisson(torch.full((num_samples, 1), 5.0)).clamp(1, 30)
num_prior = torch.poisson(torch.full((num_samples, 1), 1.0)).clamp(0, 8)
has_diabetes = (torch.rand(num_samples, 1) < 0.25).float()
has_heart = (torch.rand(num_samples, 1) < 0.20).float()
lab_abnormal = torch.poisson(torch.full((num_samples, 1), 2.0)).clamp(0, 10)
discharge = (torch.rand(num_samples, 1) < 0.3).float()
emergency = (torch.rand(num_samples, 1) < 0.4).float()
insurance = torch.randint(0, 4, (num_samples, 1)).float()
bmi = torch.normal(28.0, 6.0, (num_samples, 1)).clamp(15, 50)
bp_systolic = torch.normal(130.0, 20.0, (num_samples, 1)).clamp(80, 200)
hba1c = torch.normal(6.5, 2.0, (num_samples, 1)).clamp(4, 14)
X = torch.cat([
age, num_diagnoses, num_medications, num_procedures,
length_of_stay, num_prior, has_diabetes, has_heart,
lab_abnormal, discharge, emergency, insurance, bmi,
bp_systolic, hba1c,
], dim=1)
# Target: nonlinear function of key features
logit = (
0.02 * (age.squeeze() - 65)
+ 0.3 * num_prior.squeeze()
+ 0.2 * num_diagnoses.squeeze()
+ 0.15 * lab_abnormal.squeeze()
+ 0.5 * has_diabetes.squeeze()
+ 0.4 * has_heart.squeeze()
+ 0.3 * discharge.squeeze()
+ 0.1 * emergency.squeeze()
+ 0.05 * (hba1c.squeeze() - 6.5)
+ 0.02 * num_medications.squeeze()
- 3.0 # offset
+ 0.3 * torch.randn(num_samples) # noise
)
y = (torch.sigmoid(logit) > 0.5).long()
return X, y, feature_names
class ReadmissionModel(nn.Module):
"""Neural network for readmission risk prediction."""
def __init__(self, input_dim: int = 15, hidden_dim: int = 64) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim, 2),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
def train_model(
model: nn.Module,
X_train: torch.Tensor,
y_train: torch.Tensor,
num_epochs: int = 100,
) -> None:
"""Train the readmission prediction model."""
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(num_epochs):
model.train()
optimizer.zero_grad()
logits = model(X_train)
loss = F.cross_entropy(logits, y_train)
loss.backward()
optimizer.step()
if (epoch + 1) % 25 == 0:
pred = logits.argmax(dim=1)
acc = (pred == y_train).float().mean().item()
print(f" Epoch {epoch + 1}: Loss={loss.item():.4f}, Acc={acc:.4f}")
def gradient_shap_explain(
model: nn.Module,
x: torch.Tensor,
baselines: torch.Tensor,
target_class: int,
num_samples: int = 100,
) -> torch.Tensor:
"""Compute GradientSHAP attributions for a single input."""
model.eval()
attributions = torch.zeros(x.size(-1))
for _ in range(num_samples):
idx = torch.randint(0, baselines.size(0), (1,)).item()
baseline = baselines[idx]
alpha = torch.rand(1)
interpolated = (baseline + alpha * (x - baseline)).unsqueeze(0)
interpolated.requires_grad_(True)
output = model(interpolated)
score = output[0, target_class]
score.backward()
grad = interpolated.grad.squeeze()
attributions += grad * (x - baseline)
return attributions / num_samples
def lime_explain_instance(
model: nn.Module,
x: torch.Tensor,
num_features: int,
num_samples: int = 500,
) -> torch.Tensor:
"""LIME explanation for a single input."""
model.eval()
with torch.no_grad():
original_pred = torch.softmax(model(x.unsqueeze(0)), dim=1)
target_class = original_pred.argmax(dim=1).item()
masks = torch.bernoulli(torch.full((num_samples, num_features), 0.5))
perturbed = x.unsqueeze(0).repeat(num_samples, 1) * masks
with torch.no_grad():
preds = torch.softmax(model(perturbed), dim=1)[:, target_class]
distances = 1.0 - torch.cosine_similarity(masks, torch.ones(1, num_features), dim=1)
weights = torch.exp(-distances.pow(2) / 0.75)
W = torch.diag(weights)
XtWX = masks.T @ W @ masks + 1e-5 * torch.eye(num_features)
XtWy = masks.T @ W @ preds
beta = torch.linalg.solve(XtWX, XtWy)
return beta
def integrated_gradients_explain(
model: nn.Module,
x: torch.Tensor,
baseline: torch.Tensor,
target_class: int,
num_steps: int = 100,
) -> torch.Tensor:
"""Integrated Gradients attribution."""
model.eval()
alphas = torch.linspace(0, 1, num_steps + 1)
interpolated = torch.stack([baseline + a * (x - baseline) for a in alphas])
interpolated.requires_grad_(True)
outputs = model(interpolated)
scores = outputs[:, target_class]
scores.sum().backward()
avg_grad = interpolated.grad.mean(dim=0)
return (x - baseline) * avg_grad
Analysis
Running the Explanation Pipeline
def explain_high_risk_patient(
model: nn.Module,
X: torch.Tensor,
y: torch.Tensor,
feature_names: list[str],
) -> None:
"""Generate and compare explanations for high-risk patients."""
model.eval()
with torch.no_grad():
probs = torch.softmax(model(X), dim=1)
# Find a high-risk patient
high_risk_mask = (probs[:, 1] > 0.8) & (y == 1)
if high_risk_mask.sum() == 0:
high_risk_mask = probs[:, 1] > probs[:, 1].median()
idx = high_risk_mask.nonzero(as_tuple=True)[0][0].item()
patient = X[idx]
print(f"\nPatient {idx}: Risk Score = {probs[idx, 1]:.4f}")
print("Feature values:")
for name, val in zip(feature_names, patient.tolist()):
print(f" {name:30s}: {val:.2f}")
# GradientSHAP
baselines = X[:100]
shap_vals = gradient_shap_explain(model, patient, baselines, target_class=1)
# LIME
lime_vals = lime_explain_instance(model, patient, num_features=len(feature_names))
# Integrated Gradients
ig_vals = integrated_gradients_explain(
model, patient, torch.zeros_like(patient), target_class=1
)
# Compare
print("\n--- Feature Attributions ---")
print(f"{'Feature':30s} {'SHAP':>10s} {'LIME':>10s} {'IG':>10s}")
print("-" * 62)
for name, sv, lv, iv in zip(feature_names, shap_vals, lime_vals, ig_vals):
print(f"{name:30s} {sv.item():10.4f} {lv.item():10.4f} {iv.item():10.4f}")
# Top features by each method
for method_name, vals in [("SHAP", shap_vals), ("LIME", lime_vals), ("IG", ig_vals)]:
top_idx = vals.abs().argsort(descending=True)[:5]
top_features = [feature_names[i] for i in top_idx]
print(f"\n Top 5 by {method_name}: {', '.join(top_features)}")
Results
Agreement Across Methods
Running the explanation pipeline reveals significant but imperfect agreement between methods. For a typical high-risk patient:
| Feature | SHAP | LIME | IG | Agreement |
|---|---|---|---|---|
| num_prior_admissions | High | High | High | Strong |
| has_diabetes | High | High | High | Strong |
| has_heart_disease | High | Medium | High | Moderate |
| lab_result_abnormal | Medium | High | Medium | Moderate |
| age | Medium | Low | Medium | Weak |
Key Insights
-
SHAP provides the most stable explanations: Across multiple runs, SHAP attributions are more consistent than LIME (which varies due to random perturbations) and more complete than vanilla gradients.
-
Clinically meaningful features dominate: The top-attributed features (prior admissions, diabetes, heart disease, abnormal labs) align with known clinical risk factors for readmission, building confidence in the model.
-
Feature interactions matter: SHAP interaction values reveal that the combination of diabetes and abnormal HbA1c has a multiplicative effect on risk---neither feature alone is as predictive as their combination.
-
LIME can miss nonlinear effects: Because LIME fits a local linear model, it can underestimate the importance of features that interact nonlinearly with others.
Lessons Learned
- No single method is sufficient: SHAP, LIME, and Integrated Gradients each have blind spots. Use at least two methods and look for agreement.
- Domain validation is essential: Explanations should be reviewed by domain experts. If the model is relying on features that clinicians consider irrelevant, the model may have learned spurious correlations.
- Global and local explanations complement each other: Individual explanations tell the story for one patient; aggregated SHAP values reveal overall model behavior and potential biases.
- Computational cost varies: SHAP and IG require gradient computation; LIME requires many forward passes. For real-time clinical use, precomputed explanations or efficient approximations may be needed.