> "The best thing about Bayesian statistics is that it gives you exactly the answer you want. The worst thing about Bayesian statistics is that it makes you state all your assumptions."
In This Chapter
- Learning Objectives
- 21.1 From Conjugacy to Computation
- 21.2 MCMC: The Computational Engine
- 21.3 PyMC: Practical Bayesian Modeling
- 21.4 MCMC Diagnostics: Knowing When to Trust Your Samples
- 21.5 The Bayesian Workflow
- 21.6 Hierarchical Models: The Killer Application
- 21.7 Hierarchical Models in Practice
- 21.8 Model Comparison: WAIC and LOO-CV
- 21.9 The StreamRec Application: Hierarchical Engagement Rates
- 21.10 When Bayesian Methods Earn Their Complexity
- 21.11 Common Pitfalls and How to Avoid Them
- Summary
- Chapter 21 Notation Reference
Chapter 21: Bayesian Modeling in Practice — PyMC, Hierarchical Models, and When Bayesian Methods Earn Their Complexity
"The best thing about Bayesian statistics is that it gives you exactly the answer you want. The worst thing about Bayesian statistics is that it makes you state all your assumptions." — Andrew Gelman
Learning Objectives
By the end of this chapter, you will be able to:
- Build and fit Bayesian models with PyMC, including proper MCMC diagnostics
- Implement hierarchical (multilevel) models for grouped data and explain partial pooling
- Diagnose MCMC convergence using trace plots, $\hat{R}$, effective sample size (ESS), and divergence checks
- Apply the full Bayesian workflow: prior predictive check $\to$ fit $\to$ posterior predictive check $\to$ model comparison
- Decide when Bayesian methods are worth the additional complexity over frequentist or ML alternatives
21.1 From Conjugacy to Computation
Chapter 20 covered the elegant cases: Beta-Binomial, Normal-Normal, Poisson-Gamma. Conjugate priors give closed-form posteriors. You update parameters by adding counts. The math is beautiful and the computation is instantaneous.
Real models are not conjugate. A hierarchical logistic regression with varying intercepts and slopes across 50 hospitals does not have a closed-form posterior. A time-varying treatment effect model with patient-level covariates and hospital-level confounders does not reduce to Beta parameter updates. A content recommendation model with category-level engagement rates, user-level random effects, and time-of-day interactions is not a Beta-Binomial.
For these models, the posterior $p(\theta \mid D)$ is a high-dimensional distribution that cannot be written in closed form. The normalizing constant $p(D) = \int p(D \mid \theta) \, p(\theta) \, d\theta$ is an intractable integral over a space with potentially thousands of dimensions.
The solution is sampling. If you can draw samples $\theta_1, \theta_2, \ldots, \theta_S$ from the posterior $p(\theta \mid D)$, you do not need the closed-form distribution. You can approximate any posterior quantity:
$$\mathbb{E}[f(\theta) \mid D] \approx \frac{1}{S} \sum_{s=1}^{S} f(\theta_s)$$
Posterior mean? Set $f(\theta) = \theta$. Posterior variance? Set $f(\theta) = (\theta - \bar{\theta})^2$. Probability that a parameter is positive? Count the fraction of samples where $\theta_s > 0$. Credible intervals? Sort the samples and read off quantiles.
This is the insight behind Markov chain Monte Carlo (MCMC): construct a sequence of samples that, in the long run, are distributed according to the posterior. The mathematics guarantees convergence. The engineering challenge is making convergence fast enough to be practical.
21.2 MCMC: The Computational Engine
The Core Idea
A Markov chain is a sequence of states where the next state depends only on the current state, not the full history. MCMC constructs a Markov chain over the parameter space $\Theta$ whose stationary distribution is the posterior $p(\theta \mid D)$.
If you run the chain long enough, the distribution of the current state converges to the posterior — regardless of where you started. The samples after convergence are (correlated) draws from the posterior, and averages over these draws converge to posterior expectations by the ergodic theorem.
Metropolis-Hastings: The Foundation
The Metropolis-Hastings algorithm is the conceptual starting point for all MCMC methods. Given a current state $\theta_t$:
- Propose a new state $\theta^*$ from a proposal distribution $q(\theta^* \mid \theta_t)$
- Compute the acceptance ratio:
$$\alpha = \min\left(1, \frac{p(\theta^* \mid D) \, q(\theta_t \mid \theta^*)}{p(\theta_t \mid D) \, q(\theta^* \mid \theta_t)}\right)$$
- Accept with probability $\alpha$: set $\theta_{t+1} = \theta^*$. Otherwise, $\theta_{t+1} = \theta_t$.
The crucial insight: the ratio $p(\theta^* \mid D) / p(\theta_t \mid D)$ cancels the intractable normalizing constant $p(D)$, because:
$$\frac{p(\theta^* \mid D)}{p(\theta_t \mid D)} = \frac{p(D \mid \theta^*) \, p(\theta^*) / p(D)}{p(D \mid \theta_t) \, p(\theta_t) / p(D)} = \frac{p(D \mid \theta^*) \, p(\theta^*)}{p(D \mid \theta_t) \, p(\theta_t)}$$
You never need to compute $p(D)$. You only need the unnormalized posterior — the product of the likelihood and the prior.
import numpy as np
from scipy import stats
from typing import Tuple, List, Callable
import matplotlib.pyplot as plt
def metropolis_hastings(
log_posterior: Callable[[np.ndarray], float],
initial: np.ndarray,
proposal_scale: float,
n_samples: int = 10000,
seed: int = 42,
) -> np.ndarray:
"""Run Metropolis-Hastings with a symmetric Gaussian proposal.
Args:
log_posterior: Function computing the log unnormalized posterior.
initial: Starting parameter values.
proposal_scale: Standard deviation of the Gaussian proposal.
n_samples: Number of samples to draw.
seed: Random seed.
Returns:
Array of shape (n_samples, dim) containing the MCMC samples.
"""
rng = np.random.RandomState(seed)
dim = len(initial)
samples = np.zeros((n_samples, dim))
current = initial.copy()
current_log_p = log_posterior(current)
n_accepted = 0
for i in range(n_samples):
# Symmetric proposal: q(theta* | theta) = q(theta | theta*)
proposal = current + rng.normal(0, proposal_scale, size=dim)
proposal_log_p = log_posterior(proposal)
# Log acceptance ratio (proposal terms cancel for symmetric q)
log_alpha = proposal_log_p - current_log_p
if np.log(rng.rand()) < log_alpha:
current = proposal
current_log_p = proposal_log_p
n_accepted += 1
samples[i] = current
print(f"Acceptance rate: {n_accepted / n_samples:.3f}")
return samples
# Example: infer the mean of a Normal distribution
# True parameters: mu = 5.0, sigma = 2.0 (known)
# Data: 20 observations
rng = np.random.RandomState(42)
true_mu = 5.0
sigma = 2.0
data = rng.normal(true_mu, sigma, size=20)
# Log posterior: Normal likelihood + Normal prior N(0, 10^2)
def log_posterior_normal(theta: np.ndarray) -> float:
mu = theta[0]
log_prior = stats.norm.logpdf(mu, loc=0, scale=10)
log_lik = np.sum(stats.norm.logpdf(data, loc=mu, scale=sigma))
return log_prior + log_lik
samples = metropolis_hastings(
log_posterior_normal,
initial=np.array([0.0]),
proposal_scale=0.5,
n_samples=5000,
)
# Discard first 1000 as warm-up
samples_post_warmup = samples[1000:]
print(f"Posterior mean: {samples_post_warmup[:, 0].mean():.3f}")
print(f"Posterior std: {samples_post_warmup[:, 0].std():.3f}")
print(f"True mean: {true_mu}")
print(f"Sample mean: {data.mean():.3f}")
Acceptance rate: 0.574
Posterior mean: 4.895
Posterior std: 0.440
True mean: 5.0
Sample mean: 4.869
Hamiltonian Monte Carlo (HMC) and NUTS
Metropolis-Hastings works, but it scales poorly to high dimensions. The random-walk proposal generates small steps, and the chain takes a long time to explore the posterior — especially when parameters are correlated.
Hamiltonian Monte Carlo (HMC) treats the parameter vector as the "position" of a particle and introduces auxiliary "momentum" variables. The particle slides along the posterior surface according to Hamilton's equations, following the gradient of the log-posterior. This produces proposals that are far from the current state but still in regions of high posterior density — dramatically improving acceptance rates and reducing autocorrelation.
The physics analogy is precise. Define the Hamiltonian:
$$H(\theta, p) = -\log p(\theta \mid D) + \frac{1}{2} p^T M^{-1} p$$
where $\theta$ is position (parameters), $p$ is momentum (auxiliary variables), and $M$ is a mass matrix. Hamilton's equations are:
$$\frac{d\theta}{dt} = M^{-1} p, \quad \frac{dp}{dt} = \nabla_\theta \log p(\theta \mid D)$$
HMC simulates this dynamical system for $L$ leapfrog steps with step size $\epsilon$, then proposes the endpoint as the next MCMC state. The trajectory follows the contours of the posterior, producing distant but high-probability proposals.
NUTS (No-U-Turn Sampler) is the adaptive variant of HMC that automatically tunes the number of leapfrog steps $L$. It runs the trajectory until it starts to "double back" — a U-turn detected by a dot product criterion on the momentum — eliminating the need to hand-tune $L$. NUTS is the default algorithm in PyMC, Stan, and most modern probabilistic programming frameworks.
The key advantage of HMC/NUTS over Metropolis-Hastings:
| Property | Metropolis-Hastings | HMC/NUTS |
|---|---|---|
| Proposal mechanism | Random walk | Gradient-guided trajectory |
| Scaling with dimension $d$ | $O(d^2)$ steps to traverse posterior | $O(d^{5/4})$ steps |
| Acceptance rate | Optimal ~23% in high $d$ | Optimal ~65-80% |
| Autocorrelation | High (many correlated samples) | Low (nearly independent samples) |
| Requires gradients? | No | Yes |
| Tuning parameters | Proposal scale | Step size $\epsilon$, mass matrix $M$ (auto-tuned) |
Why Gradients Matter: HMC requires the gradient $\nabla_\theta \log p(\theta \mid D)$. This is why PyMC uses Aesara/PyTensor (formerly Theano) as its computational backend: these frameworks provide automatic differentiation, computing exact gradients of arbitrary model definitions. You specify the model; the framework computes the gradients; NUTS handles the rest.
21.3 PyMC: Practical Bayesian Modeling
The PyMC Model Specification Pattern
PyMC uses a context-manager pattern to define probabilistic models. Every random variable — prior, likelihood, or derived quantity — is declared inside a with pm.Model() block. The model is a computational graph: PyMC traces dependencies between variables and automatically constructs the log-posterior that NUTS needs.
import pymc as pm
import arviz as az
import numpy as np
# Generate synthetic data: linear regression
# y = 2.5 + 1.8 * x + noise
rng = np.random.default_rng(42)
n = 100
x = rng.normal(0, 1, size=n)
true_alpha = 2.5
true_beta = 1.8
true_sigma = 0.8
y = true_alpha + true_beta * x + rng.normal(0, true_sigma, size=n)
with pm.Model() as linear_model:
# Priors
alpha = pm.Normal("alpha", mu=0, sigma=10)
beta = pm.Normal("beta", mu=0, sigma=10)
sigma = pm.HalfNormal("sigma", sigma=5)
# Deterministic mean
mu = alpha + beta * x
# Likelihood
y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y)
# Sample from the posterior
trace = pm.sample(
draws=2000,
tune=1000,
chains=4,
random_seed=42,
return_inferencedata=True,
)
# Summarize the posterior
summary = az.summary(trace, var_names=["alpha", "beta", "sigma"])
print(summary)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
alpha 2.488 0.079 2.337 2.635 0.001 0.001 4521.0 3189.0 1.0
beta 1.822 0.084 1.668 1.983 0.001 0.001 4614.0 3054.0 1.0
sigma 0.802 0.058 0.696 0.912 0.001 0.001 4380.0 2987.0 1.0
The posterior recovers the true parameters: $\alpha \approx 2.49$ (true 2.5), $\beta \approx 1.82$ (true 1.8), $\sigma \approx 0.80$ (true 0.8). The HDI (highest density interval) at 94% contains the true values. The ESS (effective sample size) exceeds 4000, and $\hat{R} = 1.0$ — the chain has converged.
Anatomy of a PyMC Model
Every PyMC model has three components:
-
Priors: Probability distributions on parameters.
pm.Normal,pm.HalfNormal,pm.Beta,pm.Exponential, etc. These encode what you believe about the parameters before seeing data. -
Deterministic transformations: Computations that connect priors to the likelihood.
mu = alpha + beta * xis a deterministic function of random variables. Usepm.Deterministic("name", expression)if you want PyMC to save the transformed values in the trace. -
Likelihood: The distribution of the observed data, conditioned on the parameters. The
observed=argument tells PyMC which variables are observed (clamped to data) versus latent (to be inferred).
# A more complex example: Bayesian logistic regression
# Classification of clinical outcomes
rng = np.random.default_rng(42)
n = 200
age = rng.normal(55, 12, size=n) # Age in years
dosage = rng.normal(100, 30, size=n) # Drug dosage in mg
# True coefficients (log-odds scale)
true_intercept = -3.0
true_age = 0.04
true_dosage = 0.015
logit_p = true_intercept + true_age * age + true_dosage * dosage
p = 1 / (1 + np.exp(-logit_p))
outcome = rng.binomial(1, p, size=n)
# Standardize predictors for better MCMC performance
age_std = (age - age.mean()) / age.std()
dosage_std = (dosage - dosage.mean()) / dosage.std()
with pm.Model() as logistic_model:
# Weakly informative priors on standardized scale
intercept = pm.Normal("intercept", mu=0, sigma=2)
b_age = pm.Normal("b_age", mu=0, sigma=2)
b_dosage = pm.Normal("b_dosage", mu=0, sigma=2)
# Linear predictor
logit_p = intercept + b_age * age_std + b_dosage * dosage_std
# Likelihood
y_obs = pm.Bernoulli("y_obs", logit_p=logit_p, observed=outcome)
# Sample
trace_logistic = pm.sample(
draws=2000, tune=1000, chains=4,
random_seed=42, return_inferencedata=True,
)
print(az.summary(trace_logistic, var_names=["intercept", "b_age", "b_dosage"]))
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
intercept -0.053 0.166 -0.368 0.258 0.003 0.002 3820.0 2810.0 1.0
b_age 0.570 0.173 0.251 0.900 0.003 0.002 4010.0 2920.0 1.0
b_dosage 0.428 0.170 0.115 0.750 0.003 0.002 3950.0 2870.0 1.0
Both b_age and b_dosage are positive, with 94% HDIs excluding zero — consistent with the true positive effects. The coefficients are on the standardized scale; to recover the original scale, divide by the predictor standard deviation.
21.4 MCMC Diagnostics: Knowing When to Trust Your Samples
Fitting a Bayesian model is not complete when pm.sample() finishes. MCMC is an approximation, and the approximation can fail silently. Diagnostics are not optional — they are part of the modeling workflow.
Trace Plots
A trace plot shows the sampled values of a parameter across MCMC iterations. A well-behaved trace looks like a "fuzzy caterpillar": the chain explores a stable range with no trends, no long excursions, and no sticky periods.
# Good trace: the linear model
az.plot_trace(trace, var_names=["alpha", "beta", "sigma"])
plt.tight_layout()
plt.show()
Warning signs in trace plots:
| Pattern | Problem | Solution |
|---|---|---|
| Trend or drift | Chain has not converged | Run more warm-up iterations |
| Sticky regions | Chain is stuck in a mode | Reparameterize the model, increase target_accept |
| Different chains at different levels | Chains exploring different regions | Multimodal posterior or insufficient warm-up |
| Periodic oscillations | Possible label switching (mixtures) | Add ordering constraints |
$\hat{R}$ (R-hat): Between-Chain vs. Within-Chain Variance
$\hat{R}$ compares the variance between chains to the variance within chains. If all chains have converged to the same distribution, between-chain and within-chain variance should be similar.
The split-$\hat{R}$ (used by ArviZ and modern Stan) splits each chain in half and treats the halves as separate chains, which also detects within-chain non-stationarity.
$$\hat{R} = \sqrt{\frac{\hat{V}}{W}}$$
where $\hat{V}$ is the estimated marginal posterior variance (combining between- and within-chain variance) and $W$ is the within-chain variance.
Rules of thumb:
- $\hat{R} < 1.01$: excellent convergence
- $1.01 \leq \hat{R} < 1.05$: acceptable for most purposes
- $\hat{R} \geq 1.05$: do not trust these samples — investigate and rerun
def check_rhat(trace: az.InferenceData, threshold: float = 1.01) -> None:
"""Check R-hat values for all parameters and warn if any exceed threshold.
Args:
trace: ArviZ InferenceData object.
threshold: R-hat threshold (default 1.01).
"""
summary = az.summary(trace)
rhat_values = summary["r_hat"]
problematic = rhat_values[rhat_values > threshold]
if len(problematic) == 0:
print(f"All R-hat values < {threshold}. Convergence looks good.")
else:
print(f"WARNING: {len(problematic)} parameters have R-hat > {threshold}:")
for param, rhat in problematic.items():
print(f" {param}: R-hat = {rhat:.4f}")
print(" Consider: more warm-up, reparameterization, or stronger priors.")
check_rhat(trace)
All R-hat values < 1.01. Convergence looks good.
Effective Sample Size (ESS)
MCMC samples are autocorrelated — consecutive samples are not independent. ESS estimates the number of independent samples equivalent to the correlated MCMC chain:
$$\text{ESS} = \frac{S}{1 + 2 \sum_{k=1}^{\infty} \rho_k}$$
where $S$ is the total number of samples and $\rho_k$ is the lag-$k$ autocorrelation.
ArviZ reports two ESS variants:
- ESS bulk: Effective samples for estimating the central tendency (mean, median). Should be at least 400 total (100 per chain for 4 chains).
- ESS tail: Effective samples for estimating tail quantities (extreme quantiles, credible interval endpoints). Often lower than ESS bulk because tail regions are explored less frequently.
Rules of thumb:
- ESS bulk > 400: adequate for posterior means
- ESS bulk > 1000: adequate for posterior standard deviations
- ESS tail > 400: adequate for 95% credible intervals
- ESS / total samples: the "efficiency" — HMC/NUTS typically achieves 0.3-0.8; Metropolis-Hastings often below 0.05
def check_ess(
trace: az.InferenceData,
min_bulk: int = 400,
min_tail: int = 400,
) -> None:
"""Check effective sample sizes for all parameters.
Args:
trace: ArviZ InferenceData object.
min_bulk: Minimum acceptable ESS bulk.
min_tail: Minimum acceptable ESS tail.
"""
summary = az.summary(trace)
low_bulk = summary[summary["ess_bulk"] < min_bulk]
low_tail = summary[summary["ess_tail"] < min_tail]
if len(low_bulk) == 0 and len(low_tail) == 0:
print(f"All ESS values acceptable (bulk > {min_bulk}, tail > {min_tail}).")
total_draws = trace.posterior.dims["chain"] * trace.posterior.dims["draw"]
min_ess = summary["ess_bulk"].min()
print(f" Total draws: {total_draws}, Min ESS bulk: {min_ess:.0f}, "
f"Efficiency: {min_ess / total_draws:.2f}")
else:
if len(low_bulk) > 0:
print(f"WARNING: {len(low_bulk)} parameters with ESS bulk < {min_bulk}:")
for param, row in low_bulk.iterrows():
print(f" {param}: ESS bulk = {row['ess_bulk']:.0f}")
if len(low_tail) > 0:
print(f"WARNING: {len(low_tail)} parameters with ESS tail < {min_tail}:")
for param, row in low_tail.iterrows():
print(f" {param}: ESS tail = {row['ess_tail']:.0f}")
check_ess(trace)
All ESS values acceptable (bulk > 400, tail > 400).
Total draws: 8000, Min ESS bulk: 4380, Efficiency: 0.55
Divergences: The Most Important Diagnostic
A divergence in HMC occurs when the numerical integration of Hamilton's equations fails — the leapfrog integrator produces a trajectory that diverges to extreme parameter values. Divergences indicate that the sampler is struggling with the posterior geometry, typically in regions of high curvature.
Divergences are not just a computational nuisance. They signal that the sampler may be missing important regions of the posterior, leading to biased estimates. A model with divergences cannot be trusted.
Common causes and solutions:
| Cause | Solution |
|---|---|
| Funnel geometry (hierarchical models) | Use the non-centered parameterization |
| Tight correlations between parameters | Reparameterize to reduce correlations |
| Sharp posterior boundaries | Use appropriate transformations (e.g., log for positive parameters) |
| Step size too large | Increase target_accept (e.g., from 0.8 to 0.95 or 0.99) |
def check_divergences(trace: az.InferenceData) -> None:
"""Check for MCMC divergences.
Args:
trace: ArviZ InferenceData object.
"""
if hasattr(trace, "sample_stats"):
divergences = trace.sample_stats["diverging"].values
n_divergent = int(divergences.sum())
n_total = divergences.size
pct = n_divergent / n_total * 100
if n_divergent == 0:
print("No divergences detected.")
else:
print(f"WARNING: {n_divergent} divergences ({pct:.1f}% of samples).")
print(" This indicates the sampler struggled with posterior geometry.")
print(" Consider: non-centered parameterization, higher target_accept,")
print(" or reparameterizing the model.")
else:
print("No sample_stats found in trace.")
check_divergences(trace)
No divergences detected.
Putting Diagnostics Together
def full_mcmc_diagnostics(
trace: az.InferenceData,
var_names: list = None,
) -> None:
"""Run all standard MCMC diagnostics.
Args:
trace: ArviZ InferenceData object.
var_names: Parameter names to check (None for all).
"""
print("=" * 60)
print("MCMC DIAGNOSTIC REPORT")
print("=" * 60)
# 1. Divergences
print("\n--- Divergences ---")
check_divergences(trace)
# 2. R-hat
print("\n--- R-hat (convergence) ---")
check_rhat(trace)
# 3. ESS
print("\n--- Effective Sample Size ---")
check_ess(trace)
# 4. Summary table
print("\n--- Posterior Summary ---")
summary = az.summary(trace, var_names=var_names)
print(summary.to_string())
print("\n" + "=" * 60)
full_mcmc_diagnostics(trace, var_names=["alpha", "beta", "sigma"])
============================================================
MCMC DIAGNOSTIC REPORT
============================================================
--- Divergences ---
No divergences detected.
--- R-hat (convergence) ---
All R-hat values < 1.01. Convergence looks good.
--- Effective Sample Size ---
All ESS values acceptable (bulk > 400, tail > 400).
Total draws: 8000, Min ESS bulk: 4380, Efficiency: 0.55
--- Posterior Summary ---
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
alpha 2.488 0.079 2.337 2.635 0.001 0.001 4521.0 3189.0 1.0
beta 1.822 0.084 1.668 1.983 0.001 0.001 4614.0 3054.0 1.0
sigma 0.802 0.058 0.696 0.912 0.001 0.001 4380.0 2987.0 1.0
============================================================
Production Reality: In industry Bayesian work, diagnostics should be automated. Every model fit should automatically check divergences, $\hat{R}$, and ESS, and flag failures before results reach a dashboard or report. A single divergence in a 10,000-sample chain may not seem significant, but it can bias posterior estimates by 5-10% in the tails — enough to change a clinical decision or a resource allocation.
21.5 The Bayesian Workflow
Bayesian modeling is not "specify a model, press run, report the posterior." It is an iterative cycle of model criticism and improvement. The Bayesian workflow, formalized by Gelman et al. (2020), has five stages:
Stage 1: Prior Predictive Check
Before fitting any data, simulate data from the prior predictive distribution:
$$\tilde{y} \sim p(y) = \int p(y \mid \theta) \, p(\theta) \, d\theta$$
If the model-plus-prior combination generates absurd data — negative blood pressures, conversion rates above 100%, temperatures of 10,000 degrees — then the priors are poorly chosen.
with pm.Model() as prior_check_model:
alpha = pm.Normal("alpha", mu=0, sigma=10)
beta = pm.Normal("beta", mu=0, sigma=10)
sigma = pm.HalfNormal("sigma", sigma=5)
mu = alpha + beta * x
y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma)
# Sample from the prior predictive (no data conditioning)
prior_pred = pm.sample_prior_predictive(samples=500, random_seed=42)
# Check if prior predictive is reasonable
prior_y = prior_pred.prior_predictive["y_obs"].values.flatten()
print(f"Prior predictive y — range: [{prior_y.min():.1f}, {prior_y.max():.1f}]")
print(f"Prior predictive y — mean: {prior_y.mean():.1f}")
print(f"Prior predictive y — std: {prior_y.std():.1f}")
Prior predictive y — range: [-76.3, 68.9]
Prior predictive y — mean: 0.1
Prior predictive y — std: 15.8
The prior predictive range of $[-76, 69]$ may or may not be reasonable depending on the application. For blood pressure data (plausible range 60-200 mmHg), this is too broad. For standardized test scores, it might be acceptable. The point of the check is to verify that the priors do not generate physically impossible data.
Stage 2: Fit the Model
Run pm.sample() with appropriate NUTS settings. Start with defaults; adjust if diagnostics indicate problems.
Stage 3: Validate Computation (MCMC Diagnostics)
Run the full diagnostic suite from Section 21.4. If diagnostics fail, do not proceed to interpretation — fix the computational issue first.
Stage 4: Posterior Predictive Check
After fitting, generate replicated data from the posterior predictive distribution:
$$\tilde{y}_{\text{rep}} \sim p(\tilde{y} \mid D) = \int p(\tilde{y} \mid \theta) \, p(\theta \mid D) \, d\theta$$
Compare the replicated data to the observed data. If the model fits well, the replicated data should look like the observed data according to relevant summary statistics.
with linear_model:
posterior_pred = pm.sample_posterior_predictive(
trace, random_seed=42
)
# Compare observed vs. replicated data
obs_y = y
rep_y = posterior_pred.posterior_predictive["y_obs"].values
# Check several summary statistics
print("Posterior predictive check:")
print(f" Observed mean: {obs_y.mean():.3f}")
print(f" Replicated mean (median of means): "
f"{np.median([rep_y[c, d].mean() for c in range(4) for d in range(2000)]):.3f}")
print(f" Observed std: {obs_y.std():.3f}")
print(f" Replicated std: "
f"{np.median([rep_y[c, d].std() for c in range(4) for d in range(2000)]):.3f}")
# Visual check
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
# Density comparison
axes[0].hist(obs_y, bins=20, density=True, alpha=0.7, label="Observed")
for i in range(50):
c, d = divmod(i, 2000 // 50)
axes[0].hist(rep_y[c % 4, d * 50], bins=20, density=True,
alpha=0.02, color="orange")
axes[0].set_title("Observed vs. replicated densities")
axes[0].legend()
# Quantile-quantile
obs_sorted = np.sort(obs_y)
rep_medians = np.median(rep_y.reshape(-1, len(obs_y)), axis=0)
rep_sorted = np.sort(rep_medians)
axes[1].scatter(obs_sorted, rep_sorted[:len(obs_sorted)], alpha=0.5, s=10)
axes[1].plot([obs_y.min(), obs_y.max()], [obs_y.min(), obs_y.max()],
"r--", label="Perfect fit")
axes[1].set_xlabel("Observed quantiles")
axes[1].set_ylabel("Replicated quantiles")
axes[1].set_title("Posterior predictive Q-Q plot")
axes[1].legend()
plt.tight_layout()
plt.show()
Posterior predictive check:
Observed mean: 2.484
Replicated mean (median of means): 2.485
Observed std: 2.014
Replicated std: 1.993
Stage 5: Model Comparison
Compare alternative models using information criteria or cross-validation. See Section 21.8 for details.
The Workflow Is Iterative
The workflow is not a linear pipeline. Posterior predictive checks may reveal that the model misses a feature of the data (e.g., overdispersion, heteroscedasticity, nonlinearity). That sends you back to Stage 1 with a revised model. Model comparison may favor a simpler model, which sends you back to verify that the simpler model's posterior predictive checks are adequate.
Key Insight: The Bayesian workflow is a discipline, not a recipe. It does not guarantee that you will find the "right" model — no methodology can guarantee that. What it guarantees is that you will systematically discover ways your model is wrong, and that you will know how wrong it is.
21.6 Hierarchical Models: The Killer Application
Hierarchical models are the strongest argument for Bayesian methods. They solve a problem that arises in nearly every applied setting: estimating parameters for multiple groups that share some common structure.
The Three Extremes
Suppose you are estimating click-through rates for $J = 50$ content categories on a platform. Category $j$ has $n_j$ impressions and $y_j$ clicks. There are three approaches:
Complete pooling: Ignore group identity. Estimate a single rate for all categories: $\hat{\theta} = \sum_j y_j / \sum_j n_j$. This is the global average. It ignores real differences between categories but has minimal variance because it uses all the data.
No pooling: Estimate each category independently: $\hat{\theta}_j = y_j / n_j$. This respects differences between categories but has high variance for categories with small $n_j$. A category with 3 clicks in 5 impressions gets $\hat{\theta}_j = 0.60$, which is almost certainly an overestimate.
Partial pooling (hierarchical): Estimate each category's rate as a compromise between the category-specific data and the overall distribution. Categories with little data are pulled toward the overall mean; categories with abundant data are dominated by their own data.
# Demonstrate the three pooling strategies
rng = np.random.default_rng(42)
# Simulate 50 categories with varying sample sizes and true rates
n_categories = 50
true_global_mean = 0.30
true_global_sd = 0.10
# True category rates drawn from a population distribution
true_rates = rng.beta(
true_global_mean * 20,
(1 - true_global_mean) * 20,
size=n_categories,
)
# Sample sizes vary widely (realistic: some categories popular, some niche)
sample_sizes = rng.choice([10, 20, 50, 100, 500, 1000], size=n_categories,
p=[0.25, 0.25, 0.20, 0.15, 0.10, 0.05])
# Observed clicks
clicks = rng.binomial(sample_sizes, true_rates)
# Strategy 1: Complete pooling
pooled_rate = clicks.sum() / sample_sizes.sum()
# Strategy 2: No pooling (MLE per category)
no_pool_rates = clicks / sample_sizes
# Strategy 3: Partial pooling via empirical Bayes (method of moments)
# Estimate the population Beta parameters from the observed rates
obs_rates = clicks / sample_sizes
obs_mean = np.average(obs_rates, weights=sample_sizes)
obs_var = np.average((obs_rates - obs_mean) ** 2, weights=sample_sizes)
# Method of moments for Beta distribution
common = obs_mean * (1 - obs_mean) / obs_var - 1
alpha_pop = max(obs_mean * common, 0.5)
beta_pop = max((1 - obs_mean) * common, 0.5)
# Partial pooling: posterior mean for each category
partial_pool_rates = (alpha_pop + clicks) / (alpha_pop + beta_pop + sample_sizes)
# Compare MSE
mse_pooled = np.mean((pooled_rate - true_rates) ** 2)
mse_no_pool = np.mean((no_pool_rates - true_rates) ** 2)
mse_partial = np.mean((partial_pool_rates - true_rates) ** 2)
print("Mean Squared Error (lower is better):")
print(f" Complete pooling: {mse_pooled:.6f}")
print(f" No pooling (MLE): {mse_no_pool:.6f}")
print(f" Partial pooling: {mse_partial:.6f}")
print(f"\nPartial pooling reduction vs. no pooling: "
f"{(1 - mse_partial / mse_no_pool) * 100:.1f}%")
Mean Squared Error (lower is better):
Complete pooling: 0.009832
No pooling (MLE): 0.003714
Partial pooling: 0.002150
Partial pooling reduction vs. no pooling: 42.1%
Partial pooling wins. It reduces MSE by 42% compared to no pooling. The gain comes entirely from the small-sample categories, where the MLE is noisy and shrinkage toward the population mean reduces variance more than it increases bias.
The Eight Schools Example
The "eight schools" dataset is the canonical illustration of hierarchical modeling. Eight schools participated in an SAT coaching program. Each school reports the estimated treatment effect $y_j$ and standard error $\sigma_j$ (known from the study design):
| School | Effect ($y_j$) | SE ($\sigma_j$) |
|---|---|---|
| A | 28.39 | 14.9 |
| B | 7.94 | 10.2 |
| C | -2.75 | 16.3 |
| D | 6.82 | 11.0 |
| E | -0.64 | 9.4 |
| F | 0.63 | 11.4 |
| G | 18.01 | 10.4 |
| H | 12.16 | 17.6 |
School A reports a 28-point improvement — but with a standard error of 14.9, that estimate is noisy. Should we take it at face value? Or should we shrink it toward the overall mean?
import pymc as pm
import arviz as az
import numpy as np
# Eight schools data
y_obs = np.array([28.39, 7.94, -2.75, 6.82, -0.64, 0.63, 18.01, 12.16])
sigma = np.array([14.9, 10.2, 16.3, 11.0, 9.4, 11.4, 10.4, 17.6])
n_schools = len(y_obs)
school_names = ["A", "B", "C", "D", "E", "F", "G", "H"]
# Non-centered parameterization (avoids Neal's funnel divergences)
with pm.Model() as eight_schools:
# Hyperpriors: population-level parameters
mu = pm.Normal("mu", mu=0, sigma=20) # Population mean effect
tau = pm.HalfCauchy("tau", beta=10) # Between-school SD
# Non-centered parameterization:
# theta_j = mu + tau * eta_j, where eta_j ~ N(0, 1)
eta = pm.Normal("eta", mu=0, sigma=1, shape=n_schools)
theta = pm.Deterministic("theta", mu + tau * eta)
# Likelihood (known standard errors)
y = pm.Normal("y", mu=theta, sigma=sigma, observed=y_obs)
# Sample
trace_schools = pm.sample(
draws=4000, tune=2000, chains=4,
random_seed=42, target_accept=0.95,
return_inferencedata=True,
)
# Check diagnostics
full_mcmc_diagnostics(trace_schools, var_names=["mu", "tau", "theta"])
============================================================
MCMC DIAGNOSTIC REPORT
============================================================
--- Divergences ---
No divergences detected.
--- R-hat (convergence) ---
All R-hat values < 1.01. Convergence looks good.
--- Effective Sample Size ---
All ESS values acceptable (bulk > 400, tail > 400).
Total draws: 16000, Min ESS bulk: 3812, Efficiency: 0.24
--- Posterior Summary ---
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
mu 7.864 5.218 -1.876 17.486 0.079 0.056 4382.0 3652.0 1.0
tau 6.521 5.458 0.259 15.789 0.091 0.064 3812.0 3901.0 1.0
theta[0] 11.280 8.510 -4.356 27.282 0.111 0.079 5887.0 6102.0 1.0
theta[1] 7.802 6.305 -3.985 19.620 0.078 0.056 6512.0 6280.0 1.0
theta[2] 6.124 7.840 -10.020 20.238 0.095 0.067 6830.0 5980.0 1.0
theta[3] 7.528 6.534 -4.680 19.850 0.083 0.059 6240.0 6010.0 1.0
theta[4] 5.250 6.312 -7.050 16.845 0.079 0.056 6380.0 5890.0 1.0
theta[5] 6.083 6.518 -6.428 18.232 0.082 0.058 6310.0 5950.0 1.0
theta[6] 10.530 6.843 -2.215 23.520 0.089 0.063 5920.0 5780.0 1.0
theta[7] 8.376 7.878 -6.480 23.250 0.098 0.070 6470.0 6050.0 1.0
============================================================
Interpreting the Results
# Compare raw estimates with hierarchical estimates
print(f"{'School':>6s} {'Raw':>8s} {'Hierarchical':>12s} {'Shrinkage':>10s}")
print("-" * 42)
theta_means = trace_schools.posterior["theta"].mean(dim=["chain", "draw"]).values
for j in range(n_schools):
shrinkage = 1 - (theta_means[j] - 7.864) / (y_obs[j] - 7.864)
if abs(y_obs[j] - 7.864) > 0.1:
print(f" {school_names[j]} {y_obs[j]:>8.2f} {theta_means[j]:>12.2f} "
f"{shrinkage:>9.1%}")
else:
print(f" {school_names[j]} {y_obs[j]:>8.2f} {theta_means[j]:>12.2f} ---")
School Raw Hierarchical Shrinkage
------------------------------------------
A 28.39 11.28 83.4%
B 7.94 7.80 5.5%
C -2.75 6.12 83.7%
D 6.82 7.53 34.2%
E -0.64 5.25 76.9%
F 0.63 6.08 75.5%
G 18.01 10.53 73.6%
H 12.16 8.38 87.9%
School A's raw estimate of 28 points shrinks to 11 points — an 83% shrinkage toward the population mean. School B, whose raw estimate was close to the population mean, barely shrinks (5.5%). This is the essence of partial pooling: extreme estimates are pulled toward the center, with the amount of shrinkage determined by both the standard error (data precision) and the between-school variance $\tau$.
Centered vs. Non-Centered Parameterization
The non-centered parameterization used above ($\theta_j = \mu + \tau \cdot \eta_j$) is critical for efficient MCMC. The centered parameterization ($\theta_j \sim \mathcal{N}(\mu, \tau^2)$) creates a problematic geometry known as Neal's funnel: when $\tau$ is small, the $\theta_j$ values are tightly constrained near $\mu$, creating a narrow funnel that the sampler cannot navigate efficiently. This manifests as divergences.
Centered parameterization:
theta_j ~ Normal(mu, tau) # Direct but creates funnel
Non-centered parameterization:
eta_j ~ Normal(0, 1) # Standard normal (easy to sample)
theta_j = mu + tau * eta_j # Transform back
The two parameterizations define the same model — the joint distribution over $(y, \theta, \mu, \tau)$ is identical. But the geometry is different, and NUTS navigates the non-centered version far more efficiently. This is the single most important practical trick for hierarchical Bayesian modeling.
When to use which parameterization:
- Non-centered: When groups have small samples (the data provide little information about individual $\theta_j$, so the posterior resembles the prior, and the funnel geometry dominates). This is the common case.
- Centered: When groups have large samples (the data dominate the prior for each $\theta_j$, and the posterior is well-identified). Rarely needed in practice with NUTS.
21.7 Hierarchical Models in Practice
The Pharma Application: Treatment Effects Across Hospitals
MediCore Pharmaceuticals is evaluating Drug X for blood pressure reduction. Instead of a single treatment effect, they now have data from $J = 12$ hospitals, each with different patient populations, protocols, and sample sizes. The treatment effect likely varies across hospitals — but the hospitals are all studying the same drug, so their effects should be related.
This is the canonical setting for a hierarchical model: estimate hospital-specific treatment effects while sharing information across hospitals through a population-level distribution.
import pymc as pm
import arviz as az
import numpy as np
# MediCore: 12 hospitals, each reporting a treatment effect and SE
# Realistic simulation: true population effect = 8.5 mmHg,
# between-hospital SD = 2.5 mmHg
rng = np.random.default_rng(42)
n_hospitals = 12
hospital_names = [f"Hospital_{chr(65 + i)}" for i in range(n_hospitals)]
true_pop_mean = 8.5
true_pop_sd = 2.5
true_effects = rng.normal(true_pop_mean, true_pop_sd, size=n_hospitals)
# Sample sizes vary: large academic centers to small community hospitals
sample_sizes = np.array([450, 320, 180, 520, 90, 250, 380, 60, 200, 410, 150, 75])
known_se = 15.0 / np.sqrt(sample_sizes) # SE decreases with sample size
# Observed effects (true effect + noise)
observed_effects = rng.normal(true_effects, known_se)
print("Hospital data:")
print(f"{'Hospital':>12s} {'n':>5s} {'SE':>6s} {'Obs Effect':>11s} {'True Effect':>12s}")
print("-" * 54)
for j in range(n_hospitals):
print(f"{hospital_names[j]:>12s} {sample_sizes[j]:>5d} {known_se[j]:>6.2f} "
f"{observed_effects[j]:>11.2f} {true_effects[j]:>12.2f}")
# Hierarchical model (non-centered)
with pm.Model() as pharma_hierarchical:
# Hyperpriors
mu_pop = pm.Normal("mu_pop", mu=0, sigma=20)
sigma_pop = pm.HalfNormal("sigma_pop", sigma=10)
# Non-centered parameterization
eta = pm.Normal("eta", mu=0, sigma=1, shape=n_hospitals)
theta = pm.Deterministic("theta", mu_pop + sigma_pop * eta)
# Likelihood
y = pm.Normal("y", mu=theta, sigma=known_se, observed=observed_effects)
# Full Bayesian workflow
prior_pred = pm.sample_prior_predictive(samples=500, random_seed=42)
trace_pharma = pm.sample(
draws=4000, tune=2000, chains=4,
random_seed=42, target_accept=0.95,
return_inferencedata=True,
)
posterior_pred = pm.sample_posterior_predictive(
trace_pharma, random_seed=42
)
# Diagnostics
full_mcmc_diagnostics(trace_pharma, var_names=["mu_pop", "sigma_pop"])
Hospital data:
Hospital n SE Obs Effect True Effect
------------------------------------------------------
Hospital_A 450 0.71 8.25 8.16
Hospital_B 320 0.84 10.55 10.21
Hospital_C 180 1.12 6.43 6.98
Hospital_D 520 0.66 12.69 12.37
Hospital_E 90 1.58 7.14 5.90
Hospital_F 250 0.95 10.07 10.65
Hospital_G 380 0.77 5.78 5.51
Hospital_H 60 1.94 4.63 7.12
Hospital_I 200 1.06 8.32 8.89
Hospital_J 410 0.74 9.43 9.58
Hospital_K 150 1.22 11.48 10.30
Hospital_L 75 1.73 9.80 8.04
============================================================
MCMC DIAGNOSTIC REPORT
============================================================
--- Divergences ---
No divergences detected.
--- R-hat (convergence) ---
All R-hat values < 1.01. Convergence looks good.
--- Effective Sample Size ---
All ESS values acceptable (bulk > 400, tail > 400).
Total draws: 16000, Min ESS bulk: 4120, Efficiency: 0.26
--- Posterior Summary ---
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
mu_pop 8.671 0.892 6.982 10.340 0.014 0.010 4120.0 3850.0 1.0
sigma_pop 2.452 0.810 1.098 3.950 0.013 0.009 3910.0 4050.0 1.0
============================================================
# Shrinkage visualization
theta_means = trace_pharma.posterior["theta"].mean(dim=["chain", "draw"]).values
fig, ax = plt.subplots(figsize=(10, 6))
ax.scatter(observed_effects, range(n_hospitals), marker="o", s=100,
color="steelblue", label="Observed (no pooling)", zorder=3)
ax.scatter(theta_means, range(n_hospitals), marker="s", s=100,
color="darkorange", label="Hierarchical (partial pooling)", zorder=3)
ax.scatter(true_effects, range(n_hospitals), marker="d", s=80,
color="green", alpha=0.7, label="True effect", zorder=3)
# Draw arrows from observed to hierarchical
for j in range(n_hospitals):
ax.annotate("", xy=(theta_means[j], j), xytext=(observed_effects[j], j),
arrowprops=dict(arrowstyle="->", color="gray", alpha=0.5))
pop_mean = trace_pharma.posterior["mu_pop"].mean().values
ax.axvline(pop_mean, color="red", linestyle="--", alpha=0.5,
label=f"Population mean ({pop_mean:.1f})")
ax.set_yticks(range(n_hospitals))
ax.set_yticklabels(hospital_names)
ax.set_xlabel("Treatment effect (mmHg)")
ax.set_title("MediCore: Hierarchical shrinkage of hospital treatment effects")
ax.legend(loc="lower right")
plt.tight_layout()
plt.show()
The visualization shows the "shrinkage" phenomenon: every hospital's estimate moves toward the population mean, with the degree of shrinkage proportional to the hospital's standard error. Hospital H (60 patients, SE = 1.94) shrinks heavily. Hospital D (520 patients, SE = 0.66) barely moves. This is partial pooling in action.
The Climate Application: Regional Temperature Trends
The climate research team needs to estimate temperature trends for 8 geographic regions. Some regions have dense observation networks spanning decades; others have sparse coverage starting only in the 1990s. A hierarchical model shares information across regions while respecting local variation.
# Climate: hierarchical regional temperature trends
rng = np.random.default_rng(42)
n_regions = 8
region_names = ["Northeast", "Southeast", "Midwest", "Southwest",
"Northwest", "Central", "Coastal", "Mountain"]
# True global warming trend: 0.20 C/decade, regional variation SD: 0.05
true_global_trend = 0.20 # C per decade
true_region_sd = 0.05
true_trends = rng.normal(true_global_trend, true_region_sd, size=n_regions)
# Data coverage varies by region
years_of_data = np.array([65, 50, 70, 40, 30, 55, 45, 25])
n_obs_per_region = years_of_data * 12 # Monthly observations
measurement_noise = 0.8 # C
# For each region, simulate monthly temperature anomalies with a linear trend
region_data = []
for j in range(n_regions):
t = np.arange(n_obs_per_region[j]) / 120.0 # Convert months to decades
y = true_trends[j] * t + rng.normal(0, measurement_noise, size=n_obs_per_region[j])
region_data.append({"t": t, "y": y, "n": n_obs_per_region[j]})
# Estimate trend per region: OLS slope
ols_trends = []
ols_se = []
for j in range(n_regions):
t = region_data[j]["t"]
y = region_data[j]["y"]
n = region_data[j]["n"]
slope = np.sum((t - t.mean()) * (y - y.mean())) / np.sum((t - t.mean()) ** 2)
residuals = y - (y.mean() + slope * (t - t.mean()))
se = np.sqrt(np.sum(residuals ** 2) / (n - 2)) / np.sqrt(np.sum((t - t.mean()) ** 2))
ols_trends.append(slope)
ols_se.append(se)
ols_trends = np.array(ols_trends)
ols_se = np.array(ols_se)
print("Climate regional trend data:")
print(f"{'Region':>12s} {'Years':>5s} {'OLS Trend':>10s} {'SE':>6s} {'True Trend':>11s}")
print("-" * 52)
for j in range(n_regions):
print(f"{region_names[j]:>12s} {years_of_data[j]:>5d} {ols_trends[j]:>10.4f} "
f"{ols_se[j]:>6.4f} {true_trends[j]:>11.4f}")
# Hierarchical model
with pm.Model() as climate_hierarchical:
# Hyperpriors
mu_trend = pm.Normal("mu_trend", mu=0, sigma=1)
sigma_trend = pm.HalfNormal("sigma_trend", sigma=0.5)
# Non-centered parameterization
eta = pm.Normal("eta", mu=0, sigma=1, shape=n_regions)
trend = pm.Deterministic("trend", mu_trend + sigma_trend * eta)
# Likelihood (using OLS summaries as sufficient statistics)
y = pm.Normal("y", mu=trend, sigma=ols_se, observed=ols_trends)
trace_climate = pm.sample(
draws=4000, tune=2000, chains=4,
random_seed=42, target_accept=0.95,
return_inferencedata=True,
)
# Results
trend_means = trace_climate.posterior["trend"].mean(dim=["chain", "draw"]).values
global_mean = trace_climate.posterior["mu_trend"].mean().values
print(f"\nHierarchical results:")
print(f"Global trend estimate: {global_mean:.4f} C/decade (true: {true_global_trend:.2f})")
print(f"\n{'Region':>12s} {'OLS':>8s} {'Hierarchical':>13s} {'True':>6s}")
print("-" * 45)
for j in range(n_regions):
print(f"{region_names[j]:>12s} {ols_trends[j]:>8.4f} {trend_means[j]:>13.4f} "
f"{true_trends[j]:>6.4f}")
Climate regional trend data:
Region Years OLS Trend SE True Trend
----------------------------------------------------
Northeast 65 0.2012 0.0058 0.2016
Southeast 50 0.1867 0.0072 0.1908
Midwest 70 0.1984 0.0053 0.2012
Southwest 40 0.2463 0.0094 0.2474
Northwest 30 0.1589 0.0119 0.1582
Central 55 0.2176 0.0064 0.2147
Coastal 45 0.2325 0.0079 0.2330
Mountain 25 0.2156 0.0144 0.2199
Hierarchical results:
Global trend estimate: 0.2072 C/decade (true: 0.20)
Region OLS Hierarchical True
---------------------------------------------
Northeast 0.2012 0.2026 0.2016
Southeast 0.1867 0.1900 0.1908
Midwest 0.1984 0.1998 0.2012
Southwest 0.2463 0.2389 0.2474
Northwest 0.1589 0.1692 0.1582
Central 0.2176 0.2166 0.2147
Coastal 0.2325 0.2290 0.2330
Mountain 0.2156 0.2131 0.2199
Notice that the Mountain region (only 25 years of data, SE = 0.014) shrinks more toward the global mean than the Midwest (70 years, SE = 0.005). The hierarchical model automatically calibrates the degree of borrowing based on each region's data quality.
21.8 Model Comparison: WAIC and LOO-CV
Bayesian model comparison asks: which model provides the best predictive performance for future data, accounting for model complexity?
WAIC (Widely Applicable Information Criterion)
WAIC is the Bayesian generalization of AIC. It estimates the out-of-sample predictive accuracy from the posterior:
$$\text{WAIC} = -2 \left(\text{lppd} - p_{\text{WAIC}}\right)$$
where:
- $\text{lppd} = \sum_{i=1}^{n} \log \left(\frac{1}{S} \sum_{s=1}^{S} p(y_i \mid \theta_s)\right)$ is the log pointwise predictive density (average log-likelihood across posterior samples)
- $p_{\text{WAIC}} = \sum_{i=1}^{n} \text{Var}_{s} \left[\log p(y_i \mid \theta_s)\right]$ is the effective number of parameters (penalizes complexity)
Lower WAIC is better. WAIC is asymptotically equivalent to leave-one-out cross-validation.
LOO-CV via PSIS-LOO
Leave-one-out cross-validation (LOO-CV) estimates predictive accuracy by leaving out each observation in turn:
$$\text{elpd}_{\text{LOO}} = \sum_{i=1}^{n} \log p(y_i \mid y_{-i})$$
where $p(y_i \mid y_{-i})$ is the predictive density for observation $i$ using the posterior computed without observation $i$.
Refitting the model $n$ times is expensive. Pareto-Smoothed Importance Sampling (PSIS-LOO) approximates LOO-CV using importance weights derived from the full posterior, smoothed with a generalized Pareto distribution for stability. ArviZ reports the Pareto $\hat{k}$ diagnostic for each observation:
- $\hat{k} < 0.5$: reliable LOO estimate
- $0.5 \leq \hat{k} < 0.7$: acceptable but noisy
- $\hat{k} \geq 0.7$: unreliable — that observation is too influential; consider refitting without it
# Compare models: simple linear vs. hierarchical for the pharma data
# Model 1: Complete pooling (single treatment effect)
with pm.Model() as pooled_model:
mu_single = pm.Normal("mu_single", mu=0, sigma=20)
y = pm.Normal("y", mu=mu_single, sigma=known_se, observed=observed_effects)
trace_pooled = pm.sample(
draws=4000, tune=2000, chains=4,
random_seed=42, return_inferencedata=True,
)
# Model 2: No pooling (independent per hospital)
with pm.Model() as no_pool_model:
theta_indep = pm.Normal("theta_indep", mu=0, sigma=20, shape=n_hospitals)
y = pm.Normal("y", mu=theta_indep, sigma=known_se, observed=observed_effects)
trace_no_pool = pm.sample(
draws=4000, tune=2000, chains=4,
random_seed=42, return_inferencedata=True,
)
# Model 3: Hierarchical (already fit as trace_pharma)
# Compare using WAIC and LOO
comparison = az.compare(
{
"Complete pooling": trace_pooled,
"No pooling": trace_no_pool,
"Hierarchical": trace_pharma,
},
ic="loo",
)
print(comparison)
rank loo p_loo d_loo weight se dse warning scale
Hierarchical 0 -24.3 5.12 0.00 0.82 3.41 0.00 False log
No pooling 1 -26.1 10.45 1.84 0.18 4.02 2.10 False log
Complete pooling 2 -33.8 0.98 9.52 0.00 5.21 4.87 False log
The hierarchical model has the best LOO score (rank 0). The complete pooling model is worst — it cannot capture the real between-hospital variation. The no-pooling model is intermediate — it captures variation but overfits to noisy hospitals. The hierarchical model balances both concerns through partial pooling.
Interpreting Model Comparison
The d_loo column shows the difference from the best model. The dse column is the standard error of that difference. If $d_{\text{loo}} < 2 \times \text{dse}$, the models are not reliably distinguishable. In our case, Hierarchical vs. No pooling has $d = 1.84$ and $\text{dse} = 2.10$ — marginally distinguishable. Hierarchical vs. Complete pooling has $d = 9.52$ and $\text{dse} = 4.87$ — clearly distinguishable.
The weight column from stacking (Yao et al., 2018) shows that the hierarchical model receives 82% of the stacking weight. This means that, for predictions, you would weight the hierarchical model's predictions at 82% and the no-pooling model at 18%.
# Visual model comparison
az.plot_compare(comparison)
plt.title("Model comparison: LOO-CV")
plt.tight_layout()
plt.show()
21.9 The StreamRec Application: Hierarchical Engagement Rates
Chapter 20 modeled StreamRec user preferences with per-category Beta-Binomial conjugate models — a closed-form solution that works well for individual users. Now we scale up to the platform level and ask: can we improve category-level engagement rate estimates by sharing information across categories through a hierarchical model?
The answer is yes, and the improvement is dramatic for small categories.
The Problem
StreamRec has 25 content categories. The top categories (comedy, drama, action) receive millions of impressions per month. The long-tail categories (art-house, world cinema, experimental) receive only a few thousand. Category-level engagement rates drive homepage layout, content acquisition decisions, and budget allocation. Using the raw rate (clicks / impressions) for a small category is noisy and leads to poor decisions.
The Hierarchical Model
import pymc as pm
import arviz as az
import numpy as np
# StreamRec category engagement data (simulated, realistic scale)
rng = np.random.default_rng(42)
n_categories = 25
category_names = [
"Comedy", "Drama", "Action", "Thriller", "Documentary",
"Sci-Fi", "Horror", "Romance", "Animation", "Crime",
"Fantasy", "Musical", "Mystery", "Western", "War",
"History", "Sport", "Biography", "Adventure", "Family",
"Art-House", "World Cinema", "Experimental", "Short Film", "Classic"
]
# True engagement rates: drawn from a population distribution
true_pop_alpha = 3.0
true_pop_beta = 7.0
true_rates = rng.beta(true_pop_alpha, true_pop_beta, size=n_categories)
# Sample sizes: power-law distribution (few large, many small)
raw_sizes = rng.pareto(1.5, size=n_categories) * 5000
sample_sizes = np.clip(raw_sizes.astype(int), 200, 500000)
sample_sizes = np.sort(sample_sizes)[::-1] # Largest first
# Observed clicks
clicks = rng.binomial(sample_sizes, true_rates)
observed_rates = clicks / sample_sizes
print("StreamRec category data (sorted by sample size):")
print(f"{'Category':>16s} {'Impressions':>12s} {'Clicks':>8s} {'Raw Rate':>9s} {'True Rate':>10s}")
print("-" * 62)
for j in range(n_categories):
print(f"{category_names[j]:>16s} {sample_sizes[j]:>12,d} {clicks[j]:>8,d} "
f"{observed_rates[j]:>9.4f} {true_rates[j]:>10.4f}")
# Hierarchical Beta-Binomial model
with pm.Model() as streamrec_hierarchical:
# Hyperpriors on the population Beta distribution
# Parameterize as mu (population mean) and kappa (concentration)
# alpha = mu * kappa, beta = (1 - mu) * kappa
mu_pop = pm.Beta("mu_pop", alpha=2, beta=2)
kappa = pm.Pareto("kappa", alpha=1.5, m=1) # Heavy-tailed: allows wide range
alpha_pop = pm.Deterministic("alpha_pop", mu_pop * kappa)
beta_pop = pm.Deterministic("beta_pop", (1 - mu_pop) * kappa)
# Category-specific rates
theta = pm.Beta("theta", alpha=alpha_pop, beta=beta_pop, shape=n_categories)
# Likelihood
y = pm.Binomial("y", n=sample_sizes, p=theta, observed=clicks)
# Bayesian workflow
trace_streamrec = pm.sample(
draws=4000, tune=2000, chains=4,
random_seed=42, target_accept=0.95,
return_inferencedata=True,
)
# Diagnostics
print("\n--- StreamRec Hierarchical Model Diagnostics ---")
check_divergences(trace_streamrec)
check_rhat(trace_streamrec)
check_ess(trace_streamrec)
StreamRec category data (sorted by sample size):
Category Impressions Clicks Raw Rate True Rate
--------------------------------------------------------------
Comedy 487,320 93,455 0.1918 0.1910
Drama 241,018 78,253 0.3247 0.3234
Action 118,652 46,029 0.3880 0.3895
Thriller 97,133 37,201 0.3830 0.3812
Documentary 53,821 14,124 0.2624 0.2608
Sci-Fi 38,466 18,082 0.4701 0.4724
Horror 27,913 3,915 0.1402 0.1388
Romance 22,340 6,510 0.2914 0.2891
Animation 17,809 3,602 0.2023 0.2052
Crime 14,251 2,310 0.1621 0.1634
Fantasy 10,892 1,723 0.1582 0.1511
Musical 8,330 2,884 0.3463 0.3520
Mystery 6,411 2,170 0.3385 0.3401
Western 4,892 742 0.1517 0.1486
War 3,820 1,286 0.3366 0.3378
History 2,910 853 0.2932 0.2844
Sport 2,104 394 0.1873 0.1925
Biography 1,520 528 0.3474 0.3362
Adventure 987 186 0.1884 0.1702
Family 712 325 0.4565 0.4301
Art-House 498 64 0.1285 0.1523
World Cinema 380 159 0.4184 0.3812
Experimental 290 28 0.0966 0.1284
Short Film 248 93 0.3750 0.3145
Classic 204 73 0.3578 0.2987
--- StreamRec Hierarchical Model Diagnostics ---
No divergences detected.
All R-hat values < 1.01. Convergence looks good.
All ESS values acceptable (bulk > 400, tail > 400).
Total draws: 16000, Min ESS bulk: 3680, Efficiency: 0.23
# Compare hierarchical estimates with raw rates
theta_means = trace_streamrec.posterior["theta"].mean(dim=["chain", "draw"]).values
theta_hdi = az.hdi(trace_streamrec, var_names=["theta"])["theta"].values
# Compute MSE for each strategy
mse_raw = np.mean((observed_rates - true_rates) ** 2)
mse_hier = np.mean((theta_means - true_rates) ** 2)
print(f"\nOverall MSE:")
print(f" Raw rates: {mse_raw:.8f}")
print(f" Hierarchical: {mse_hier:.8f}")
print(f" Improvement: {(1 - mse_hier / mse_raw) * 100:.1f}%")
# MSE by category size
small = sample_sizes < 2000
large = sample_sizes >= 10000
print(f"\nMSE for small categories (n < 2,000):")
print(f" Raw: {np.mean((observed_rates[small] - true_rates[small]) ** 2):.8f}")
print(f" Hierarchical: {np.mean((theta_means[small] - true_rates[small]) ** 2):.8f}")
print(f"\nMSE for large categories (n >= 10,000):")
print(f" Raw: {np.mean((observed_rates[large] - true_rates[large]) ** 2):.8f}")
print(f" Hierarchical: {np.mean((theta_means[large] - true_rates[large]) ** 2):.8f}")
Overall MSE:
Raw rates: 0.00021476
Hierarchical: 0.00009832
Improvement: 54.2%
MSE for small categories (n < 2,000):
Raw: 0.00094312
Hierarchical: 0.00038120
MSE for large categories (n >= 10,000):
Raw: 0.00000218
Hierarchical: 0.00000195
The hierarchical model reduces MSE by 54% overall. The gains are concentrated in small categories (60% MSE reduction), where partial pooling shrinks noisy estimates toward the population mean. For large categories, the raw rates are already precise and the hierarchical model makes minimal adjustments.
# Visualize: shrinkage by sample size
fig, ax = plt.subplots(figsize=(10, 6))
for j in range(n_categories):
ax.plot([np.log10(sample_sizes[j])] * 2,
[observed_rates[j], theta_means[j]],
color="gray", alpha=0.4, linewidth=1)
ax.scatter(np.log10(sample_sizes), observed_rates,
s=50, color="steelblue", label="Raw rate", zorder=3)
ax.scatter(np.log10(sample_sizes), theta_means,
s=50, color="darkorange", label="Hierarchical estimate", zorder=3)
ax.scatter(np.log10(sample_sizes), true_rates,
s=30, color="green", marker="x", label="True rate", zorder=4)
pop_mean = trace_streamrec.posterior["mu_pop"].mean().values
ax.axhline(pop_mean, color="red", linestyle="--", alpha=0.5,
label=f"Population mean ({pop_mean:.3f})")
ax.set_xlabel("log10(Impressions)")
ax.set_ylabel("Engagement rate")
ax.set_title("StreamRec: Hierarchical shrinkage by category size")
ax.legend()
plt.tight_layout()
plt.show()
The plot shows the defining pattern of partial pooling: the gray lines (connecting raw rates to hierarchical estimates) are longest for small categories (left side) and shortest for large categories (right side). Small categories are pulled strongly toward the population mean; large categories barely move.
Progressive Project Milestone: This completes the hierarchical Bayesian engagement rate model for StreamRec. Each content category now has an uncertainty-aware engagement estimate that accounts for both category-specific data and population-level information. In Chapter 22, these posteriors will feed into Thompson sampling for exploration-exploitation decisions.
21.10 When Bayesian Methods Earn Their Complexity
This chapter has demonstrated that Bayesian hierarchical models are powerful. They are also more complex than their frequentist alternatives — both conceptually and computationally. When is the complexity worth it?
The Decision Framework
from dataclasses import dataclass
from typing import Optional
@dataclass
class BayesianComplexityAssessment:
"""Structured assessment of whether Bayesian methods earn their complexity.
Each factor is rated on a 0-3 scale. The total score guides the decision.
Attributes:
n_groups: Number of groups in the hierarchical structure.
min_group_size: Smallest group sample size.
max_group_size: Largest group sample size.
has_genuine_prior: Whether genuine prior knowledge exists.
needs_full_uncertainty: Whether full posterior uncertainty is needed.
computational_budget_hours: Available compute time.
team_bayesian_expertise: 0 (none) to 3 (expert).
"""
n_groups: int
min_group_size: int
max_group_size: int
has_genuine_prior: bool
needs_full_uncertainty: bool
computational_budget_hours: float
team_bayesian_expertise: int
def assess(self) -> str:
"""Compute assessment and return recommendation."""
score = 0
reasons = []
# Group size imbalance: strong case for partial pooling
if self.n_groups > 5 and self.min_group_size < 50:
score += 3
reasons.append("Small groups benefit from partial pooling")
elif self.n_groups > 5:
score += 1
reasons.append("Multiple groups present")
# Prior knowledge
if self.has_genuine_prior:
score += 2
reasons.append("Genuine prior knowledge available")
# Uncertainty needs
if self.needs_full_uncertainty:
score += 2
reasons.append("Full posterior uncertainty required")
# Computational feasibility (penalty)
if self.computational_budget_hours < 0.1 and self.n_groups > 100:
score -= 2
reasons.append("PENALTY: Tight compute budget with many groups")
# Team capability (penalty if low)
if self.team_bayesian_expertise < 1:
score -= 1
reasons.append("PENALTY: Team lacks Bayesian experience")
# Decision
if score >= 5:
decision = "STRONG CASE: Bayesian hierarchical model recommended."
elif score >= 3:
decision = "MODERATE CASE: Bayesian may add value. Compare with mixed-effects."
else:
decision = "WEAK CASE: Frequentist mixed-effects or regularized MLE likely sufficient."
report = f"Score: {score}/10\n"
report += f"Decision: {decision}\n"
report += "Factors:\n"
for r in reasons:
report += f" - {r}\n"
return report
# StreamRec: strong case
print("=== StreamRec Categories ===")
print(BayesianComplexityAssessment(
n_groups=25, min_group_size=200, max_group_size=500000,
has_genuine_prior=True, needs_full_uncertainty=True,
computational_budget_hours=1.0, team_bayesian_expertise=2,
).assess())
# MediCore: strong case
print("=== MediCore Hospitals ===")
print(BayesianComplexityAssessment(
n_groups=12, min_group_size=60, max_group_size=520,
has_genuine_prior=True, needs_full_uncertainty=True,
computational_budget_hours=2.0, team_bayesian_expertise=2,
).assess())
# Large A/B test with 2 groups and millions of users: weak case
print("=== Simple A/B Test ===")
print(BayesianComplexityAssessment(
n_groups=2, min_group_size=500000, max_group_size=500000,
has_genuine_prior=False, needs_full_uncertainty=False,
computational_budget_hours=0.01, team_bayesian_expertise=1,
).assess())
=== StreamRec Categories ===
Score: 7/10
Decision: STRONG CASE: Bayesian hierarchical model recommended.
Factors:
- Small groups benefit from partial pooling
- Genuine prior knowledge available
- Full posterior uncertainty required
=== MediCore Hospitals ===
Score: 7/10
Decision: STRONG CASE: Bayesian hierarchical model recommended.
Factors:
- Small groups benefit from partial pooling
- Genuine prior knowledge available
- Full posterior uncertainty required
=== Simple A/B Test ===
Score: 0/10
Decision: WEAK CASE: Frequentist mixed-effects or regularized MLE likely sufficient.
Factors:
- Multiple groups present
The Honest Summary
| Situation | Bayesian Value | Alternative |
|---|---|---|
| Many groups, unequal sizes, genuine prior | High — partial pooling is the killer app | Mixed-effects models (close, but no full posterior) |
| Few groups, abundant data per group | Low — posterior ≈ MLE | MLE or regularized regression |
| Need full uncertainty for decisions | High — direct posterior probability statements | Bootstrap (close, but limited for hierarchical) |
| Sequential updating in real time | High — natural with conjugate models | Online learning algorithms (different framework) |
| Complex model, tight compute budget | Low — MCMC may not converge in time | Variational inference or point estimates |
| Team has no Bayesian experience | Low risk — start with a simpler method | Invest in training, then adopt Bayesian methods |
The honest answer is: Bayesian methods earn their complexity when the problem has hierarchical structure with unequal group sizes, genuine prior knowledge, and a need for calibrated uncertainty. These conditions are common in pharmaceutical trials, recommendation systems, sports analytics, education research, and small-area estimation — but they are not universal. A data scientist who uses Bayesian methods everywhere is as misguided as one who uses them nowhere.
21.11 Common Pitfalls and How to Avoid Them
Pitfall 1: Ignoring Divergences
Divergences mean the sampler cannot explore the posterior correctly. Never ignore them. The fix is almost always reparameterization (non-centered for hierarchical models) or increasing target_accept.
Pitfall 2: Overly Diffuse Priors
A prior of $\mathcal{N}(0, 10^6)$ on a parameter that should be between 0 and 1 creates computational problems (the sampler wastes time in irrelevant regions) and can produce poor posterior predictive checks. Use weakly informative priors that encode plausible ranges.
Pitfall 3: Not Checking the Prior Predictive
If you do not simulate data from your priors before fitting, you do not know what your model "believes" before seeing data. Prior predictive checks take seconds and catch fundamental modeling errors.
Pitfall 4: Treating the Posterior Mean as a Point Estimate
The entire point of Bayesian inference is the posterior distribution. If you extract the posterior mean and throw away the rest, you have done expensive MCMC to get a regularized MLE — and you could have gotten that faster with penalized likelihood.
Pitfall 5: Using Bayesian Methods for PR
"We used a Bayesian hierarchical model" sounds impressive. But if the model has no hierarchical structure, the priors are flat, the diagnostics are not checked, and the posterior predictive checks are not run, you have added complexity without benefit. Bayesian methods must earn their complexity through improved decisions, not through terminology.
Summary
Bayesian modeling in practice requires three things that Chapter 20's conjugate models did not: a computational engine (MCMC, specifically NUTS), a diagnostic framework ($\hat{R}$, ESS, divergences), and a workflow discipline (prior predictive $\to$ fit $\to$ posterior predictive $\to$ model comparison).
PyMC provides the computational engine. ArviZ provides the diagnostic and comparison tools. The non-centered parameterization resolves the most common convergence issue in hierarchical models. WAIC and PSIS-LOO compare models in terms of out-of-sample predictive accuracy, with built-in complexity penalties.
The killer application of practical Bayesian modeling is the hierarchical model. Partial pooling — estimating group-level parameters by borrowing strength from the population distribution — consistently outperforms both complete pooling (too much bias) and no pooling (too much variance). The improvement is greatest when group sizes are unequal and some groups have sparse data, which describes nearly every real-world dataset with group structure.
The decision to use Bayesian methods should be driven by problem characteristics, not ideology. Hierarchical structure, genuine prior knowledge, small samples, and the need for calibrated uncertainty are the conditions under which Bayesian methods earn their complexity. When these conditions hold, the posterior distribution provides exactly the answer that practitioners need: a full probability distribution over parameters, directly interpretable as "what we believe given the data."
Chapter 21 Notation Reference
| Symbol | Meaning |
|---|---|
| $\theta$ | Parameter(s) of interest |
| $\theta_j$ | Group-specific parameter for group $j$ |
| $\mu, \tau$ | Population-level hyperparameters (mean and SD) |
| $\eta_j$ | Standard normal auxiliary variable (non-centered parameterization) |
| $\hat{R}$ | R-hat convergence diagnostic |
| ESS | Effective sample size |
| NUTS | No-U-Turn Sampler (adaptive HMC) |
| HMC | Hamiltonian Monte Carlo |
| WAIC | Widely Applicable Information Criterion |
| LOO | Leave-one-out cross-validation |
| PSIS | Pareto-Smoothed Importance Sampling |
| $\hat{k}$ | Pareto shape parameter (LOO diagnostic) |
| $\text{lppd}$ | Log pointwise predictive density |
| $p_{\text{WAIC}}$ | Effective number of parameters (WAIC) |