Case Study 1: MediCore Multi-Site Federated Causal Analysis — Privacy-Preserving Treatment Effect Estimation Across 12 Hospitals

Context

MediCore Pharmaceuticals is conducting a post-market observational study of Cardiopril-X across 12 hospitals in the United States and Europe. The study measures systolic blood pressure (SBP) reduction after 12 weeks of treatment, with the goal of estimating heterogeneous treatment effects across patient subgroups — a question that requires access to individual-level data from all sites, not just site-level summary statistics.

The statistical challenge was addressed in Chapter 21 with a hierarchical Bayesian model that borrowed strength across hospitals. That analysis assumed centralized access to all patient data. The regulatory reality is different. Three hospitals are in the EU (governed by GDPR), four are in the US (governed by HIPAA and state privacy laws), and five are in jurisdictions with varying data protection regimes. No single data use agreement covers all 12 sites. Even if it did, the logistical burden of creating a centralized data enclave — with audit trails, access controls, encryption at rest, and breach notification procedures for 12 separate institutional review boards — would delay the analysis by 6-12 months.

The MediCore biostatistics team proposes a federated approach: train the treatment effect model across all 12 sites without centralizing any patient-level data. Each hospital keeps its data on-premises. A central coordination server (hosted in MediCore's GCP environment) orchestrates the training. The aggregated model produces treatment effect estimates with the statistical power of the full 12-site dataset, while each hospital's individual patient records never leave its network.

The Data

Each hospital $k$ has a dataset $D_k = \{(X_i, T_i, Y_i)\}_{i=1}^{n_k}$ where $X_i \in \mathbb{R}^{15}$ are patient covariates (age, sex, BMI, baseline SBP, comorbidity count, medication history, smoking status, diabetes indicator, renal function, cholesterol, exercise level, stress score, sleep quality, alcohol consumption, diet score), $T_i \in \{0, 1\}$ is the treatment indicator, and $Y_i \in \mathbb{R}$ is the SBP reduction after 12 weeks.

Hospital Location n Treatment Control Positive Rate
Boston Academic USA 620 310 310 0.50
Chicago Teaching USA 480 250 230 0.52
Houston Methodist USA 340 180 160 0.53
Atlanta Community USA 190 95 95 0.50
London NHS UK/EU 410 200 210 0.49
Paris Tertiary EU 350 175 175 0.50
Munich University EU 290 150 140 0.52
Toronto General Canada 260 130 130 0.50
Tokyo Medical Japan 220 115 105 0.52
São Paulo Central Brazil 180 90 90 0.50
Sydney Metro Australia 160 80 80 0.50
Cape Town Public S. Africa 100 50 50 0.50
Total 3,600 1,825 1,775 0.51

The data is non-IID across sites: Boston and London have older, sicker patient populations (mean age 64, mean comorbidities 3.2); Tokyo and São Paulo have younger, healthier populations (mean age 52, mean comorbidities 1.4). Cape Town has the smallest sample and the highest variance in outcomes. This heterogeneity is clinically meaningful — the treatment effect may genuinely vary across populations — and statistically challenging for federated learning.

The Architecture

from dataclasses import dataclass, field
from typing import List, Dict, Tuple, Optional
import numpy as np


@dataclass
class HospitalSiteConfig:
    """Configuration for one hospital in the federated study."""
    site_id: str
    location: str
    n_patients: int
    n_treated: int
    n_control: int
    mean_age: float
    mean_comorbidities: float
    data_governance: str  # "GDPR", "HIPAA", "local"


@dataclass
class FederatedCausalConfig:
    """Configuration for the federated causal analysis."""
    sites: List[HospitalSiteConfig]
    n_covariates: int = 15
    local_epochs: int = 5
    federated_rounds: int = 50
    learning_rate: float = 0.001
    dp_epsilon: float = 3.0
    dp_delta: float = 1e-6
    max_grad_norm: float = 1.0
    secure_aggregation: bool = True

    @property
    def total_patients(self) -> int:
        return sum(s.n_patients for s in self.sites)


# Define the 12-site configuration
sites = [
    HospitalSiteConfig("boston", "USA", 620, 310, 310, 63.2, 3.1, "HIPAA"),
    HospitalSiteConfig("chicago", "USA", 480, 250, 230, 61.8, 2.9, "HIPAA"),
    HospitalSiteConfig("houston", "USA", 340, 180, 160, 59.4, 2.7, "HIPAA"),
    HospitalSiteConfig("atlanta", "USA", 190, 95, 95, 58.1, 2.5, "HIPAA"),
    HospitalSiteConfig("london", "UK/EU", 410, 200, 210, 64.7, 3.4, "GDPR"),
    HospitalSiteConfig("paris", "EU", 350, 175, 175, 62.3, 3.0, "GDPR"),
    HospitalSiteConfig("munich", "EU", 290, 150, 140, 60.9, 2.8, "GDPR"),
    HospitalSiteConfig("toronto", "Canada", 260, 130, 130, 57.6, 2.3, "local"),
    HospitalSiteConfig("tokyo", "Japan", 220, 115, 105, 52.1, 1.5, "local"),
    HospitalSiteConfig("sao_paulo", "Brazil", 180, 90, 90, 51.8, 1.3, "local"),
    HospitalSiteConfig("sydney", "Australia", 160, 80, 80, 56.4, 2.1, "local"),
    HospitalSiteConfig("cape_town", "S. Africa", 100, 50, 50, 54.2, 1.8, "local"),
]

config = FederatedCausalConfig(sites=sites, dp_epsilon=3.0)

print(f"MediCore Federated Causal Analysis Configuration")
print(f"=" * 60)
print(f"  Total patients: {config.total_patients:,}")
print(f"  Sites: {len(config.sites)}")
print(f"  Privacy: ε={config.dp_epsilon}, δ={config.dp_delta}")
print(f"  Secure aggregation: {config.secure_aggregation}")
print(f"  Federated rounds: {config.federated_rounds}")
print(f"  Local epochs per round: {config.local_epochs}")
MediCore Federated Causal Analysis Configuration
============================================================
  Total patients: 3,600
  Sites: 12
  Privacy: ε=3.0, δ=1e-06
  Secure aggregation: True
  Federated rounds: 50
  Local epochs per round: 5

The Model: Federated T-Learner

The team uses a T-learner approach (Chapter 18) to estimate conditional average treatment effects (CATE): train one model $\mu_1(x)$ on treated outcomes and one model $\mu_0(x)$ on control outcomes, then estimate $\hat{\tau}(x) = \hat{\mu}_1(x) - \hat{\mu}_0(x)$. Both models are trained federatedly with DP-SGD.

import torch
import torch.nn as nn


class TreatmentEffectNetwork(nn.Module):
    """
    Neural network for estimating conditional outcomes.

    One instance for mu_1 (treated) and one for mu_0 (control).
    Uses LayerNorm instead of BatchNorm for Opacus compatibility.
    """

    def __init__(self, input_dim: int = 15, hidden_dims: List[int] = None):
        super().__init__()
        if hidden_dims is None:
            hidden_dims = [64, 32, 16]

        layers = []
        prev_dim = input_dim
        for h_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, h_dim),
                nn.LayerNorm(h_dim),
                nn.ReLU(),
                nn.Dropout(0.1),
            ])
            prev_dim = h_dim
        layers.append(nn.Linear(prev_dim, 1))
        self.network = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x).squeeze(-1)


class FederatedTLearner:
    """
    Federated T-Learner for heterogeneous treatment effect estimation.

    Trains mu_1 and mu_0 models across sites using FedAvg with DP-SGD.
    """

    def __init__(self, input_dim: int, config: FederatedCausalConfig):
        self.mu1_model = TreatmentEffectNetwork(input_dim)
        self.mu0_model = TreatmentEffectNetwork(input_dim)
        self.config = config

    def estimate_cate(self, x: torch.Tensor) -> torch.Tensor:
        """
        Estimate CATE for new patients.

        Parameters
        ----------
        x : torch.Tensor
            Patient covariates (n_patients, n_covariates).

        Returns
        -------
        torch.Tensor
            Estimated treatment effects (n_patients,).
        """
        self.mu1_model.eval()
        self.mu0_model.eval()

        with torch.no_grad():
            mu1_hat = self.mu1_model(x)
            mu0_hat = self.mu0_model(x)

        return mu1_hat - mu0_hat

    def estimate_ate(self, x: torch.Tensor) -> Tuple[float, float]:
        """
        Estimate ATE and its standard error.

        Returns
        -------
        tuple[float, float]
            (ate_estimate, standard_error)
        """
        cate = self.estimate_cate(x)
        ate = cate.mean().item()
        se = cate.std().item() / np.sqrt(len(cate))
        return ate, se


# Simulate the federated training results
torch.manual_seed(42)
np.random.seed(42)

# Generate synthetic patient data with known treatment effects
n_total = 3600
n_features = 15

# Patient covariates
X = torch.randn(n_total, n_features)

# True treatment effect varies with age (feature 0) and comorbidities (feature 4)
# Older, sicker patients benefit more
true_cate = 8.0 + 2.0 * X[:, 0] + 1.5 * X[:, 4]  # Mean ~8 mmHg
true_ate = true_cate.mean().item()

print(f"\nTrue ATE: {true_ate:.2f} mmHg")
print(f"True CATE range: [{true_cate.min():.2f}, {true_cate.max():.2f}]")
True ATE: 8.02 mmHg
True CATE range: [-0.14, 17.23]

Results

The team runs the federated T-learner at three privacy levels and compares against a centralized (non-private) baseline.

@dataclass
class FederatedCausalResult:
    """Result of one federated causal analysis experiment."""
    method: str
    epsilon: float
    ate_estimate: float
    ate_se: float
    ate_bias: float  # |estimate - true|
    cate_rmse: float  # RMSE of individual treatment effect estimates
    coverage_95: float  # Fraction of true CATEs within estimated 95% CI
    n_rounds: int
    total_messages: int  # Communication cost


# Simulated results based on realistic DP-federated performance
true_ate = 8.02

results = [
    FederatedCausalResult(
        "Centralized (no DP)", float("inf"),
        ate_estimate=8.14, ate_se=0.31,
        ate_bias=0.12, cate_rmse=1.87,
        coverage_95=0.94, n_rounds=0, total_messages=0,
    ),
    FederatedCausalResult(
        "Federated + DP (ε=8)", 8.0,
        ate_estimate=8.28, ate_se=0.45,
        ate_bias=0.26, cate_rmse=2.34,
        coverage_95=0.91, n_rounds=50, total_messages=1200,
    ),
    FederatedCausalResult(
        "Federated + DP (ε=3)", 3.0,
        ate_estimate=8.51, ate_se=0.62,
        ate_bias=0.49, cate_rmse=3.12,
        coverage_95=0.88, n_rounds=50, total_messages=1200,
    ),
    FederatedCausalResult(
        "Federated + DP (ε=1)", 1.0,
        ate_estimate=9.14, ate_se=1.23,
        ate_bias=1.12, cate_rmse=4.87,
        coverage_95=0.82, n_rounds=50, total_messages=1200,
    ),
    FederatedCausalResult(
        "Per-site (no pooling)", float("inf"),
        ate_estimate=7.68, ate_se=0.89,
        ate_bias=0.34, cate_rmse=3.45,
        coverage_95=0.79, n_rounds=0, total_messages=0,
    ),
]

print("MediCore Federated Causal Analysis: ATE Estimation")
print("=" * 95)
print(f"{'Method':>28s}  {'ε':>5s}  {'ATE':>6s}  {'SE':>6s}  "
      f"{'|Bias|':>7s}  {'CATE RMSE':>10s}  {'95% Cov':>8s}")
print("-" * 95)

for r in results:
    eps_str = f"{r.epsilon:.0f}" if r.epsilon != float("inf") else "∞"
    print(f"{r.method:>28s}  {eps_str:>5s}  {r.ate_estimate:>6.2f}  "
          f"{r.ate_se:>6.2f}  {r.ate_bias:>7.2f}  {r.cate_rmse:>10.2f}  "
          f"{r.coverage_95:>8.2f}")

print(f"\n  True ATE: {true_ate:.2f} mmHg")
MediCore Federated Causal Analysis: ATE Estimation
===============================================================================================
                      Method      ε     ATE      SE   |Bias|  CATE RMSE   95% Cov
-----------------------------------------------------------------------------------------------
       Centralized (no DP)      ∞    8.14    0.31     0.12        1.87      0.94
      Federated + DP (ε=8)      8    8.28    0.45     0.26        2.34      0.91
      Federated + DP (ε=3)      3    8.51    0.62     0.49        3.12      0.88
      Federated + DP (ε=1)      1    9.14    1.23     1.12        4.87      0.82
      Per-site (no pooling)      ∞    7.68    0.89     0.34        3.45      0.79

  True ATE: 8.02 mmHg

Analysis

Three findings stand out.

Federated learning with DP at $\varepsilon = 3$ outperforms per-site analysis without DP. The per-site approach — each hospital estimates its own treatment effect and then the results are meta-analyzed — produces CATE RMSE of 3.45 and 95% coverage of only 0.79, both worse than federated DP at $\varepsilon = 3$ (CATE RMSE 3.12, coverage 0.88). The statistical power gained from pooling across 3,600 patients more than compensates for the noise added by differential privacy. This is the key insight: for small-to-moderate datasets, the information gained from federated pooling exceeds the information lost to DP noise.

ATE estimation is robust to DP noise; CATE estimation is not. The ATE (an aggregate across 3,600 patients) is estimated within 0.5 mmHg even at $\varepsilon = 3$. The CATE (a per-patient estimate) is far more sensitive to noise: CATE RMSE increases from 1.87 (centralized) to 3.12 (federated $\varepsilon = 3$) — a 67% degradation. This reflects the sensitivity/aggregation tradeoff from Section 32.2: aggregate queries benefit from large $n$; individual-level estimates do not.

Strong privacy ($\varepsilon = 1$) renders CATE estimation unreliable for clinical use. The CATE RMSE of 4.87 and 95% coverage of 0.82 mean that treatment effect estimates for individual patients are too noisy for personalized medicine decisions. The ATE is still reasonable (bias 1.12, within the confidence interval), so $\varepsilon = 1$ is sufficient for answering "does the treatment work on average?" but not "which patients benefit most?"

Regulatory Outcome

The MediCore team presents two analyses to the regulatory committee:

  1. Primary analysis (ATE): Federated + DP at $\varepsilon = 3$, estimating ATE = 8.51 mmHg (SE 0.62, 95% CI [7.29, 9.73]). This is within 0.5 mmHg of the centralized estimate and provides a formal DP guarantee that satisfies GDPR Article 5(1)(f) proportionality requirements.

  2. Exploratory analysis (CATE): Federated + DP at $\varepsilon = 8$, identifying patient subgroups that benefit most (older patients with multiple comorbidities). This uses a larger privacy budget because it requires finer-grained estimates, and it is presented as exploratory (not confirmatory) to acknowledge the additional noise.

The entire analysis was completed in 3 weeks — compared to the 6-12 month timeline estimated for a centralized data enclave. No patient-level data crossed institutional boundaries. The privacy budget was formally tracked and reported: $\varepsilon = 3.0$ for the primary analysis, $\varepsilon = 8.0$ for the exploratory analysis, with $\delta = 10^{-6}$ throughout.

Lessons Learned

  1. Federated DP is not just a privacy constraint — it is an enabler. The analysis would not have happened at all without federated learning, because the data sharing agreements were intractable. Accepting a modest privacy-utility cost enabled an analysis that would otherwise have been impossible.

  2. Non-IID data required careful handling. The initial FedAvg run showed poor convergence because Tokyo and São Paulo's young, healthy populations diverged from the global model during local training. Adding a FedProx proximal term ($\mu = 0.01$) stabilized convergence. The team also reduced local epochs from 5 to 2 for the most heterogeneous sites.

  3. Secure aggregation was necessary for institutional trust, not just formal privacy. Several hospital IRBs required a contractual guarantee that MediCore could not observe individual site updates. Secure aggregation (via the TF Encrypted library running on MediCore's coordination server) provided this guarantee. The computational overhead was approximately 3x per round, acceptable for 50 rounds of a treatment effect model.

  4. The privacy budget must be pre-registered. The biostatistics team initially planned a single analysis at $\varepsilon = 3$. When the CATE results were noisy, the clinical team requested a re-run at $\varepsilon = 8$. This was permitted because the total budget was pre-registered as $\varepsilon = 11$ (primary + exploratory), approved by the IRB before the data was touched. Without pre-registration, the second analysis would have required a new IRB submission.