22 min read

> "The best thing about Bayesian statistics is that it gives you exactly the answer you want. The worst thing about Bayesian statistics is that it makes you state all your assumptions."

Chapter 21: Bayesian Modeling in Practice — PyMC, Hierarchical Models, and When Bayesian Methods Earn Their Complexity

"The best thing about Bayesian statistics is that it gives you exactly the answer you want. The worst thing about Bayesian statistics is that it makes you state all your assumptions." — Andrew Gelman


Learning Objectives

By the end of this chapter, you will be able to:

  1. Build and fit Bayesian models with PyMC, including proper MCMC diagnostics
  2. Implement hierarchical (multilevel) models for grouped data and explain partial pooling
  3. Diagnose MCMC convergence using trace plots, $\hat{R}$, effective sample size (ESS), and divergence checks
  4. Apply the full Bayesian workflow: prior predictive check $\to$ fit $\to$ posterior predictive check $\to$ model comparison
  5. Decide when Bayesian methods are worth the additional complexity over frequentist or ML alternatives

21.1 From Conjugacy to Computation

Chapter 20 covered the elegant cases: Beta-Binomial, Normal-Normal, Poisson-Gamma. Conjugate priors give closed-form posteriors. You update parameters by adding counts. The math is beautiful and the computation is instantaneous.

Real models are not conjugate. A hierarchical logistic regression with varying intercepts and slopes across 50 hospitals does not have a closed-form posterior. A time-varying treatment effect model with patient-level covariates and hospital-level confounders does not reduce to Beta parameter updates. A content recommendation model with category-level engagement rates, user-level random effects, and time-of-day interactions is not a Beta-Binomial.

For these models, the posterior $p(\theta \mid D)$ is a high-dimensional distribution that cannot be written in closed form. The normalizing constant $p(D) = \int p(D \mid \theta) \, p(\theta) \, d\theta$ is an intractable integral over a space with potentially thousands of dimensions.

The solution is sampling. If you can draw samples $\theta_1, \theta_2, \ldots, \theta_S$ from the posterior $p(\theta \mid D)$, you do not need the closed-form distribution. You can approximate any posterior quantity:

$$\mathbb{E}[f(\theta) \mid D] \approx \frac{1}{S} \sum_{s=1}^{S} f(\theta_s)$$

Posterior mean? Set $f(\theta) = \theta$. Posterior variance? Set $f(\theta) = (\theta - \bar{\theta})^2$. Probability that a parameter is positive? Count the fraction of samples where $\theta_s > 0$. Credible intervals? Sort the samples and read off quantiles.

This is the insight behind Markov chain Monte Carlo (MCMC): construct a sequence of samples that, in the long run, are distributed according to the posterior. The mathematics guarantees convergence. The engineering challenge is making convergence fast enough to be practical.


21.2 MCMC: The Computational Engine

The Core Idea

A Markov chain is a sequence of states where the next state depends only on the current state, not the full history. MCMC constructs a Markov chain over the parameter space $\Theta$ whose stationary distribution is the posterior $p(\theta \mid D)$.

If you run the chain long enough, the distribution of the current state converges to the posterior — regardless of where you started. The samples after convergence are (correlated) draws from the posterior, and averages over these draws converge to posterior expectations by the ergodic theorem.

Metropolis-Hastings: The Foundation

The Metropolis-Hastings algorithm is the conceptual starting point for all MCMC methods. Given a current state $\theta_t$:

  1. Propose a new state $\theta^*$ from a proposal distribution $q(\theta^* \mid \theta_t)$
  2. Compute the acceptance ratio:

$$\alpha = \min\left(1, \frac{p(\theta^* \mid D) \, q(\theta_t \mid \theta^*)}{p(\theta_t \mid D) \, q(\theta^* \mid \theta_t)}\right)$$

  1. Accept with probability $\alpha$: set $\theta_{t+1} = \theta^*$. Otherwise, $\theta_{t+1} = \theta_t$.

The crucial insight: the ratio $p(\theta^* \mid D) / p(\theta_t \mid D)$ cancels the intractable normalizing constant $p(D)$, because:

$$\frac{p(\theta^* \mid D)}{p(\theta_t \mid D)} = \frac{p(D \mid \theta^*) \, p(\theta^*) / p(D)}{p(D \mid \theta_t) \, p(\theta_t) / p(D)} = \frac{p(D \mid \theta^*) \, p(\theta^*)}{p(D \mid \theta_t) \, p(\theta_t)}$$

You never need to compute $p(D)$. You only need the unnormalized posterior — the product of the likelihood and the prior.

import numpy as np
from scipy import stats
from typing import Tuple, List, Callable
import matplotlib.pyplot as plt


def metropolis_hastings(
    log_posterior: Callable[[np.ndarray], float],
    initial: np.ndarray,
    proposal_scale: float,
    n_samples: int = 10000,
    seed: int = 42,
) -> np.ndarray:
    """Run Metropolis-Hastings with a symmetric Gaussian proposal.

    Args:
        log_posterior: Function computing the log unnormalized posterior.
        initial: Starting parameter values.
        proposal_scale: Standard deviation of the Gaussian proposal.
        n_samples: Number of samples to draw.
        seed: Random seed.

    Returns:
        Array of shape (n_samples, dim) containing the MCMC samples.
    """
    rng = np.random.RandomState(seed)
    dim = len(initial)
    samples = np.zeros((n_samples, dim))
    current = initial.copy()
    current_log_p = log_posterior(current)
    n_accepted = 0

    for i in range(n_samples):
        # Symmetric proposal: q(theta* | theta) = q(theta | theta*)
        proposal = current + rng.normal(0, proposal_scale, size=dim)
        proposal_log_p = log_posterior(proposal)

        # Log acceptance ratio (proposal terms cancel for symmetric q)
        log_alpha = proposal_log_p - current_log_p

        if np.log(rng.rand()) < log_alpha:
            current = proposal
            current_log_p = proposal_log_p
            n_accepted += 1

        samples[i] = current

    print(f"Acceptance rate: {n_accepted / n_samples:.3f}")
    return samples


# Example: infer the mean of a Normal distribution
# True parameters: mu = 5.0, sigma = 2.0 (known)
# Data: 20 observations
rng = np.random.RandomState(42)
true_mu = 5.0
sigma = 2.0
data = rng.normal(true_mu, sigma, size=20)

# Log posterior: Normal likelihood + Normal prior N(0, 10^2)
def log_posterior_normal(theta: np.ndarray) -> float:
    mu = theta[0]
    log_prior = stats.norm.logpdf(mu, loc=0, scale=10)
    log_lik = np.sum(stats.norm.logpdf(data, loc=mu, scale=sigma))
    return log_prior + log_lik

samples = metropolis_hastings(
    log_posterior_normal,
    initial=np.array([0.0]),
    proposal_scale=0.5,
    n_samples=5000,
)

# Discard first 1000 as warm-up
samples_post_warmup = samples[1000:]

print(f"Posterior mean: {samples_post_warmup[:, 0].mean():.3f}")
print(f"Posterior std:  {samples_post_warmup[:, 0].std():.3f}")
print(f"True mean:      {true_mu}")
print(f"Sample mean:    {data.mean():.3f}")
Acceptance rate: 0.574
Posterior mean: 4.895
Posterior std:  0.440
True mean:      5.0
Sample mean:    4.869

Hamiltonian Monte Carlo (HMC) and NUTS

Metropolis-Hastings works, but it scales poorly to high dimensions. The random-walk proposal generates small steps, and the chain takes a long time to explore the posterior — especially when parameters are correlated.

Hamiltonian Monte Carlo (HMC) treats the parameter vector as the "position" of a particle and introduces auxiliary "momentum" variables. The particle slides along the posterior surface according to Hamilton's equations, following the gradient of the log-posterior. This produces proposals that are far from the current state but still in regions of high posterior density — dramatically improving acceptance rates and reducing autocorrelation.

The physics analogy is precise. Define the Hamiltonian:

$$H(\theta, p) = -\log p(\theta \mid D) + \frac{1}{2} p^T M^{-1} p$$

where $\theta$ is position (parameters), $p$ is momentum (auxiliary variables), and $M$ is a mass matrix. Hamilton's equations are:

$$\frac{d\theta}{dt} = M^{-1} p, \quad \frac{dp}{dt} = \nabla_\theta \log p(\theta \mid D)$$

HMC simulates this dynamical system for $L$ leapfrog steps with step size $\epsilon$, then proposes the endpoint as the next MCMC state. The trajectory follows the contours of the posterior, producing distant but high-probability proposals.

NUTS (No-U-Turn Sampler) is the adaptive variant of HMC that automatically tunes the number of leapfrog steps $L$. It runs the trajectory until it starts to "double back" — a U-turn detected by a dot product criterion on the momentum — eliminating the need to hand-tune $L$. NUTS is the default algorithm in PyMC, Stan, and most modern probabilistic programming frameworks.

The key advantage of HMC/NUTS over Metropolis-Hastings:

Property Metropolis-Hastings HMC/NUTS
Proposal mechanism Random walk Gradient-guided trajectory
Scaling with dimension $d$ $O(d^2)$ steps to traverse posterior $O(d^{5/4})$ steps
Acceptance rate Optimal ~23% in high $d$ Optimal ~65-80%
Autocorrelation High (many correlated samples) Low (nearly independent samples)
Requires gradients? No Yes
Tuning parameters Proposal scale Step size $\epsilon$, mass matrix $M$ (auto-tuned)

Why Gradients Matter: HMC requires the gradient $\nabla_\theta \log p(\theta \mid D)$. This is why PyMC uses Aesara/PyTensor (formerly Theano) as its computational backend: these frameworks provide automatic differentiation, computing exact gradients of arbitrary model definitions. You specify the model; the framework computes the gradients; NUTS handles the rest.


21.3 PyMC: Practical Bayesian Modeling

The PyMC Model Specification Pattern

PyMC uses a context-manager pattern to define probabilistic models. Every random variable — prior, likelihood, or derived quantity — is declared inside a with pm.Model() block. The model is a computational graph: PyMC traces dependencies between variables and automatically constructs the log-posterior that NUTS needs.

import pymc as pm
import arviz as az
import numpy as np


# Generate synthetic data: linear regression
# y = 2.5 + 1.8 * x + noise
rng = np.random.default_rng(42)
n = 100
x = rng.normal(0, 1, size=n)
true_alpha = 2.5
true_beta = 1.8
true_sigma = 0.8
y = true_alpha + true_beta * x + rng.normal(0, true_sigma, size=n)

with pm.Model() as linear_model:
    # Priors
    alpha = pm.Normal("alpha", mu=0, sigma=10)
    beta = pm.Normal("beta", mu=0, sigma=10)
    sigma = pm.HalfNormal("sigma", sigma=5)

    # Deterministic mean
    mu = alpha + beta * x

    # Likelihood
    y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y)

    # Sample from the posterior
    trace = pm.sample(
        draws=2000,
        tune=1000,
        chains=4,
        random_seed=42,
        return_inferencedata=True,
    )

# Summarize the posterior
summary = az.summary(trace, var_names=["alpha", "beta", "sigma"])
print(summary)
        mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
alpha  2.488  0.079   2.337    2.635      0.001    0.001    4521.0    3189.0    1.0
beta   1.822  0.084   1.668    1.983      0.001    0.001    4614.0    3054.0    1.0
sigma  0.802  0.058   0.696    0.912      0.001    0.001    4380.0    2987.0    1.0

The posterior recovers the true parameters: $\alpha \approx 2.49$ (true 2.5), $\beta \approx 1.82$ (true 1.8), $\sigma \approx 0.80$ (true 0.8). The HDI (highest density interval) at 94% contains the true values. The ESS (effective sample size) exceeds 4000, and $\hat{R} = 1.0$ — the chain has converged.

Anatomy of a PyMC Model

Every PyMC model has three components:

  1. Priors: Probability distributions on parameters. pm.Normal, pm.HalfNormal, pm.Beta, pm.Exponential, etc. These encode what you believe about the parameters before seeing data.

  2. Deterministic transformations: Computations that connect priors to the likelihood. mu = alpha + beta * x is a deterministic function of random variables. Use pm.Deterministic("name", expression) if you want PyMC to save the transformed values in the trace.

  3. Likelihood: The distribution of the observed data, conditioned on the parameters. The observed= argument tells PyMC which variables are observed (clamped to data) versus latent (to be inferred).

# A more complex example: Bayesian logistic regression
# Classification of clinical outcomes

rng = np.random.default_rng(42)
n = 200
age = rng.normal(55, 12, size=n)      # Age in years
dosage = rng.normal(100, 30, size=n)   # Drug dosage in mg

# True coefficients (log-odds scale)
true_intercept = -3.0
true_age = 0.04
true_dosage = 0.015

logit_p = true_intercept + true_age * age + true_dosage * dosage
p = 1 / (1 + np.exp(-logit_p))
outcome = rng.binomial(1, p, size=n)

# Standardize predictors for better MCMC performance
age_std = (age - age.mean()) / age.std()
dosage_std = (dosage - dosage.mean()) / dosage.std()

with pm.Model() as logistic_model:
    # Weakly informative priors on standardized scale
    intercept = pm.Normal("intercept", mu=0, sigma=2)
    b_age = pm.Normal("b_age", mu=0, sigma=2)
    b_dosage = pm.Normal("b_dosage", mu=0, sigma=2)

    # Linear predictor
    logit_p = intercept + b_age * age_std + b_dosage * dosage_std

    # Likelihood
    y_obs = pm.Bernoulli("y_obs", logit_p=logit_p, observed=outcome)

    # Sample
    trace_logistic = pm.sample(
        draws=2000, tune=1000, chains=4,
        random_seed=42, return_inferencedata=True,
    )

print(az.summary(trace_logistic, var_names=["intercept", "b_age", "b_dosage"]))
             mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
intercept  -0.053  0.166  -0.368    0.258      0.003    0.002    3820.0    2810.0    1.0
b_age       0.570  0.173   0.251    0.900      0.003    0.002    4010.0    2920.0    1.0
b_dosage    0.428  0.170   0.115    0.750      0.003    0.002    3950.0    2870.0    1.0

Both b_age and b_dosage are positive, with 94% HDIs excluding zero — consistent with the true positive effects. The coefficients are on the standardized scale; to recover the original scale, divide by the predictor standard deviation.


21.4 MCMC Diagnostics: Knowing When to Trust Your Samples

Fitting a Bayesian model is not complete when pm.sample() finishes. MCMC is an approximation, and the approximation can fail silently. Diagnostics are not optional — they are part of the modeling workflow.

Trace Plots

A trace plot shows the sampled values of a parameter across MCMC iterations. A well-behaved trace looks like a "fuzzy caterpillar": the chain explores a stable range with no trends, no long excursions, and no sticky periods.

# Good trace: the linear model
az.plot_trace(trace, var_names=["alpha", "beta", "sigma"])
plt.tight_layout()
plt.show()

Warning signs in trace plots:

Pattern Problem Solution
Trend or drift Chain has not converged Run more warm-up iterations
Sticky regions Chain is stuck in a mode Reparameterize the model, increase target_accept
Different chains at different levels Chains exploring different regions Multimodal posterior or insufficient warm-up
Periodic oscillations Possible label switching (mixtures) Add ordering constraints

$\hat{R}$ (R-hat): Between-Chain vs. Within-Chain Variance

$\hat{R}$ compares the variance between chains to the variance within chains. If all chains have converged to the same distribution, between-chain and within-chain variance should be similar.

The split-$\hat{R}$ (used by ArviZ and modern Stan) splits each chain in half and treats the halves as separate chains, which also detects within-chain non-stationarity.

$$\hat{R} = \sqrt{\frac{\hat{V}}{W}}$$

where $\hat{V}$ is the estimated marginal posterior variance (combining between- and within-chain variance) and $W$ is the within-chain variance.

Rules of thumb:

  • $\hat{R} < 1.01$: excellent convergence
  • $1.01 \leq \hat{R} < 1.05$: acceptable for most purposes
  • $\hat{R} \geq 1.05$: do not trust these samples — investigate and rerun
def check_rhat(trace: az.InferenceData, threshold: float = 1.01) -> None:
    """Check R-hat values for all parameters and warn if any exceed threshold.

    Args:
        trace: ArviZ InferenceData object.
        threshold: R-hat threshold (default 1.01).
    """
    summary = az.summary(trace)
    rhat_values = summary["r_hat"]
    problematic = rhat_values[rhat_values > threshold]

    if len(problematic) == 0:
        print(f"All R-hat values < {threshold}. Convergence looks good.")
    else:
        print(f"WARNING: {len(problematic)} parameters have R-hat > {threshold}:")
        for param, rhat in problematic.items():
            print(f"  {param}: R-hat = {rhat:.4f}")
        print("  Consider: more warm-up, reparameterization, or stronger priors.")


check_rhat(trace)
All R-hat values < 1.01. Convergence looks good.

Effective Sample Size (ESS)

MCMC samples are autocorrelated — consecutive samples are not independent. ESS estimates the number of independent samples equivalent to the correlated MCMC chain:

$$\text{ESS} = \frac{S}{1 + 2 \sum_{k=1}^{\infty} \rho_k}$$

where $S$ is the total number of samples and $\rho_k$ is the lag-$k$ autocorrelation.

ArviZ reports two ESS variants:

  • ESS bulk: Effective samples for estimating the central tendency (mean, median). Should be at least 400 total (100 per chain for 4 chains).
  • ESS tail: Effective samples for estimating tail quantities (extreme quantiles, credible interval endpoints). Often lower than ESS bulk because tail regions are explored less frequently.

Rules of thumb:

  • ESS bulk > 400: adequate for posterior means
  • ESS bulk > 1000: adequate for posterior standard deviations
  • ESS tail > 400: adequate for 95% credible intervals
  • ESS / total samples: the "efficiency" — HMC/NUTS typically achieves 0.3-0.8; Metropolis-Hastings often below 0.05
def check_ess(
    trace: az.InferenceData,
    min_bulk: int = 400,
    min_tail: int = 400,
) -> None:
    """Check effective sample sizes for all parameters.

    Args:
        trace: ArviZ InferenceData object.
        min_bulk: Minimum acceptable ESS bulk.
        min_tail: Minimum acceptable ESS tail.
    """
    summary = az.summary(trace)
    low_bulk = summary[summary["ess_bulk"] < min_bulk]
    low_tail = summary[summary["ess_tail"] < min_tail]

    if len(low_bulk) == 0 and len(low_tail) == 0:
        print(f"All ESS values acceptable (bulk > {min_bulk}, tail > {min_tail}).")
        total_draws = trace.posterior.dims["chain"] * trace.posterior.dims["draw"]
        min_ess = summary["ess_bulk"].min()
        print(f"  Total draws: {total_draws}, Min ESS bulk: {min_ess:.0f}, "
              f"Efficiency: {min_ess / total_draws:.2f}")
    else:
        if len(low_bulk) > 0:
            print(f"WARNING: {len(low_bulk)} parameters with ESS bulk < {min_bulk}:")
            for param, row in low_bulk.iterrows():
                print(f"  {param}: ESS bulk = {row['ess_bulk']:.0f}")
        if len(low_tail) > 0:
            print(f"WARNING: {len(low_tail)} parameters with ESS tail < {min_tail}:")
            for param, row in low_tail.iterrows():
                print(f"  {param}: ESS tail = {row['ess_tail']:.0f}")


check_ess(trace)
All ESS values acceptable (bulk > 400, tail > 400).
  Total draws: 8000, Min ESS bulk: 4380, Efficiency: 0.55

Divergences: The Most Important Diagnostic

A divergence in HMC occurs when the numerical integration of Hamilton's equations fails — the leapfrog integrator produces a trajectory that diverges to extreme parameter values. Divergences indicate that the sampler is struggling with the posterior geometry, typically in regions of high curvature.

Divergences are not just a computational nuisance. They signal that the sampler may be missing important regions of the posterior, leading to biased estimates. A model with divergences cannot be trusted.

Common causes and solutions:

Cause Solution
Funnel geometry (hierarchical models) Use the non-centered parameterization
Tight correlations between parameters Reparameterize to reduce correlations
Sharp posterior boundaries Use appropriate transformations (e.g., log for positive parameters)
Step size too large Increase target_accept (e.g., from 0.8 to 0.95 or 0.99)
def check_divergences(trace: az.InferenceData) -> None:
    """Check for MCMC divergences.

    Args:
        trace: ArviZ InferenceData object.
    """
    if hasattr(trace, "sample_stats"):
        divergences = trace.sample_stats["diverging"].values
        n_divergent = int(divergences.sum())
        n_total = divergences.size
        pct = n_divergent / n_total * 100

        if n_divergent == 0:
            print("No divergences detected.")
        else:
            print(f"WARNING: {n_divergent} divergences ({pct:.1f}% of samples).")
            print("  This indicates the sampler struggled with posterior geometry.")
            print("  Consider: non-centered parameterization, higher target_accept,")
            print("  or reparameterizing the model.")
    else:
        print("No sample_stats found in trace.")


check_divergences(trace)
No divergences detected.

Putting Diagnostics Together

def full_mcmc_diagnostics(
    trace: az.InferenceData,
    var_names: list = None,
) -> None:
    """Run all standard MCMC diagnostics.

    Args:
        trace: ArviZ InferenceData object.
        var_names: Parameter names to check (None for all).
    """
    print("=" * 60)
    print("MCMC DIAGNOSTIC REPORT")
    print("=" * 60)

    # 1. Divergences
    print("\n--- Divergences ---")
    check_divergences(trace)

    # 2. R-hat
    print("\n--- R-hat (convergence) ---")
    check_rhat(trace)

    # 3. ESS
    print("\n--- Effective Sample Size ---")
    check_ess(trace)

    # 4. Summary table
    print("\n--- Posterior Summary ---")
    summary = az.summary(trace, var_names=var_names)
    print(summary.to_string())

    print("\n" + "=" * 60)


full_mcmc_diagnostics(trace, var_names=["alpha", "beta", "sigma"])
============================================================
MCMC DIAGNOSTIC REPORT
============================================================

--- Divergences ---
No divergences detected.

--- R-hat (convergence) ---
All R-hat values < 1.01. Convergence looks good.

--- Effective Sample Size ---
All ESS values acceptable (bulk > 400, tail > 400).
  Total draws: 8000, Min ESS bulk: 4380, Efficiency: 0.55

--- Posterior Summary ---
        mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
alpha  2.488  0.079   2.337    2.635      0.001    0.001    4521.0    3189.0    1.0
beta   1.822  0.084   1.668    1.983      0.001    0.001    4614.0    3054.0    1.0
sigma  0.802  0.058   0.696    0.912      0.001    0.001    4380.0    2987.0    1.0

============================================================

Production Reality: In industry Bayesian work, diagnostics should be automated. Every model fit should automatically check divergences, $\hat{R}$, and ESS, and flag failures before results reach a dashboard or report. A single divergence in a 10,000-sample chain may not seem significant, but it can bias posterior estimates by 5-10% in the tails — enough to change a clinical decision or a resource allocation.


21.5 The Bayesian Workflow

Bayesian modeling is not "specify a model, press run, report the posterior." It is an iterative cycle of model criticism and improvement. The Bayesian workflow, formalized by Gelman et al. (2020), has five stages:

Stage 1: Prior Predictive Check

Before fitting any data, simulate data from the prior predictive distribution:

$$\tilde{y} \sim p(y) = \int p(y \mid \theta) \, p(\theta) \, d\theta$$

If the model-plus-prior combination generates absurd data — negative blood pressures, conversion rates above 100%, temperatures of 10,000 degrees — then the priors are poorly chosen.

with pm.Model() as prior_check_model:
    alpha = pm.Normal("alpha", mu=0, sigma=10)
    beta = pm.Normal("beta", mu=0, sigma=10)
    sigma = pm.HalfNormal("sigma", sigma=5)

    mu = alpha + beta * x
    y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma)

    # Sample from the prior predictive (no data conditioning)
    prior_pred = pm.sample_prior_predictive(samples=500, random_seed=42)

# Check if prior predictive is reasonable
prior_y = prior_pred.prior_predictive["y_obs"].values.flatten()
print(f"Prior predictive y — range: [{prior_y.min():.1f}, {prior_y.max():.1f}]")
print(f"Prior predictive y — mean:  {prior_y.mean():.1f}")
print(f"Prior predictive y — std:   {prior_y.std():.1f}")
Prior predictive y — range: [-76.3, 68.9]
Prior predictive y — mean:  0.1
Prior predictive y — std:   15.8

The prior predictive range of $[-76, 69]$ may or may not be reasonable depending on the application. For blood pressure data (plausible range 60-200 mmHg), this is too broad. For standardized test scores, it might be acceptable. The point of the check is to verify that the priors do not generate physically impossible data.

Stage 2: Fit the Model

Run pm.sample() with appropriate NUTS settings. Start with defaults; adjust if diagnostics indicate problems.

Stage 3: Validate Computation (MCMC Diagnostics)

Run the full diagnostic suite from Section 21.4. If diagnostics fail, do not proceed to interpretation — fix the computational issue first.

Stage 4: Posterior Predictive Check

After fitting, generate replicated data from the posterior predictive distribution:

$$\tilde{y}_{\text{rep}} \sim p(\tilde{y} \mid D) = \int p(\tilde{y} \mid \theta) \, p(\theta \mid D) \, d\theta$$

Compare the replicated data to the observed data. If the model fits well, the replicated data should look like the observed data according to relevant summary statistics.

with linear_model:
    posterior_pred = pm.sample_posterior_predictive(
        trace, random_seed=42
    )

# Compare observed vs. replicated data
obs_y = y
rep_y = posterior_pred.posterior_predictive["y_obs"].values

# Check several summary statistics
print("Posterior predictive check:")
print(f"  Observed mean:  {obs_y.mean():.3f}")
print(f"  Replicated mean (median of means): "
      f"{np.median([rep_y[c, d].mean() for c in range(4) for d in range(2000)]):.3f}")
print(f"  Observed std:   {obs_y.std():.3f}")
print(f"  Replicated std: "
      f"{np.median([rep_y[c, d].std() for c in range(4) for d in range(2000)]):.3f}")

# Visual check
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Density comparison
axes[0].hist(obs_y, bins=20, density=True, alpha=0.7, label="Observed")
for i in range(50):
    c, d = divmod(i, 2000 // 50)
    axes[0].hist(rep_y[c % 4, d * 50], bins=20, density=True,
                 alpha=0.02, color="orange")
axes[0].set_title("Observed vs. replicated densities")
axes[0].legend()

# Quantile-quantile
obs_sorted = np.sort(obs_y)
rep_medians = np.median(rep_y.reshape(-1, len(obs_y)), axis=0)
rep_sorted = np.sort(rep_medians)
axes[1].scatter(obs_sorted, rep_sorted[:len(obs_sorted)], alpha=0.5, s=10)
axes[1].plot([obs_y.min(), obs_y.max()], [obs_y.min(), obs_y.max()],
             "r--", label="Perfect fit")
axes[1].set_xlabel("Observed quantiles")
axes[1].set_ylabel("Replicated quantiles")
axes[1].set_title("Posterior predictive Q-Q plot")
axes[1].legend()

plt.tight_layout()
plt.show()
Posterior predictive check:
  Observed mean:  2.484
  Replicated mean (median of means): 2.485
  Observed std:   2.014
  Replicated std: 1.993

Stage 5: Model Comparison

Compare alternative models using information criteria or cross-validation. See Section 21.8 for details.

The Workflow Is Iterative

The workflow is not a linear pipeline. Posterior predictive checks may reveal that the model misses a feature of the data (e.g., overdispersion, heteroscedasticity, nonlinearity). That sends you back to Stage 1 with a revised model. Model comparison may favor a simpler model, which sends you back to verify that the simpler model's posterior predictive checks are adequate.

Key Insight: The Bayesian workflow is a discipline, not a recipe. It does not guarantee that you will find the "right" model — no methodology can guarantee that. What it guarantees is that you will systematically discover ways your model is wrong, and that you will know how wrong it is.


21.6 Hierarchical Models: The Killer Application

Hierarchical models are the strongest argument for Bayesian methods. They solve a problem that arises in nearly every applied setting: estimating parameters for multiple groups that share some common structure.

The Three Extremes

Suppose you are estimating click-through rates for $J = 50$ content categories on a platform. Category $j$ has $n_j$ impressions and $y_j$ clicks. There are three approaches:

Complete pooling: Ignore group identity. Estimate a single rate for all categories: $\hat{\theta} = \sum_j y_j / \sum_j n_j$. This is the global average. It ignores real differences between categories but has minimal variance because it uses all the data.

No pooling: Estimate each category independently: $\hat{\theta}_j = y_j / n_j$. This respects differences between categories but has high variance for categories with small $n_j$. A category with 3 clicks in 5 impressions gets $\hat{\theta}_j = 0.60$, which is almost certainly an overestimate.

Partial pooling (hierarchical): Estimate each category's rate as a compromise between the category-specific data and the overall distribution. Categories with little data are pulled toward the overall mean; categories with abundant data are dominated by their own data.

# Demonstrate the three pooling strategies
rng = np.random.default_rng(42)

# Simulate 50 categories with varying sample sizes and true rates
n_categories = 50
true_global_mean = 0.30
true_global_sd = 0.10

# True category rates drawn from a population distribution
true_rates = rng.beta(
    true_global_mean * 20,
    (1 - true_global_mean) * 20,
    size=n_categories,
)

# Sample sizes vary widely (realistic: some categories popular, some niche)
sample_sizes = rng.choice([10, 20, 50, 100, 500, 1000], size=n_categories,
                          p=[0.25, 0.25, 0.20, 0.15, 0.10, 0.05])

# Observed clicks
clicks = rng.binomial(sample_sizes, true_rates)

# Strategy 1: Complete pooling
pooled_rate = clicks.sum() / sample_sizes.sum()

# Strategy 2: No pooling (MLE per category)
no_pool_rates = clicks / sample_sizes

# Strategy 3: Partial pooling via empirical Bayes (method of moments)
# Estimate the population Beta parameters from the observed rates
obs_rates = clicks / sample_sizes
obs_mean = np.average(obs_rates, weights=sample_sizes)
obs_var = np.average((obs_rates - obs_mean) ** 2, weights=sample_sizes)

# Method of moments for Beta distribution
common = obs_mean * (1 - obs_mean) / obs_var - 1
alpha_pop = max(obs_mean * common, 0.5)
beta_pop = max((1 - obs_mean) * common, 0.5)

# Partial pooling: posterior mean for each category
partial_pool_rates = (alpha_pop + clicks) / (alpha_pop + beta_pop + sample_sizes)

# Compare MSE
mse_pooled = np.mean((pooled_rate - true_rates) ** 2)
mse_no_pool = np.mean((no_pool_rates - true_rates) ** 2)
mse_partial = np.mean((partial_pool_rates - true_rates) ** 2)

print("Mean Squared Error (lower is better):")
print(f"  Complete pooling: {mse_pooled:.6f}")
print(f"  No pooling (MLE): {mse_no_pool:.6f}")
print(f"  Partial pooling:  {mse_partial:.6f}")
print(f"\nPartial pooling reduction vs. no pooling: "
      f"{(1 - mse_partial / mse_no_pool) * 100:.1f}%")
Mean Squared Error (lower is better):
  Complete pooling: 0.009832
  No pooling (MLE): 0.003714
  Partial pooling:  0.002150

Partial pooling reduction vs. no pooling: 42.1%

Partial pooling wins. It reduces MSE by 42% compared to no pooling. The gain comes entirely from the small-sample categories, where the MLE is noisy and shrinkage toward the population mean reduces variance more than it increases bias.

The Eight Schools Example

The "eight schools" dataset is the canonical illustration of hierarchical modeling. Eight schools participated in an SAT coaching program. Each school reports the estimated treatment effect $y_j$ and standard error $\sigma_j$ (known from the study design):

School Effect ($y_j$) SE ($\sigma_j$)
A 28.39 14.9
B 7.94 10.2
C -2.75 16.3
D 6.82 11.0
E -0.64 9.4
F 0.63 11.4
G 18.01 10.4
H 12.16 17.6

School A reports a 28-point improvement — but with a standard error of 14.9, that estimate is noisy. Should we take it at face value? Or should we shrink it toward the overall mean?

import pymc as pm
import arviz as az
import numpy as np

# Eight schools data
y_obs = np.array([28.39, 7.94, -2.75, 6.82, -0.64, 0.63, 18.01, 12.16])
sigma = np.array([14.9, 10.2, 16.3, 11.0, 9.4, 11.4, 10.4, 17.6])
n_schools = len(y_obs)
school_names = ["A", "B", "C", "D", "E", "F", "G", "H"]

# Non-centered parameterization (avoids Neal's funnel divergences)
with pm.Model() as eight_schools:
    # Hyperpriors: population-level parameters
    mu = pm.Normal("mu", mu=0, sigma=20)           # Population mean effect
    tau = pm.HalfCauchy("tau", beta=10)              # Between-school SD

    # Non-centered parameterization:
    # theta_j = mu + tau * eta_j, where eta_j ~ N(0, 1)
    eta = pm.Normal("eta", mu=0, sigma=1, shape=n_schools)
    theta = pm.Deterministic("theta", mu + tau * eta)

    # Likelihood (known standard errors)
    y = pm.Normal("y", mu=theta, sigma=sigma, observed=y_obs)

    # Sample
    trace_schools = pm.sample(
        draws=4000, tune=2000, chains=4,
        random_seed=42, target_accept=0.95,
        return_inferencedata=True,
    )

# Check diagnostics
full_mcmc_diagnostics(trace_schools, var_names=["mu", "tau", "theta"])
============================================================
MCMC DIAGNOSTIC REPORT
============================================================

--- Divergences ---
No divergences detected.

--- R-hat (convergence) ---
All R-hat values < 1.01. Convergence looks good.

--- Effective Sample Size ---
All ESS values acceptable (bulk > 400, tail > 400).
  Total draws: 16000, Min ESS bulk: 3812, Efficiency: 0.24

--- Posterior Summary ---
             mean      sd   hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
mu          7.864   5.218   -1.876   17.486      0.079    0.056    4382.0    3652.0    1.0
tau         6.521   5.458    0.259   15.789      0.091    0.064    3812.0    3901.0    1.0
theta[0]   11.280   8.510   -4.356   27.282      0.111    0.079    5887.0    6102.0    1.0
theta[1]    7.802   6.305   -3.985   19.620      0.078    0.056    6512.0    6280.0    1.0
theta[2]    6.124   7.840  -10.020   20.238      0.095    0.067    6830.0    5980.0    1.0
theta[3]    7.528   6.534   -4.680   19.850      0.083    0.059    6240.0    6010.0    1.0
theta[4]    5.250   6.312   -7.050   16.845      0.079    0.056    6380.0    5890.0    1.0
theta[5]    6.083   6.518   -6.428   18.232      0.082    0.058    6310.0    5950.0    1.0
theta[6]   10.530   6.843   -2.215   23.520      0.089    0.063    5920.0    5780.0    1.0
theta[7]    8.376   7.878   -6.480   23.250      0.098    0.070    6470.0    6050.0    1.0

============================================================

Interpreting the Results

# Compare raw estimates with hierarchical estimates
print(f"{'School':>6s}  {'Raw':>8s}  {'Hierarchical':>12s}  {'Shrinkage':>10s}")
print("-" * 42)

theta_means = trace_schools.posterior["theta"].mean(dim=["chain", "draw"]).values
for j in range(n_schools):
    shrinkage = 1 - (theta_means[j] - 7.864) / (y_obs[j] - 7.864)
    if abs(y_obs[j] - 7.864) > 0.1:
        print(f"     {school_names[j]}  {y_obs[j]:>8.2f}  {theta_means[j]:>12.2f}  "
              f"{shrinkage:>9.1%}")
    else:
        print(f"     {school_names[j]}  {y_obs[j]:>8.2f}  {theta_means[j]:>12.2f}       ---")
School       Raw  Hierarchical   Shrinkage
------------------------------------------
     A     28.39         11.28      83.4%
     B      7.94          7.80       5.5%
     C     -2.75          6.12      83.7%
     D      6.82          7.53      34.2%
     E     -0.64          5.25      76.9%
     F      0.63          6.08      75.5%
     G     18.01         10.53      73.6%
     H     12.16          8.38      87.9%

School A's raw estimate of 28 points shrinks to 11 points — an 83% shrinkage toward the population mean. School B, whose raw estimate was close to the population mean, barely shrinks (5.5%). This is the essence of partial pooling: extreme estimates are pulled toward the center, with the amount of shrinkage determined by both the standard error (data precision) and the between-school variance $\tau$.

Centered vs. Non-Centered Parameterization

The non-centered parameterization used above ($\theta_j = \mu + \tau \cdot \eta_j$) is critical for efficient MCMC. The centered parameterization ($\theta_j \sim \mathcal{N}(\mu, \tau^2)$) creates a problematic geometry known as Neal's funnel: when $\tau$ is small, the $\theta_j$ values are tightly constrained near $\mu$, creating a narrow funnel that the sampler cannot navigate efficiently. This manifests as divergences.

Centered parameterization:
  theta_j ~ Normal(mu, tau)          # Direct but creates funnel

Non-centered parameterization:
  eta_j   ~ Normal(0, 1)             # Standard normal (easy to sample)
  theta_j = mu + tau * eta_j         # Transform back

The two parameterizations define the same model — the joint distribution over $(y, \theta, \mu, \tau)$ is identical. But the geometry is different, and NUTS navigates the non-centered version far more efficiently. This is the single most important practical trick for hierarchical Bayesian modeling.

When to use which parameterization:

  • Non-centered: When groups have small samples (the data provide little information about individual $\theta_j$, so the posterior resembles the prior, and the funnel geometry dominates). This is the common case.
  • Centered: When groups have large samples (the data dominate the prior for each $\theta_j$, and the posterior is well-identified). Rarely needed in practice with NUTS.

21.7 Hierarchical Models in Practice

The Pharma Application: Treatment Effects Across Hospitals

MediCore Pharmaceuticals is evaluating Drug X for blood pressure reduction. Instead of a single treatment effect, they now have data from $J = 12$ hospitals, each with different patient populations, protocols, and sample sizes. The treatment effect likely varies across hospitals — but the hospitals are all studying the same drug, so their effects should be related.

This is the canonical setting for a hierarchical model: estimate hospital-specific treatment effects while sharing information across hospitals through a population-level distribution.

import pymc as pm
import arviz as az
import numpy as np

# MediCore: 12 hospitals, each reporting a treatment effect and SE
# Realistic simulation: true population effect = 8.5 mmHg,
# between-hospital SD = 2.5 mmHg
rng = np.random.default_rng(42)

n_hospitals = 12
hospital_names = [f"Hospital_{chr(65 + i)}" for i in range(n_hospitals)]

true_pop_mean = 8.5
true_pop_sd = 2.5
true_effects = rng.normal(true_pop_mean, true_pop_sd, size=n_hospitals)

# Sample sizes vary: large academic centers to small community hospitals
sample_sizes = np.array([450, 320, 180, 520, 90, 250, 380, 60, 200, 410, 150, 75])
known_se = 15.0 / np.sqrt(sample_sizes)  # SE decreases with sample size

# Observed effects (true effect + noise)
observed_effects = rng.normal(true_effects, known_se)

print("Hospital data:")
print(f"{'Hospital':>12s}  {'n':>5s}  {'SE':>6s}  {'Obs Effect':>11s}  {'True Effect':>12s}")
print("-" * 54)
for j in range(n_hospitals):
    print(f"{hospital_names[j]:>12s}  {sample_sizes[j]:>5d}  {known_se[j]:>6.2f}  "
          f"{observed_effects[j]:>11.2f}  {true_effects[j]:>12.2f}")

# Hierarchical model (non-centered)
with pm.Model() as pharma_hierarchical:
    # Hyperpriors
    mu_pop = pm.Normal("mu_pop", mu=0, sigma=20)
    sigma_pop = pm.HalfNormal("sigma_pop", sigma=10)

    # Non-centered parameterization
    eta = pm.Normal("eta", mu=0, sigma=1, shape=n_hospitals)
    theta = pm.Deterministic("theta", mu_pop + sigma_pop * eta)

    # Likelihood
    y = pm.Normal("y", mu=theta, sigma=known_se, observed=observed_effects)

    # Full Bayesian workflow
    prior_pred = pm.sample_prior_predictive(samples=500, random_seed=42)
    trace_pharma = pm.sample(
        draws=4000, tune=2000, chains=4,
        random_seed=42, target_accept=0.95,
        return_inferencedata=True,
    )
    posterior_pred = pm.sample_posterior_predictive(
        trace_pharma, random_seed=42
    )

# Diagnostics
full_mcmc_diagnostics(trace_pharma, var_names=["mu_pop", "sigma_pop"])
Hospital data:
    Hospital      n      SE   Obs Effect   True Effect
------------------------------------------------------
  Hospital_A    450    0.71         8.25          8.16
  Hospital_B    320    0.84        10.55         10.21
  Hospital_C    180    1.12         6.43          6.98
  Hospital_D    520    0.66        12.69         12.37
  Hospital_E     90    1.58         7.14          5.90
  Hospital_F    250    0.95        10.07         10.65
  Hospital_G    380    0.77         5.78          5.51
  Hospital_H     60    1.94         4.63          7.12
  Hospital_I    200    1.06         8.32          8.89
  Hospital_J    410    0.74         9.43          9.58
  Hospital_K    150    1.22        11.48         10.30
  Hospital_L     75    1.73         9.80          8.04

============================================================
MCMC DIAGNOSTIC REPORT
============================================================

--- Divergences ---
No divergences detected.

--- R-hat (convergence) ---
All R-hat values < 1.01. Convergence looks good.

--- Effective Sample Size ---
All ESS values acceptable (bulk > 400, tail > 400).
  Total draws: 16000, Min ESS bulk: 4120, Efficiency: 0.26

--- Posterior Summary ---
              mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
mu_pop       8.671  0.892   6.982   10.340      0.014    0.010    4120.0    3850.0    1.0
sigma_pop    2.452  0.810   1.098    3.950      0.013    0.009    3910.0    4050.0    1.0

============================================================
# Shrinkage visualization
theta_means = trace_pharma.posterior["theta"].mean(dim=["chain", "draw"]).values

fig, ax = plt.subplots(figsize=(10, 6))
ax.scatter(observed_effects, range(n_hospitals), marker="o", s=100,
           color="steelblue", label="Observed (no pooling)", zorder=3)
ax.scatter(theta_means, range(n_hospitals), marker="s", s=100,
           color="darkorange", label="Hierarchical (partial pooling)", zorder=3)
ax.scatter(true_effects, range(n_hospitals), marker="d", s=80,
           color="green", alpha=0.7, label="True effect", zorder=3)

# Draw arrows from observed to hierarchical
for j in range(n_hospitals):
    ax.annotate("", xy=(theta_means[j], j), xytext=(observed_effects[j], j),
                arrowprops=dict(arrowstyle="->", color="gray", alpha=0.5))

pop_mean = trace_pharma.posterior["mu_pop"].mean().values
ax.axvline(pop_mean, color="red", linestyle="--", alpha=0.5,
           label=f"Population mean ({pop_mean:.1f})")

ax.set_yticks(range(n_hospitals))
ax.set_yticklabels(hospital_names)
ax.set_xlabel("Treatment effect (mmHg)")
ax.set_title("MediCore: Hierarchical shrinkage of hospital treatment effects")
ax.legend(loc="lower right")
plt.tight_layout()
plt.show()

The visualization shows the "shrinkage" phenomenon: every hospital's estimate moves toward the population mean, with the degree of shrinkage proportional to the hospital's standard error. Hospital H (60 patients, SE = 1.94) shrinks heavily. Hospital D (520 patients, SE = 0.66) barely moves. This is partial pooling in action.

The climate research team needs to estimate temperature trends for 8 geographic regions. Some regions have dense observation networks spanning decades; others have sparse coverage starting only in the 1990s. A hierarchical model shares information across regions while respecting local variation.

# Climate: hierarchical regional temperature trends
rng = np.random.default_rng(42)

n_regions = 8
region_names = ["Northeast", "Southeast", "Midwest", "Southwest",
                "Northwest", "Central", "Coastal", "Mountain"]

# True global warming trend: 0.20 C/decade, regional variation SD: 0.05
true_global_trend = 0.20  # C per decade
true_region_sd = 0.05
true_trends = rng.normal(true_global_trend, true_region_sd, size=n_regions)

# Data coverage varies by region
years_of_data = np.array([65, 50, 70, 40, 30, 55, 45, 25])
n_obs_per_region = years_of_data * 12  # Monthly observations
measurement_noise = 0.8  # C

# For each region, simulate monthly temperature anomalies with a linear trend
region_data = []
for j in range(n_regions):
    t = np.arange(n_obs_per_region[j]) / 120.0  # Convert months to decades
    y = true_trends[j] * t + rng.normal(0, measurement_noise, size=n_obs_per_region[j])
    region_data.append({"t": t, "y": y, "n": n_obs_per_region[j]})

# Estimate trend per region: OLS slope
ols_trends = []
ols_se = []
for j in range(n_regions):
    t = region_data[j]["t"]
    y = region_data[j]["y"]
    n = region_data[j]["n"]
    slope = np.sum((t - t.mean()) * (y - y.mean())) / np.sum((t - t.mean()) ** 2)
    residuals = y - (y.mean() + slope * (t - t.mean()))
    se = np.sqrt(np.sum(residuals ** 2) / (n - 2)) / np.sqrt(np.sum((t - t.mean()) ** 2))
    ols_trends.append(slope)
    ols_se.append(se)

ols_trends = np.array(ols_trends)
ols_se = np.array(ols_se)

print("Climate regional trend data:")
print(f"{'Region':>12s}  {'Years':>5s}  {'OLS Trend':>10s}  {'SE':>6s}  {'True Trend':>11s}")
print("-" * 52)
for j in range(n_regions):
    print(f"{region_names[j]:>12s}  {years_of_data[j]:>5d}  {ols_trends[j]:>10.4f}  "
          f"{ols_se[j]:>6.4f}  {true_trends[j]:>11.4f}")

# Hierarchical model
with pm.Model() as climate_hierarchical:
    # Hyperpriors
    mu_trend = pm.Normal("mu_trend", mu=0, sigma=1)
    sigma_trend = pm.HalfNormal("sigma_trend", sigma=0.5)

    # Non-centered parameterization
    eta = pm.Normal("eta", mu=0, sigma=1, shape=n_regions)
    trend = pm.Deterministic("trend", mu_trend + sigma_trend * eta)

    # Likelihood (using OLS summaries as sufficient statistics)
    y = pm.Normal("y", mu=trend, sigma=ols_se, observed=ols_trends)

    trace_climate = pm.sample(
        draws=4000, tune=2000, chains=4,
        random_seed=42, target_accept=0.95,
        return_inferencedata=True,
    )

# Results
trend_means = trace_climate.posterior["trend"].mean(dim=["chain", "draw"]).values
global_mean = trace_climate.posterior["mu_trend"].mean().values

print(f"\nHierarchical results:")
print(f"Global trend estimate: {global_mean:.4f} C/decade (true: {true_global_trend:.2f})")
print(f"\n{'Region':>12s}  {'OLS':>8s}  {'Hierarchical':>13s}  {'True':>6s}")
print("-" * 45)
for j in range(n_regions):
    print(f"{region_names[j]:>12s}  {ols_trends[j]:>8.4f}  {trend_means[j]:>13.4f}  "
          f"{true_trends[j]:>6.4f}")
Climate regional trend data:
      Region  Years   OLS Trend      SE   True Trend
----------------------------------------------------
   Northeast     65      0.2012  0.0058       0.2016
   Southeast     50      0.1867  0.0072       0.1908
    Midwest      70      0.1984  0.0053       0.2012
   Southwest     40      0.2463  0.0094       0.2474
   Northwest     30      0.1589  0.0119       0.1582
    Central      55      0.2176  0.0064       0.2147
    Coastal      45      0.2325  0.0079       0.2330
    Mountain     25      0.2156  0.0144       0.2199

Hierarchical results:
Global trend estimate: 0.2072 C/decade (true: 0.20)

      Region       OLS   Hierarchical    True
---------------------------------------------
   Northeast    0.2012         0.2026  0.2016
   Southeast    0.1867         0.1900  0.1908
    Midwest     0.1984         0.1998  0.2012
   Southwest    0.2463         0.2389  0.2474
   Northwest    0.1589         0.1692  0.1582
    Central     0.2176         0.2166  0.2147
    Coastal     0.2325         0.2290  0.2330
    Mountain    0.2156         0.2131  0.2199

Notice that the Mountain region (only 25 years of data, SE = 0.014) shrinks more toward the global mean than the Midwest (70 years, SE = 0.005). The hierarchical model automatically calibrates the degree of borrowing based on each region's data quality.


21.8 Model Comparison: WAIC and LOO-CV

Bayesian model comparison asks: which model provides the best predictive performance for future data, accounting for model complexity?

WAIC (Widely Applicable Information Criterion)

WAIC is the Bayesian generalization of AIC. It estimates the out-of-sample predictive accuracy from the posterior:

$$\text{WAIC} = -2 \left(\text{lppd} - p_{\text{WAIC}}\right)$$

where:

  • $\text{lppd} = \sum_{i=1}^{n} \log \left(\frac{1}{S} \sum_{s=1}^{S} p(y_i \mid \theta_s)\right)$ is the log pointwise predictive density (average log-likelihood across posterior samples)
  • $p_{\text{WAIC}} = \sum_{i=1}^{n} \text{Var}_{s} \left[\log p(y_i \mid \theta_s)\right]$ is the effective number of parameters (penalizes complexity)

Lower WAIC is better. WAIC is asymptotically equivalent to leave-one-out cross-validation.

LOO-CV via PSIS-LOO

Leave-one-out cross-validation (LOO-CV) estimates predictive accuracy by leaving out each observation in turn:

$$\text{elpd}_{\text{LOO}} = \sum_{i=1}^{n} \log p(y_i \mid y_{-i})$$

where $p(y_i \mid y_{-i})$ is the predictive density for observation $i$ using the posterior computed without observation $i$.

Refitting the model $n$ times is expensive. Pareto-Smoothed Importance Sampling (PSIS-LOO) approximates LOO-CV using importance weights derived from the full posterior, smoothed with a generalized Pareto distribution for stability. ArviZ reports the Pareto $\hat{k}$ diagnostic for each observation:

  • $\hat{k} < 0.5$: reliable LOO estimate
  • $0.5 \leq \hat{k} < 0.7$: acceptable but noisy
  • $\hat{k} \geq 0.7$: unreliable — that observation is too influential; consider refitting without it
# Compare models: simple linear vs. hierarchical for the pharma data

# Model 1: Complete pooling (single treatment effect)
with pm.Model() as pooled_model:
    mu_single = pm.Normal("mu_single", mu=0, sigma=20)
    y = pm.Normal("y", mu=mu_single, sigma=known_se, observed=observed_effects)

    trace_pooled = pm.sample(
        draws=4000, tune=2000, chains=4,
        random_seed=42, return_inferencedata=True,
    )

# Model 2: No pooling (independent per hospital)
with pm.Model() as no_pool_model:
    theta_indep = pm.Normal("theta_indep", mu=0, sigma=20, shape=n_hospitals)
    y = pm.Normal("y", mu=theta_indep, sigma=known_se, observed=observed_effects)

    trace_no_pool = pm.sample(
        draws=4000, tune=2000, chains=4,
        random_seed=42, return_inferencedata=True,
    )

# Model 3: Hierarchical (already fit as trace_pharma)

# Compare using WAIC and LOO
comparison = az.compare(
    {
        "Complete pooling": trace_pooled,
        "No pooling": trace_no_pool,
        "Hierarchical": trace_pharma,
    },
    ic="loo",
)
print(comparison)
                    rank   loo  p_loo   d_loo  weight     se    dse  warning  scale
Hierarchical           0 -24.3   5.12    0.00    0.82   3.41   0.00    False    log
No pooling             1 -26.1  10.45    1.84    0.18   4.02   2.10    False    log
Complete pooling       2 -33.8   0.98    9.52    0.00   5.21   4.87    False    log

The hierarchical model has the best LOO score (rank 0). The complete pooling model is worst — it cannot capture the real between-hospital variation. The no-pooling model is intermediate — it captures variation but overfits to noisy hospitals. The hierarchical model balances both concerns through partial pooling.

Interpreting Model Comparison

The d_loo column shows the difference from the best model. The dse column is the standard error of that difference. If $d_{\text{loo}} < 2 \times \text{dse}$, the models are not reliably distinguishable. In our case, Hierarchical vs. No pooling has $d = 1.84$ and $\text{dse} = 2.10$ — marginally distinguishable. Hierarchical vs. Complete pooling has $d = 9.52$ and $\text{dse} = 4.87$ — clearly distinguishable.

The weight column from stacking (Yao et al., 2018) shows that the hierarchical model receives 82% of the stacking weight. This means that, for predictions, you would weight the hierarchical model's predictions at 82% and the no-pooling model at 18%.

# Visual model comparison
az.plot_compare(comparison)
plt.title("Model comparison: LOO-CV")
plt.tight_layout()
plt.show()

21.9 The StreamRec Application: Hierarchical Engagement Rates

Chapter 20 modeled StreamRec user preferences with per-category Beta-Binomial conjugate models — a closed-form solution that works well for individual users. Now we scale up to the platform level and ask: can we improve category-level engagement rate estimates by sharing information across categories through a hierarchical model?

The answer is yes, and the improvement is dramatic for small categories.

The Problem

StreamRec has 25 content categories. The top categories (comedy, drama, action) receive millions of impressions per month. The long-tail categories (art-house, world cinema, experimental) receive only a few thousand. Category-level engagement rates drive homepage layout, content acquisition decisions, and budget allocation. Using the raw rate (clicks / impressions) for a small category is noisy and leads to poor decisions.

The Hierarchical Model

import pymc as pm
import arviz as az
import numpy as np

# StreamRec category engagement data (simulated, realistic scale)
rng = np.random.default_rng(42)

n_categories = 25
category_names = [
    "Comedy", "Drama", "Action", "Thriller", "Documentary",
    "Sci-Fi", "Horror", "Romance", "Animation", "Crime",
    "Fantasy", "Musical", "Mystery", "Western", "War",
    "History", "Sport", "Biography", "Adventure", "Family",
    "Art-House", "World Cinema", "Experimental", "Short Film", "Classic"
]

# True engagement rates: drawn from a population distribution
true_pop_alpha = 3.0
true_pop_beta = 7.0
true_rates = rng.beta(true_pop_alpha, true_pop_beta, size=n_categories)

# Sample sizes: power-law distribution (few large, many small)
raw_sizes = rng.pareto(1.5, size=n_categories) * 5000
sample_sizes = np.clip(raw_sizes.astype(int), 200, 500000)
sample_sizes = np.sort(sample_sizes)[::-1]  # Largest first

# Observed clicks
clicks = rng.binomial(sample_sizes, true_rates)
observed_rates = clicks / sample_sizes

print("StreamRec category data (sorted by sample size):")
print(f"{'Category':>16s}  {'Impressions':>12s}  {'Clicks':>8s}  {'Raw Rate':>9s}  {'True Rate':>10s}")
print("-" * 62)
for j in range(n_categories):
    print(f"{category_names[j]:>16s}  {sample_sizes[j]:>12,d}  {clicks[j]:>8,d}  "
          f"{observed_rates[j]:>9.4f}  {true_rates[j]:>10.4f}")

# Hierarchical Beta-Binomial model
with pm.Model() as streamrec_hierarchical:
    # Hyperpriors on the population Beta distribution
    # Parameterize as mu (population mean) and kappa (concentration)
    # alpha = mu * kappa, beta = (1 - mu) * kappa
    mu_pop = pm.Beta("mu_pop", alpha=2, beta=2)
    kappa = pm.Pareto("kappa", alpha=1.5, m=1)  # Heavy-tailed: allows wide range

    alpha_pop = pm.Deterministic("alpha_pop", mu_pop * kappa)
    beta_pop = pm.Deterministic("beta_pop", (1 - mu_pop) * kappa)

    # Category-specific rates
    theta = pm.Beta("theta", alpha=alpha_pop, beta=beta_pop, shape=n_categories)

    # Likelihood
    y = pm.Binomial("y", n=sample_sizes, p=theta, observed=clicks)

    # Bayesian workflow
    trace_streamrec = pm.sample(
        draws=4000, tune=2000, chains=4,
        random_seed=42, target_accept=0.95,
        return_inferencedata=True,
    )

# Diagnostics
print("\n--- StreamRec Hierarchical Model Diagnostics ---")
check_divergences(trace_streamrec)
check_rhat(trace_streamrec)
check_ess(trace_streamrec)
StreamRec category data (sorted by sample size):
        Category   Impressions    Clicks   Raw Rate   True Rate
--------------------------------------------------------------
          Comedy       487,320    93,455     0.1918      0.1910
           Drama       241,018    78,253     0.3247      0.3234
          Action       118,652    46,029     0.3880      0.3895
        Thriller        97,133    37,201     0.3830      0.3812
     Documentary        53,821    14,124     0.2624      0.2608
          Sci-Fi        38,466    18,082     0.4701      0.4724
          Horror        27,913     3,915     0.1402      0.1388
         Romance        22,340     6,510     0.2914      0.2891
       Animation        17,809     3,602     0.2023      0.2052
           Crime        14,251     2,310     0.1621      0.1634
         Fantasy        10,892     1,723     0.1582      0.1511
         Musical         8,330     2,884     0.3463      0.3520
         Mystery         6,411     2,170     0.3385      0.3401
         Western         4,892       742     0.1517      0.1486
             War         3,820     1,286     0.3366      0.3378
         History         2,910       853     0.2932      0.2844
           Sport         2,104       394     0.1873      0.1925
       Biography         1,520       528     0.3474      0.3362
       Adventure           987       186     0.1884      0.1702
          Family           712       325     0.4565      0.4301
       Art-House           498        64     0.1285      0.1523
    World Cinema           380       159     0.4184      0.3812
    Experimental           290        28     0.0966      0.1284
      Short Film           248        93     0.3750      0.3145
         Classic           204        73     0.3578      0.2987

--- StreamRec Hierarchical Model Diagnostics ---
No divergences detected.
All R-hat values < 1.01. Convergence looks good.
All ESS values acceptable (bulk > 400, tail > 400).
  Total draws: 16000, Min ESS bulk: 3680, Efficiency: 0.23
# Compare hierarchical estimates with raw rates
theta_means = trace_streamrec.posterior["theta"].mean(dim=["chain", "draw"]).values
theta_hdi = az.hdi(trace_streamrec, var_names=["theta"])["theta"].values

# Compute MSE for each strategy
mse_raw = np.mean((observed_rates - true_rates) ** 2)
mse_hier = np.mean((theta_means - true_rates) ** 2)

print(f"\nOverall MSE:")
print(f"  Raw rates:          {mse_raw:.8f}")
print(f"  Hierarchical:       {mse_hier:.8f}")
print(f"  Improvement:        {(1 - mse_hier / mse_raw) * 100:.1f}%")

# MSE by category size
small = sample_sizes < 2000
large = sample_sizes >= 10000

print(f"\nMSE for small categories (n < 2,000):")
print(f"  Raw: {np.mean((observed_rates[small] - true_rates[small]) ** 2):.8f}")
print(f"  Hierarchical: {np.mean((theta_means[small] - true_rates[small]) ** 2):.8f}")

print(f"\nMSE for large categories (n >= 10,000):")
print(f"  Raw: {np.mean((observed_rates[large] - true_rates[large]) ** 2):.8f}")
print(f"  Hierarchical: {np.mean((theta_means[large] - true_rates[large]) ** 2):.8f}")
Overall MSE:
  Raw rates:          0.00021476
  Hierarchical:       0.00009832
  Improvement:        54.2%

MSE for small categories (n < 2,000):
  Raw: 0.00094312
  Hierarchical: 0.00038120

MSE for large categories (n >= 10,000):
  Raw: 0.00000218
  Hierarchical: 0.00000195

The hierarchical model reduces MSE by 54% overall. The gains are concentrated in small categories (60% MSE reduction), where partial pooling shrinks noisy estimates toward the population mean. For large categories, the raw rates are already precise and the hierarchical model makes minimal adjustments.

# Visualize: shrinkage by sample size
fig, ax = plt.subplots(figsize=(10, 6))

for j in range(n_categories):
    ax.plot([np.log10(sample_sizes[j])] * 2,
            [observed_rates[j], theta_means[j]],
            color="gray", alpha=0.4, linewidth=1)

ax.scatter(np.log10(sample_sizes), observed_rates,
           s=50, color="steelblue", label="Raw rate", zorder=3)
ax.scatter(np.log10(sample_sizes), theta_means,
           s=50, color="darkorange", label="Hierarchical estimate", zorder=3)
ax.scatter(np.log10(sample_sizes), true_rates,
           s=30, color="green", marker="x", label="True rate", zorder=4)

pop_mean = trace_streamrec.posterior["mu_pop"].mean().values
ax.axhline(pop_mean, color="red", linestyle="--", alpha=0.5,
           label=f"Population mean ({pop_mean:.3f})")

ax.set_xlabel("log10(Impressions)")
ax.set_ylabel("Engagement rate")
ax.set_title("StreamRec: Hierarchical shrinkage by category size")
ax.legend()
plt.tight_layout()
plt.show()

The plot shows the defining pattern of partial pooling: the gray lines (connecting raw rates to hierarchical estimates) are longest for small categories (left side) and shortest for large categories (right side). Small categories are pulled strongly toward the population mean; large categories barely move.

Progressive Project Milestone: This completes the hierarchical Bayesian engagement rate model for StreamRec. Each content category now has an uncertainty-aware engagement estimate that accounts for both category-specific data and population-level information. In Chapter 22, these posteriors will feed into Thompson sampling for exploration-exploitation decisions.


21.10 When Bayesian Methods Earn Their Complexity

This chapter has demonstrated that Bayesian hierarchical models are powerful. They are also more complex than their frequentist alternatives — both conceptually and computationally. When is the complexity worth it?

The Decision Framework

from dataclasses import dataclass
from typing import Optional


@dataclass
class BayesianComplexityAssessment:
    """Structured assessment of whether Bayesian methods earn their complexity.

    Each factor is rated on a 0-3 scale. The total score guides the decision.

    Attributes:
        n_groups: Number of groups in the hierarchical structure.
        min_group_size: Smallest group sample size.
        max_group_size: Largest group sample size.
        has_genuine_prior: Whether genuine prior knowledge exists.
        needs_full_uncertainty: Whether full posterior uncertainty is needed.
        computational_budget_hours: Available compute time.
        team_bayesian_expertise: 0 (none) to 3 (expert).
    """
    n_groups: int
    min_group_size: int
    max_group_size: int
    has_genuine_prior: bool
    needs_full_uncertainty: bool
    computational_budget_hours: float
    team_bayesian_expertise: int

    def assess(self) -> str:
        """Compute assessment and return recommendation."""
        score = 0
        reasons = []

        # Group size imbalance: strong case for partial pooling
        if self.n_groups > 5 and self.min_group_size < 50:
            score += 3
            reasons.append("Small groups benefit from partial pooling")
        elif self.n_groups > 5:
            score += 1
            reasons.append("Multiple groups present")

        # Prior knowledge
        if self.has_genuine_prior:
            score += 2
            reasons.append("Genuine prior knowledge available")

        # Uncertainty needs
        if self.needs_full_uncertainty:
            score += 2
            reasons.append("Full posterior uncertainty required")

        # Computational feasibility (penalty)
        if self.computational_budget_hours < 0.1 and self.n_groups > 100:
            score -= 2
            reasons.append("PENALTY: Tight compute budget with many groups")

        # Team capability (penalty if low)
        if self.team_bayesian_expertise < 1:
            score -= 1
            reasons.append("PENALTY: Team lacks Bayesian experience")

        # Decision
        if score >= 5:
            decision = "STRONG CASE: Bayesian hierarchical model recommended."
        elif score >= 3:
            decision = "MODERATE CASE: Bayesian may add value. Compare with mixed-effects."
        else:
            decision = "WEAK CASE: Frequentist mixed-effects or regularized MLE likely sufficient."

        report = f"Score: {score}/10\n"
        report += f"Decision: {decision}\n"
        report += "Factors:\n"
        for r in reasons:
            report += f"  - {r}\n"
        return report


# StreamRec: strong case
print("=== StreamRec Categories ===")
print(BayesianComplexityAssessment(
    n_groups=25, min_group_size=200, max_group_size=500000,
    has_genuine_prior=True, needs_full_uncertainty=True,
    computational_budget_hours=1.0, team_bayesian_expertise=2,
).assess())

# MediCore: strong case
print("=== MediCore Hospitals ===")
print(BayesianComplexityAssessment(
    n_groups=12, min_group_size=60, max_group_size=520,
    has_genuine_prior=True, needs_full_uncertainty=True,
    computational_budget_hours=2.0, team_bayesian_expertise=2,
).assess())

# Large A/B test with 2 groups and millions of users: weak case
print("=== Simple A/B Test ===")
print(BayesianComplexityAssessment(
    n_groups=2, min_group_size=500000, max_group_size=500000,
    has_genuine_prior=False, needs_full_uncertainty=False,
    computational_budget_hours=0.01, team_bayesian_expertise=1,
).assess())
=== StreamRec Categories ===
Score: 7/10
Decision: STRONG CASE: Bayesian hierarchical model recommended.
Factors:
  - Small groups benefit from partial pooling
  - Genuine prior knowledge available
  - Full posterior uncertainty required

=== MediCore Hospitals ===
Score: 7/10
Decision: STRONG CASE: Bayesian hierarchical model recommended.
Factors:
  - Small groups benefit from partial pooling
  - Genuine prior knowledge available
  - Full posterior uncertainty required

=== Simple A/B Test ===
Score: 0/10
Decision: WEAK CASE: Frequentist mixed-effects or regularized MLE likely sufficient.
Factors:
  - Multiple groups present

The Honest Summary

Situation Bayesian Value Alternative
Many groups, unequal sizes, genuine prior High — partial pooling is the killer app Mixed-effects models (close, but no full posterior)
Few groups, abundant data per group Low — posterior ≈ MLE MLE or regularized regression
Need full uncertainty for decisions High — direct posterior probability statements Bootstrap (close, but limited for hierarchical)
Sequential updating in real time High — natural with conjugate models Online learning algorithms (different framework)
Complex model, tight compute budget Low — MCMC may not converge in time Variational inference or point estimates
Team has no Bayesian experience Low risk — start with a simpler method Invest in training, then adopt Bayesian methods

The honest answer is: Bayesian methods earn their complexity when the problem has hierarchical structure with unequal group sizes, genuine prior knowledge, and a need for calibrated uncertainty. These conditions are common in pharmaceutical trials, recommendation systems, sports analytics, education research, and small-area estimation — but they are not universal. A data scientist who uses Bayesian methods everywhere is as misguided as one who uses them nowhere.


21.11 Common Pitfalls and How to Avoid Them

Pitfall 1: Ignoring Divergences

Divergences mean the sampler cannot explore the posterior correctly. Never ignore them. The fix is almost always reparameterization (non-centered for hierarchical models) or increasing target_accept.

Pitfall 2: Overly Diffuse Priors

A prior of $\mathcal{N}(0, 10^6)$ on a parameter that should be between 0 and 1 creates computational problems (the sampler wastes time in irrelevant regions) and can produce poor posterior predictive checks. Use weakly informative priors that encode plausible ranges.

Pitfall 3: Not Checking the Prior Predictive

If you do not simulate data from your priors before fitting, you do not know what your model "believes" before seeing data. Prior predictive checks take seconds and catch fundamental modeling errors.

Pitfall 4: Treating the Posterior Mean as a Point Estimate

The entire point of Bayesian inference is the posterior distribution. If you extract the posterior mean and throw away the rest, you have done expensive MCMC to get a regularized MLE — and you could have gotten that faster with penalized likelihood.

Pitfall 5: Using Bayesian Methods for PR

"We used a Bayesian hierarchical model" sounds impressive. But if the model has no hierarchical structure, the priors are flat, the diagnostics are not checked, and the posterior predictive checks are not run, you have added complexity without benefit. Bayesian methods must earn their complexity through improved decisions, not through terminology.


Summary

Bayesian modeling in practice requires three things that Chapter 20's conjugate models did not: a computational engine (MCMC, specifically NUTS), a diagnostic framework ($\hat{R}$, ESS, divergences), and a workflow discipline (prior predictive $\to$ fit $\to$ posterior predictive $\to$ model comparison).

PyMC provides the computational engine. ArviZ provides the diagnostic and comparison tools. The non-centered parameterization resolves the most common convergence issue in hierarchical models. WAIC and PSIS-LOO compare models in terms of out-of-sample predictive accuracy, with built-in complexity penalties.

The killer application of practical Bayesian modeling is the hierarchical model. Partial pooling — estimating group-level parameters by borrowing strength from the population distribution — consistently outperforms both complete pooling (too much bias) and no pooling (too much variance). The improvement is greatest when group sizes are unequal and some groups have sparse data, which describes nearly every real-world dataset with group structure.

The decision to use Bayesian methods should be driven by problem characteristics, not ideology. Hierarchical structure, genuine prior knowledge, small samples, and the need for calibrated uncertainty are the conditions under which Bayesian methods earn their complexity. When these conditions hold, the posterior distribution provides exactly the answer that practitioners need: a full probability distribution over parameters, directly interpretable as "what we believe given the data."


Chapter 21 Notation Reference

Symbol Meaning
$\theta$ Parameter(s) of interest
$\theta_j$ Group-specific parameter for group $j$
$\mu, \tau$ Population-level hyperparameters (mean and SD)
$\eta_j$ Standard normal auxiliary variable (non-centered parameterization)
$\hat{R}$ R-hat convergence diagnostic
ESS Effective sample size
NUTS No-U-Turn Sampler (adaptive HMC)
HMC Hamiltonian Monte Carlo
WAIC Widely Applicable Information Criterion
LOO Leave-one-out cross-validation
PSIS Pareto-Smoothed Importance Sampling
$\hat{k}$ Pareto shape parameter (LOO diagnostic)
$\text{lppd}$ Log pointwise predictive density
$p_{\text{WAIC}}$ Effective number of parameters (WAIC)