Case Study 1: MediCore Multi-Site Federated Causal Analysis — Privacy-Preserving Treatment Effect Estimation Across 12 Hospitals
Context
MediCore Pharmaceuticals is conducting a post-market observational study of Cardiopril-X across 12 hospitals in the United States and Europe. The study measures systolic blood pressure (SBP) reduction after 12 weeks of treatment, with the goal of estimating heterogeneous treatment effects across patient subgroups — a question that requires access to individual-level data from all sites, not just site-level summary statistics.
The statistical challenge was addressed in Chapter 21 with a hierarchical Bayesian model that borrowed strength across hospitals. That analysis assumed centralized access to all patient data. The regulatory reality is different. Three hospitals are in the EU (governed by GDPR), four are in the US (governed by HIPAA and state privacy laws), and five are in jurisdictions with varying data protection regimes. No single data use agreement covers all 12 sites. Even if it did, the logistical burden of creating a centralized data enclave — with audit trails, access controls, encryption at rest, and breach notification procedures for 12 separate institutional review boards — would delay the analysis by 6-12 months.
The MediCore biostatistics team proposes a federated approach: train the treatment effect model across all 12 sites without centralizing any patient-level data. Each hospital keeps its data on-premises. A central coordination server (hosted in MediCore's GCP environment) orchestrates the training. The aggregated model produces treatment effect estimates with the statistical power of the full 12-site dataset, while each hospital's individual patient records never leave its network.
The Data
Each hospital $k$ has a dataset $D_k = \{(X_i, T_i, Y_i)\}_{i=1}^{n_k}$ where $X_i \in \mathbb{R}^{15}$ are patient covariates (age, sex, BMI, baseline SBP, comorbidity count, medication history, smoking status, diabetes indicator, renal function, cholesterol, exercise level, stress score, sleep quality, alcohol consumption, diet score), $T_i \in \{0, 1\}$ is the treatment indicator, and $Y_i \in \mathbb{R}$ is the SBP reduction after 12 weeks.
| Hospital | Location | n | Treatment | Control | Positive Rate |
|---|---|---|---|---|---|
| Boston Academic | USA | 620 | 310 | 310 | 0.50 |
| Chicago Teaching | USA | 480 | 250 | 230 | 0.52 |
| Houston Methodist | USA | 340 | 180 | 160 | 0.53 |
| Atlanta Community | USA | 190 | 95 | 95 | 0.50 |
| London NHS | UK/EU | 410 | 200 | 210 | 0.49 |
| Paris Tertiary | EU | 350 | 175 | 175 | 0.50 |
| Munich University | EU | 290 | 150 | 140 | 0.52 |
| Toronto General | Canada | 260 | 130 | 130 | 0.50 |
| Tokyo Medical | Japan | 220 | 115 | 105 | 0.52 |
| São Paulo Central | Brazil | 180 | 90 | 90 | 0.50 |
| Sydney Metro | Australia | 160 | 80 | 80 | 0.50 |
| Cape Town Public | S. Africa | 100 | 50 | 50 | 0.50 |
| Total | 3,600 | 1,825 | 1,775 | 0.51 |
The data is non-IID across sites: Boston and London have older, sicker patient populations (mean age 64, mean comorbidities 3.2); Tokyo and São Paulo have younger, healthier populations (mean age 52, mean comorbidities 1.4). Cape Town has the smallest sample and the highest variance in outcomes. This heterogeneity is clinically meaningful — the treatment effect may genuinely vary across populations — and statistically challenging for federated learning.
The Architecture
from dataclasses import dataclass, field
from typing import List, Dict, Tuple, Optional
import numpy as np
@dataclass
class HospitalSiteConfig:
"""Configuration for one hospital in the federated study."""
site_id: str
location: str
n_patients: int
n_treated: int
n_control: int
mean_age: float
mean_comorbidities: float
data_governance: str # "GDPR", "HIPAA", "local"
@dataclass
class FederatedCausalConfig:
"""Configuration for the federated causal analysis."""
sites: List[HospitalSiteConfig]
n_covariates: int = 15
local_epochs: int = 5
federated_rounds: int = 50
learning_rate: float = 0.001
dp_epsilon: float = 3.0
dp_delta: float = 1e-6
max_grad_norm: float = 1.0
secure_aggregation: bool = True
@property
def total_patients(self) -> int:
return sum(s.n_patients for s in self.sites)
# Define the 12-site configuration
sites = [
HospitalSiteConfig("boston", "USA", 620, 310, 310, 63.2, 3.1, "HIPAA"),
HospitalSiteConfig("chicago", "USA", 480, 250, 230, 61.8, 2.9, "HIPAA"),
HospitalSiteConfig("houston", "USA", 340, 180, 160, 59.4, 2.7, "HIPAA"),
HospitalSiteConfig("atlanta", "USA", 190, 95, 95, 58.1, 2.5, "HIPAA"),
HospitalSiteConfig("london", "UK/EU", 410, 200, 210, 64.7, 3.4, "GDPR"),
HospitalSiteConfig("paris", "EU", 350, 175, 175, 62.3, 3.0, "GDPR"),
HospitalSiteConfig("munich", "EU", 290, 150, 140, 60.9, 2.8, "GDPR"),
HospitalSiteConfig("toronto", "Canada", 260, 130, 130, 57.6, 2.3, "local"),
HospitalSiteConfig("tokyo", "Japan", 220, 115, 105, 52.1, 1.5, "local"),
HospitalSiteConfig("sao_paulo", "Brazil", 180, 90, 90, 51.8, 1.3, "local"),
HospitalSiteConfig("sydney", "Australia", 160, 80, 80, 56.4, 2.1, "local"),
HospitalSiteConfig("cape_town", "S. Africa", 100, 50, 50, 54.2, 1.8, "local"),
]
config = FederatedCausalConfig(sites=sites, dp_epsilon=3.0)
print(f"MediCore Federated Causal Analysis Configuration")
print(f"=" * 60)
print(f" Total patients: {config.total_patients:,}")
print(f" Sites: {len(config.sites)}")
print(f" Privacy: ε={config.dp_epsilon}, δ={config.dp_delta}")
print(f" Secure aggregation: {config.secure_aggregation}")
print(f" Federated rounds: {config.federated_rounds}")
print(f" Local epochs per round: {config.local_epochs}")
MediCore Federated Causal Analysis Configuration
============================================================
Total patients: 3,600
Sites: 12
Privacy: ε=3.0, δ=1e-06
Secure aggregation: True
Federated rounds: 50
Local epochs per round: 5
The Model: Federated T-Learner
The team uses a T-learner approach (Chapter 18) to estimate conditional average treatment effects (CATE): train one model $\mu_1(x)$ on treated outcomes and one model $\mu_0(x)$ on control outcomes, then estimate $\hat{\tau}(x) = \hat{\mu}_1(x) - \hat{\mu}_0(x)$. Both models are trained federatedly with DP-SGD.
import torch
import torch.nn as nn
class TreatmentEffectNetwork(nn.Module):
"""
Neural network for estimating conditional outcomes.
One instance for mu_1 (treated) and one for mu_0 (control).
Uses LayerNorm instead of BatchNorm for Opacus compatibility.
"""
def __init__(self, input_dim: int = 15, hidden_dims: List[int] = None):
super().__init__()
if hidden_dims is None:
hidden_dims = [64, 32, 16]
layers = []
prev_dim = input_dim
for h_dim in hidden_dims:
layers.extend([
nn.Linear(prev_dim, h_dim),
nn.LayerNorm(h_dim),
nn.ReLU(),
nn.Dropout(0.1),
])
prev_dim = h_dim
layers.append(nn.Linear(prev_dim, 1))
self.network = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.network(x).squeeze(-1)
class FederatedTLearner:
"""
Federated T-Learner for heterogeneous treatment effect estimation.
Trains mu_1 and mu_0 models across sites using FedAvg with DP-SGD.
"""
def __init__(self, input_dim: int, config: FederatedCausalConfig):
self.mu1_model = TreatmentEffectNetwork(input_dim)
self.mu0_model = TreatmentEffectNetwork(input_dim)
self.config = config
def estimate_cate(self, x: torch.Tensor) -> torch.Tensor:
"""
Estimate CATE for new patients.
Parameters
----------
x : torch.Tensor
Patient covariates (n_patients, n_covariates).
Returns
-------
torch.Tensor
Estimated treatment effects (n_patients,).
"""
self.mu1_model.eval()
self.mu0_model.eval()
with torch.no_grad():
mu1_hat = self.mu1_model(x)
mu0_hat = self.mu0_model(x)
return mu1_hat - mu0_hat
def estimate_ate(self, x: torch.Tensor) -> Tuple[float, float]:
"""
Estimate ATE and its standard error.
Returns
-------
tuple[float, float]
(ate_estimate, standard_error)
"""
cate = self.estimate_cate(x)
ate = cate.mean().item()
se = cate.std().item() / np.sqrt(len(cate))
return ate, se
# Simulate the federated training results
torch.manual_seed(42)
np.random.seed(42)
# Generate synthetic patient data with known treatment effects
n_total = 3600
n_features = 15
# Patient covariates
X = torch.randn(n_total, n_features)
# True treatment effect varies with age (feature 0) and comorbidities (feature 4)
# Older, sicker patients benefit more
true_cate = 8.0 + 2.0 * X[:, 0] + 1.5 * X[:, 4] # Mean ~8 mmHg
true_ate = true_cate.mean().item()
print(f"\nTrue ATE: {true_ate:.2f} mmHg")
print(f"True CATE range: [{true_cate.min():.2f}, {true_cate.max():.2f}]")
True ATE: 8.02 mmHg
True CATE range: [-0.14, 17.23]
Results
The team runs the federated T-learner at three privacy levels and compares against a centralized (non-private) baseline.
@dataclass
class FederatedCausalResult:
"""Result of one federated causal analysis experiment."""
method: str
epsilon: float
ate_estimate: float
ate_se: float
ate_bias: float # |estimate - true|
cate_rmse: float # RMSE of individual treatment effect estimates
coverage_95: float # Fraction of true CATEs within estimated 95% CI
n_rounds: int
total_messages: int # Communication cost
# Simulated results based on realistic DP-federated performance
true_ate = 8.02
results = [
FederatedCausalResult(
"Centralized (no DP)", float("inf"),
ate_estimate=8.14, ate_se=0.31,
ate_bias=0.12, cate_rmse=1.87,
coverage_95=0.94, n_rounds=0, total_messages=0,
),
FederatedCausalResult(
"Federated + DP (ε=8)", 8.0,
ate_estimate=8.28, ate_se=0.45,
ate_bias=0.26, cate_rmse=2.34,
coverage_95=0.91, n_rounds=50, total_messages=1200,
),
FederatedCausalResult(
"Federated + DP (ε=3)", 3.0,
ate_estimate=8.51, ate_se=0.62,
ate_bias=0.49, cate_rmse=3.12,
coverage_95=0.88, n_rounds=50, total_messages=1200,
),
FederatedCausalResult(
"Federated + DP (ε=1)", 1.0,
ate_estimate=9.14, ate_se=1.23,
ate_bias=1.12, cate_rmse=4.87,
coverage_95=0.82, n_rounds=50, total_messages=1200,
),
FederatedCausalResult(
"Per-site (no pooling)", float("inf"),
ate_estimate=7.68, ate_se=0.89,
ate_bias=0.34, cate_rmse=3.45,
coverage_95=0.79, n_rounds=0, total_messages=0,
),
]
print("MediCore Federated Causal Analysis: ATE Estimation")
print("=" * 95)
print(f"{'Method':>28s} {'ε':>5s} {'ATE':>6s} {'SE':>6s} "
f"{'|Bias|':>7s} {'CATE RMSE':>10s} {'95% Cov':>8s}")
print("-" * 95)
for r in results:
eps_str = f"{r.epsilon:.0f}" if r.epsilon != float("inf") else "∞"
print(f"{r.method:>28s} {eps_str:>5s} {r.ate_estimate:>6.2f} "
f"{r.ate_se:>6.2f} {r.ate_bias:>7.2f} {r.cate_rmse:>10.2f} "
f"{r.coverage_95:>8.2f}")
print(f"\n True ATE: {true_ate:.2f} mmHg")
MediCore Federated Causal Analysis: ATE Estimation
===============================================================================================
Method ε ATE SE |Bias| CATE RMSE 95% Cov
-----------------------------------------------------------------------------------------------
Centralized (no DP) ∞ 8.14 0.31 0.12 1.87 0.94
Federated + DP (ε=8) 8 8.28 0.45 0.26 2.34 0.91
Federated + DP (ε=3) 3 8.51 0.62 0.49 3.12 0.88
Federated + DP (ε=1) 1 9.14 1.23 1.12 4.87 0.82
Per-site (no pooling) ∞ 7.68 0.89 0.34 3.45 0.79
True ATE: 8.02 mmHg
Analysis
Three findings stand out.
Federated learning with DP at $\varepsilon = 3$ outperforms per-site analysis without DP. The per-site approach — each hospital estimates its own treatment effect and then the results are meta-analyzed — produces CATE RMSE of 3.45 and 95% coverage of only 0.79, both worse than federated DP at $\varepsilon = 3$ (CATE RMSE 3.12, coverage 0.88). The statistical power gained from pooling across 3,600 patients more than compensates for the noise added by differential privacy. This is the key insight: for small-to-moderate datasets, the information gained from federated pooling exceeds the information lost to DP noise.
ATE estimation is robust to DP noise; CATE estimation is not. The ATE (an aggregate across 3,600 patients) is estimated within 0.5 mmHg even at $\varepsilon = 3$. The CATE (a per-patient estimate) is far more sensitive to noise: CATE RMSE increases from 1.87 (centralized) to 3.12 (federated $\varepsilon = 3$) — a 67% degradation. This reflects the sensitivity/aggregation tradeoff from Section 32.2: aggregate queries benefit from large $n$; individual-level estimates do not.
Strong privacy ($\varepsilon = 1$) renders CATE estimation unreliable for clinical use. The CATE RMSE of 4.87 and 95% coverage of 0.82 mean that treatment effect estimates for individual patients are too noisy for personalized medicine decisions. The ATE is still reasonable (bias 1.12, within the confidence interval), so $\varepsilon = 1$ is sufficient for answering "does the treatment work on average?" but not "which patients benefit most?"
Regulatory Outcome
The MediCore team presents two analyses to the regulatory committee:
-
Primary analysis (ATE): Federated + DP at $\varepsilon = 3$, estimating ATE = 8.51 mmHg (SE 0.62, 95% CI [7.29, 9.73]). This is within 0.5 mmHg of the centralized estimate and provides a formal DP guarantee that satisfies GDPR Article 5(1)(f) proportionality requirements.
-
Exploratory analysis (CATE): Federated + DP at $\varepsilon = 8$, identifying patient subgroups that benefit most (older patients with multiple comorbidities). This uses a larger privacy budget because it requires finer-grained estimates, and it is presented as exploratory (not confirmatory) to acknowledge the additional noise.
The entire analysis was completed in 3 weeks — compared to the 6-12 month timeline estimated for a centralized data enclave. No patient-level data crossed institutional boundaries. The privacy budget was formally tracked and reported: $\varepsilon = 3.0$ for the primary analysis, $\varepsilon = 8.0$ for the exploratory analysis, with $\delta = 10^{-6}$ throughout.
Lessons Learned
-
Federated DP is not just a privacy constraint — it is an enabler. The analysis would not have happened at all without federated learning, because the data sharing agreements were intractable. Accepting a modest privacy-utility cost enabled an analysis that would otherwise have been impossible.
-
Non-IID data required careful handling. The initial FedAvg run showed poor convergence because Tokyo and São Paulo's young, healthy populations diverged from the global model during local training. Adding a FedProx proximal term ($\mu = 0.01$) stabilized convergence. The team also reduced local epochs from 5 to 2 for the most heterogeneous sites.
-
Secure aggregation was necessary for institutional trust, not just formal privacy. Several hospital IRBs required a contractual guarantee that MediCore could not observe individual site updates. Secure aggregation (via the TF Encrypted library running on MediCore's coordination server) provided this guarantee. The computational overhead was approximately 3x per round, acceptable for 50 rounds of a treatment effect model.
-
The privacy budget must be pre-registered. The biostatistics team initially planned a single analysis at $\varepsilon = 3$. When the CATE results were noisy, the clinical team requested a re-run at $\varepsilon = 8$. This was permitted because the total budget was pre-registered as $\varepsilon = 11$ (primary + exploratory), approved by the IRB before the data was touched. Without pre-registration, the second analysis would have required a new IRB submission.