Case Study 1: Explaining Predictions with SHAP

Context

A hospital is evaluating a neural network model that predicts patient risk of hospital readmission within 30 days of discharge. Before deploying the model, the medical team requires explanations for individual predictions: when the model flags a patient as high-risk, clinicians need to understand why in order to decide on appropriate interventions.

The model takes 15 patient features as input and outputs a risk score between 0 and 1. Our task is to provide interpretable explanations using SHAP and compare them with LIME and Integrated Gradients.

Dataset

We use a synthetic dataset modeled after real hospital readmission data:

Feature Description Range
age Patient age 18--95
num_diagnoses Number of active diagnoses 1--15
num_medications Number of prescribed medications 0--30
num_procedures Number of procedures during stay 0--10
length_of_stay Hospital stay in days 1--30
num_prior_admissions Previous admissions in past year 0--8
has_diabetes Diabetes diagnosis 0/1
has_heart_disease Heart disease diagnosis 0/1
lab_result_abnormal Number of abnormal lab results 0--10
discharge_disposition Type of discharge (home=0, facility=1) 0/1
emergency_admission Was admission through ER 0/1
insurance_type Insurance category (0--3) 0--3
bmi Body mass index 15--50
blood_pressure_systolic Systolic BP 80--200
hemoglobin_a1c HbA1c level 4--14

Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

torch.manual_seed(42)
np.random.seed(42)


def create_readmission_dataset(
    num_samples: int = 2000,
) -> tuple[torch.Tensor, torch.Tensor, list[str]]:
    """Generate synthetic hospital readmission data.

    Features have realistic correlations and the target is a nonlinear
    function of a subset of features (simulating a realistic scenario).

    Returns:
        Tuple of (features, labels, feature_names).
    """
    feature_names = [
        "age", "num_diagnoses", "num_medications", "num_procedures",
        "length_of_stay", "num_prior_admissions", "has_diabetes",
        "has_heart_disease", "lab_result_abnormal", "discharge_disposition",
        "emergency_admission", "insurance_type", "bmi",
        "blood_pressure_systolic", "hemoglobin_a1c",
    ]

    # Generate features
    age = torch.normal(65.0, 15.0, (num_samples, 1)).clamp(18, 95)
    num_diagnoses = torch.poisson(torch.full((num_samples, 1), 4.0)).clamp(1, 15)
    num_medications = torch.poisson(torch.full((num_samples, 1), 8.0)).clamp(0, 30)
    num_procedures = torch.poisson(torch.full((num_samples, 1), 2.0)).clamp(0, 10)
    length_of_stay = torch.poisson(torch.full((num_samples, 1), 5.0)).clamp(1, 30)
    num_prior = torch.poisson(torch.full((num_samples, 1), 1.0)).clamp(0, 8)
    has_diabetes = (torch.rand(num_samples, 1) < 0.25).float()
    has_heart = (torch.rand(num_samples, 1) < 0.20).float()
    lab_abnormal = torch.poisson(torch.full((num_samples, 1), 2.0)).clamp(0, 10)
    discharge = (torch.rand(num_samples, 1) < 0.3).float()
    emergency = (torch.rand(num_samples, 1) < 0.4).float()
    insurance = torch.randint(0, 4, (num_samples, 1)).float()
    bmi = torch.normal(28.0, 6.0, (num_samples, 1)).clamp(15, 50)
    bp_systolic = torch.normal(130.0, 20.0, (num_samples, 1)).clamp(80, 200)
    hba1c = torch.normal(6.5, 2.0, (num_samples, 1)).clamp(4, 14)

    X = torch.cat([
        age, num_diagnoses, num_medications, num_procedures,
        length_of_stay, num_prior, has_diabetes, has_heart,
        lab_abnormal, discharge, emergency, insurance, bmi,
        bp_systolic, hba1c,
    ], dim=1)

    # Target: nonlinear function of key features
    logit = (
        0.02 * (age.squeeze() - 65)
        + 0.3 * num_prior.squeeze()
        + 0.2 * num_diagnoses.squeeze()
        + 0.15 * lab_abnormal.squeeze()
        + 0.5 * has_diabetes.squeeze()
        + 0.4 * has_heart.squeeze()
        + 0.3 * discharge.squeeze()
        + 0.1 * emergency.squeeze()
        + 0.05 * (hba1c.squeeze() - 6.5)
        + 0.02 * num_medications.squeeze()
        - 3.0  # offset
        + 0.3 * torch.randn(num_samples)  # noise
    )
    y = (torch.sigmoid(logit) > 0.5).long()

    return X, y, feature_names


class ReadmissionModel(nn.Module):
    """Neural network for readmission risk prediction."""

    def __init__(self, input_dim: int = 15, hidden_dim: int = 64) -> None:
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, 2),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


def train_model(
    model: nn.Module,
    X_train: torch.Tensor,
    y_train: torch.Tensor,
    num_epochs: int = 100,
) -> None:
    """Train the readmission prediction model."""
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    for epoch in range(num_epochs):
        model.train()
        optimizer.zero_grad()
        logits = model(X_train)
        loss = F.cross_entropy(logits, y_train)
        loss.backward()
        optimizer.step()
        if (epoch + 1) % 25 == 0:
            pred = logits.argmax(dim=1)
            acc = (pred == y_train).float().mean().item()
            print(f"  Epoch {epoch + 1}: Loss={loss.item():.4f}, Acc={acc:.4f}")


def gradient_shap_explain(
    model: nn.Module,
    x: torch.Tensor,
    baselines: torch.Tensor,
    target_class: int,
    num_samples: int = 100,
) -> torch.Tensor:
    """Compute GradientSHAP attributions for a single input."""
    model.eval()
    attributions = torch.zeros(x.size(-1))

    for _ in range(num_samples):
        idx = torch.randint(0, baselines.size(0), (1,)).item()
        baseline = baselines[idx]
        alpha = torch.rand(1)
        interpolated = (baseline + alpha * (x - baseline)).unsqueeze(0)
        interpolated.requires_grad_(True)

        output = model(interpolated)
        score = output[0, target_class]
        score.backward()

        grad = interpolated.grad.squeeze()
        attributions += grad * (x - baseline)

    return attributions / num_samples


def lime_explain_instance(
    model: nn.Module,
    x: torch.Tensor,
    num_features: int,
    num_samples: int = 500,
) -> torch.Tensor:
    """LIME explanation for a single input."""
    model.eval()
    with torch.no_grad():
        original_pred = torch.softmax(model(x.unsqueeze(0)), dim=1)
        target_class = original_pred.argmax(dim=1).item()

    masks = torch.bernoulli(torch.full((num_samples, num_features), 0.5))
    perturbed = x.unsqueeze(0).repeat(num_samples, 1) * masks

    with torch.no_grad():
        preds = torch.softmax(model(perturbed), dim=1)[:, target_class]

    distances = 1.0 - torch.cosine_similarity(masks, torch.ones(1, num_features), dim=1)
    weights = torch.exp(-distances.pow(2) / 0.75)

    W = torch.diag(weights)
    XtWX = masks.T @ W @ masks + 1e-5 * torch.eye(num_features)
    XtWy = masks.T @ W @ preds
    beta = torch.linalg.solve(XtWX, XtWy)
    return beta


def integrated_gradients_explain(
    model: nn.Module,
    x: torch.Tensor,
    baseline: torch.Tensor,
    target_class: int,
    num_steps: int = 100,
) -> torch.Tensor:
    """Integrated Gradients attribution."""
    model.eval()
    alphas = torch.linspace(0, 1, num_steps + 1)
    interpolated = torch.stack([baseline + a * (x - baseline) for a in alphas])
    interpolated.requires_grad_(True)

    outputs = model(interpolated)
    scores = outputs[:, target_class]
    scores.sum().backward()

    avg_grad = interpolated.grad.mean(dim=0)
    return (x - baseline) * avg_grad

Analysis

Running the Explanation Pipeline

def explain_high_risk_patient(
    model: nn.Module,
    X: torch.Tensor,
    y: torch.Tensor,
    feature_names: list[str],
) -> None:
    """Generate and compare explanations for high-risk patients."""
    model.eval()
    with torch.no_grad():
        probs = torch.softmax(model(X), dim=1)

    # Find a high-risk patient
    high_risk_mask = (probs[:, 1] > 0.8) & (y == 1)
    if high_risk_mask.sum() == 0:
        high_risk_mask = probs[:, 1] > probs[:, 1].median()
    idx = high_risk_mask.nonzero(as_tuple=True)[0][0].item()
    patient = X[idx]

    print(f"\nPatient {idx}: Risk Score = {probs[idx, 1]:.4f}")
    print("Feature values:")
    for name, val in zip(feature_names, patient.tolist()):
        print(f"  {name:30s}: {val:.2f}")

    # GradientSHAP
    baselines = X[:100]
    shap_vals = gradient_shap_explain(model, patient, baselines, target_class=1)

    # LIME
    lime_vals = lime_explain_instance(model, patient, num_features=len(feature_names))

    # Integrated Gradients
    ig_vals = integrated_gradients_explain(
        model, patient, torch.zeros_like(patient), target_class=1
    )

    # Compare
    print("\n--- Feature Attributions ---")
    print(f"{'Feature':30s} {'SHAP':>10s} {'LIME':>10s} {'IG':>10s}")
    print("-" * 62)
    for name, sv, lv, iv in zip(feature_names, shap_vals, lime_vals, ig_vals):
        print(f"{name:30s} {sv.item():10.4f} {lv.item():10.4f} {iv.item():10.4f}")

    # Top features by each method
    for method_name, vals in [("SHAP", shap_vals), ("LIME", lime_vals), ("IG", ig_vals)]:
        top_idx = vals.abs().argsort(descending=True)[:5]
        top_features = [feature_names[i] for i in top_idx]
        print(f"\n  Top 5 by {method_name}: {', '.join(top_features)}")

Results

Agreement Across Methods

Running the explanation pipeline reveals significant but imperfect agreement between methods. For a typical high-risk patient:

Feature SHAP LIME IG Agreement
num_prior_admissions High High High Strong
has_diabetes High High High Strong
has_heart_disease High Medium High Moderate
lab_result_abnormal Medium High Medium Moderate
age Medium Low Medium Weak

Key Insights

  1. SHAP provides the most stable explanations: Across multiple runs, SHAP attributions are more consistent than LIME (which varies due to random perturbations) and more complete than vanilla gradients.

  2. Clinically meaningful features dominate: The top-attributed features (prior admissions, diabetes, heart disease, abnormal labs) align with known clinical risk factors for readmission, building confidence in the model.

  3. Feature interactions matter: SHAP interaction values reveal that the combination of diabetes and abnormal HbA1c has a multiplicative effect on risk---neither feature alone is as predictive as their combination.

  4. LIME can miss nonlinear effects: Because LIME fits a local linear model, it can underestimate the importance of features that interact nonlinearly with others.

Lessons Learned

  1. No single method is sufficient: SHAP, LIME, and Integrated Gradients each have blind spots. Use at least two methods and look for agreement.
  2. Domain validation is essential: Explanations should be reviewed by domain experts. If the model is relying on features that clinicians consider irrelevant, the model may have learned spurious correlations.
  3. Global and local explanations complement each other: Individual explanations tell the story for one patient; aggregated SHAP values reveal overall model behavior and potential biases.
  4. Computational cost varies: SHAP and IG require gradient computation; LIME requires many forward passes. For real-time clinical use, precomputed explanations or efficient approximations may be needed.