Case Study 1: MediCore Synthetic EHR Data — Preserving Statistics Without Privacy Risk

Context

MediCore Pharmaceuticals maintains a longitudinal electronic health record (EHR) database covering 2.1 million patients across 340 hospitals. The data is invaluable for observational studies, drug safety monitoring, and clinical trial design — but sharing it is a regulatory and ethical minefield. HIPAA, GDPR, and institutional review boards impose strict constraints on de-identification, and even de-identified records can be re-identified through linkage attacks when the feature space is rich.

The data science team needs to enable three use cases without exposing real patient data:

  1. External collaboration. Share realistic patient-level data with academic partners who lack access to MediCore's EHR system, enabling them to develop and test analytical methods.
  2. Model development. Build and iterate on predictive models (e.g., adverse event prediction, treatment response estimation) using synthetic data, then validate on real data in a controlled environment.
  3. Regulatory submissions. Demonstrate to the FDA that a causal estimation method (Chapter 16-18) works correctly by showing results on synthetic data where the ground truth is known.

The solution: train a generative model on the real EHR data and produce a synthetic dataset that preserves statistical properties — marginal distributions, feature correlations, clinical patterns — while containing no real patient's record.

The Data

The EHR data is tabular: each row is a patient, and columns represent demographics, diagnoses, lab values, medications, and outcomes. We simulate a representative subset with realistic clinical structure.

import numpy as np
import pandas as pd
from typing import Dict, Tuple
from sklearn.preprocessing import StandardScaler
from scipy import stats


def generate_synthetic_ehr(
    n_patients: int = 50000, seed: int = 42
) -> pd.DataFrame:
    """Generate a synthetic EHR dataset with realistic clinical correlations.

    Simulates patient records with demographics, lab values, diagnoses,
    and treatment outcomes. Feature correlations reflect known clinical
    relationships (e.g., age-HbA1c-diabetes, BMI-cholesterol-cardiovascular).

    Args:
        n_patients: Number of patient records.
        seed: Random seed.

    Returns:
        DataFrame with patient records.
    """
    rng = np.random.RandomState(seed)

    # Demographics
    age = rng.normal(55, 15, n_patients).clip(18, 95)
    sex = rng.binomial(1, 0.52, n_patients)  # 1 = female
    bmi = rng.normal(28, 6, n_patients).clip(15, 55)

    # Lab values with clinical correlations
    # HbA1c increases with age and BMI
    hba1c = (
        4.5
        + 0.02 * (age - 55)
        + 0.04 * (bmi - 28)
        + rng.normal(0, 0.8, n_patients)
    ).clip(3.5, 14.0)

    # Total cholesterol correlated with BMI and age
    cholesterol = (
        180
        + 0.5 * (age - 55)
        + 1.2 * (bmi - 28)
        + rng.normal(0, 30, n_patients)
    ).clip(100, 350)

    # Systolic BP correlated with age and BMI
    systolic_bp = (
        120
        + 0.4 * (age - 55)
        + 0.6 * (bmi - 28)
        + rng.normal(0, 15, n_patients)
    ).clip(80, 220)

    # Creatinine — mildly correlated with age, sex-dependent
    creatinine = (
        0.9 + 0.15 * sex  # Higher baseline in females? No — reverse
        - 0.15 * sex       # Actually lower in females
        + 0.005 * (age - 55)
        + rng.normal(0, 0.2, n_patients)
    ).clip(0.4, 5.0)

    # Diagnoses (binary) — correlated with risk factors
    diabetes_prob = 1 / (1 + np.exp(-(
        -3.0 + 0.03 * age + 0.05 * bmi + 0.3 * (hba1c - 6.5)
    )))
    diabetes = rng.binomial(1, diabetes_prob)

    hypertension_prob = 1 / (1 + np.exp(-(
        -2.5 + 0.025 * age + 0.04 * bmi + 0.02 * (systolic_bp - 120)
    )))
    hypertension = rng.binomial(1, hypertension_prob)

    cvd_prob = 1 / (1 + np.exp(-(
        -4.0 + 0.04 * age + 0.03 * bmi + 0.005 * cholesterol
        + 0.5 * diabetes + 0.4 * hypertension
    )))
    cvd = rng.binomial(1, cvd_prob)

    # Treatment and outcome
    # Treatment more likely for sicker patients (confounding)
    treatment_prob = 1 / (1 + np.exp(-(
        -1.0 + 0.3 * diabetes + 0.2 * hypertension + 0.4 * cvd
        + 0.01 * (hba1c - 6.0)
    )))
    treatment = rng.binomial(1, treatment_prob)

    # Outcome: treatment has causal effect, confounded by severity
    outcome_prob = 1 / (1 + np.exp(-(
        -2.0 + 0.3 * diabetes + 0.2 * hypertension + 0.5 * cvd
        + 0.02 * age - 0.8 * treatment  # Treatment reduces risk
    )))
    adverse_event = rng.binomial(1, outcome_prob)

    return pd.DataFrame({
        "age": age.round(1),
        "sex": sex,
        "bmi": bmi.round(1),
        "hba1c": hba1c.round(2),
        "cholesterol": cholesterol.round(0),
        "systolic_bp": systolic_bp.round(0),
        "creatinine": creatinine.round(2),
        "diabetes": diabetes,
        "hypertension": hypertension,
        "cvd": cvd,
        "treatment": treatment,
        "adverse_event": adverse_event,
    })

The Generative Model: Tabular VAE

Image VAEs use convolutional architectures and pixel-wise losses. Tabular data requires a different approach: mixed data types (continuous, binary, categorical) need type-specific reconstruction losses, and feature correlations must be preserved without spatial locality assumptions.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset


class TabularVAE(nn.Module):
    """VAE for mixed-type tabular data (continuous + binary features).

    Uses MSE loss for continuous features and BCE loss for binary features.
    The architecture is a simple MLP with batch normalization and dropout.

    Args:
        n_continuous: Number of continuous features.
        n_binary: Number of binary features.
        hidden_dim: Hidden layer size.
        latent_dim: Latent space dimensionality.
        dropout: Dropout probability.
    """

    def __init__(
        self,
        n_continuous: int = 6,
        n_binary: int = 6,
        hidden_dim: int = 256,
        latent_dim: int = 16,
        dropout: float = 0.1,
    ) -> None:
        super().__init__()
        self.n_continuous = n_continuous
        self.n_binary = n_binary
        input_dim = n_continuous + n_binary

        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
        )
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_log_var = nn.Linear(hidden_dim, latent_dim)

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
        )
        self.fc_continuous = nn.Linear(hidden_dim, n_continuous)
        self.fc_binary = nn.Linear(hidden_dim, n_binary)

    def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_log_var(h)

    def reparameterize(
        self, mu: torch.Tensor, log_var: torch.Tensor
    ) -> torch.Tensor:
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + std * eps

    def decode(
        self, z: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Decode latent vector to continuous and binary outputs.

        Args:
            z: Latent vector, shape (batch, latent_dim).

        Returns:
            Tuple of (continuous_output, binary_logits).
        """
        h = self.decoder(z)
        continuous = self.fc_continuous(h)
        binary_logits = self.fc_binary(h)
        return continuous, binary_logits

    def forward(self, x: torch.Tensor) -> Dict:
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        continuous, binary_logits = self.decode(z)
        return {
            "continuous": continuous,
            "binary_logits": binary_logits,
            "mu": mu,
            "log_var": log_var,
        }


def tabular_vae_loss(
    outputs: Dict,
    x: torch.Tensor,
    n_continuous: int,
    beta: float = 1.0,
) -> Tuple[torch.Tensor, Dict[str, float]]:
    """Compute loss for the tabular VAE.

    Continuous features: MSE loss.
    Binary features: binary cross-entropy with logits.
    KL divergence: closed-form Gaussian KL.

    Args:
        outputs: Dictionary from TabularVAE.forward().
        x: Original input, shape (batch, n_continuous + n_binary).
        n_continuous: Number of continuous features.
        beta: KL weight.

    Returns:
        Tuple of (total_loss, loss_components_dict).
    """
    x_cont = x[:, :n_continuous]
    x_bin = x[:, n_continuous:]

    recon_cont = F.mse_loss(outputs["continuous"], x_cont)
    recon_bin = F.binary_cross_entropy_with_logits(
        outputs["binary_logits"], x_bin
    )
    kl = -0.5 * torch.mean(
        torch.sum(
            1 + outputs["log_var"] - outputs["mu"].pow(2) - outputs["log_var"].exp(),
            dim=1,
        )
    )

    total = recon_cont + recon_bin + beta * kl
    components = {
        "recon_continuous": recon_cont.item(),
        "recon_binary": recon_bin.item(),
        "kl": kl.item(),
    }
    return total, components

Training and Synthetic Data Generation

def train_and_generate(
    real_df: pd.DataFrame,
    continuous_cols: list,
    binary_cols: list,
    latent_dim: int = 16,
    epochs: int = 100,
    batch_size: int = 512,
    n_synthetic: int = 50000,
    seed: int = 42,
) -> pd.DataFrame:
    """Train a tabular VAE and generate synthetic patient records.

    Args:
        real_df: Real EHR DataFrame.
        continuous_cols: Names of continuous columns.
        binary_cols: Names of binary columns.
        latent_dim: VAE latent dimensionality.
        epochs: Training epochs.
        batch_size: Batch size.
        n_synthetic: Number of synthetic records to generate.
        seed: Random seed.

    Returns:
        DataFrame of synthetic patient records.
    """
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Preprocess: standardize continuous, keep binary as-is
    scaler = StandardScaler()
    cont_data = scaler.fit_transform(real_df[continuous_cols].values)
    bin_data = real_df[binary_cols].values.astype(np.float32)
    all_data = np.hstack([cont_data, bin_data]).astype(np.float32)

    dataset = TensorDataset(torch.tensor(all_data))
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Train
    n_cont = len(continuous_cols)
    n_bin = len(binary_cols)
    model = TabularVAE(n_cont, n_bin, hidden_dim=256, latent_dim=latent_dim).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        for (batch,) in loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            outputs = model(batch)
            loss, _ = tabular_vae_loss(outputs, batch, n_cont, beta=0.5)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        if (epoch + 1) % 25 == 0:
            print(f"Epoch {epoch+1:3d} | Loss: {epoch_loss / len(loader):.4f}")

    # Generate synthetic data
    model.eval()
    with torch.no_grad():
        z = torch.randn(n_synthetic, latent_dim, device=device)
        cont_out, bin_logits = model.decode(z)
        cont_synth = cont_out.cpu().numpy()
        bin_synth = (torch.sigmoid(bin_logits) > 0.5).float().cpu().numpy()

    # Inverse transform continuous features
    cont_synth = scaler.inverse_transform(cont_synth)

    # Assemble DataFrame
    synth_df = pd.DataFrame(cont_synth, columns=continuous_cols)
    for i, col in enumerate(binary_cols):
        synth_df[col] = bin_synth[:, i].astype(int)

    # Post-processing: clip to realistic ranges
    synth_df["age"] = synth_df["age"].clip(18, 95).round(1)
    synth_df["bmi"] = synth_df["bmi"].clip(15, 55).round(1)
    synth_df["hba1c"] = synth_df["hba1c"].clip(3.5, 14.0).round(2)
    synth_df["cholesterol"] = synth_df["cholesterol"].clip(100, 350).round(0)
    synth_df["systolic_bp"] = synth_df["systolic_bp"].clip(80, 220).round(0)
    synth_df["creatinine"] = synth_df["creatinine"].clip(0.4, 5.0).round(2)

    return synth_df


# ---- Run the pipeline ----

real_ehr = generate_synthetic_ehr(n_patients=50000, seed=42)
continuous_cols = ["age", "bmi", "hba1c", "cholesterol", "systolic_bp", "creatinine"]
binary_cols = ["sex", "diabetes", "hypertension", "cvd", "treatment", "adverse_event"]

synth_ehr = train_and_generate(
    real_ehr, continuous_cols, binary_cols,
    latent_dim=16, epochs=100, n_synthetic=50000
)

Evaluation: Does the Synthetic Data Preserve Clinical Structure?

The evaluation protocol checks four dimensions: marginal distributions, correlations, downstream utility, and privacy.

def evaluate_synthetic_data(
    real_df: pd.DataFrame,
    synth_df: pd.DataFrame,
    continuous_cols: list,
    binary_cols: list,
) -> Dict[str, float]:
    """Comprehensive evaluation of synthetic data quality.

    Args:
        real_df: Real patient records.
        synth_df: Synthetic patient records.
        continuous_cols: Names of continuous columns.
        binary_cols: Names of binary columns.

    Returns:
        Dictionary of evaluation metrics.
    """
    metrics = {}

    # 1. Marginal distributions (KS statistic for continuous)
    print("=== Marginal Distribution Comparison ===")
    ks_stats = []
    for col in continuous_cols:
        ks_stat, p_value = stats.ks_2samp(real_df[col], synth_df[col])
        ks_stats.append(ks_stat)
        print(f"  {col:>15}: KS = {ks_stat:.4f}, p = {p_value:.4f}")

    # Proportion comparison for binary features
    print("\n=== Binary Feature Proportions ===")
    for col in binary_cols:
        real_prop = real_df[col].mean()
        synth_prop = synth_df[col].mean()
        print(f"  {col:>15}: real = {real_prop:.4f}, synth = {synth_prop:.4f}")

    metrics["mean_ks_stat"] = np.mean(ks_stats)

    # 2. Correlation preservation
    all_cols = continuous_cols + binary_cols
    real_corr = real_df[all_cols].corr()
    synth_corr = synth_df[all_cols].corr()
    corr_diff = (real_corr - synth_corr).abs()
    metrics["mean_corr_error"] = corr_diff.values[
        np.triu_indices(len(all_cols), k=1)
    ].mean()
    print(f"\n=== Correlation Preservation ===")
    print(f"  Mean absolute correlation error: {metrics['mean_corr_error']:.4f}")

    # 3. Privacy: nearest-neighbor distance
    from sklearn.neighbors import NearestNeighbors
    scaler = StandardScaler()
    real_scaled = scaler.fit_transform(real_df[continuous_cols].values)
    synth_scaled = scaler.transform(synth_df[continuous_cols].values)

    nn_model = NearestNeighbors(n_neighbors=1, algorithm="ball_tree")
    nn_model.fit(real_scaled)
    distances, _ = nn_model.kneighbors(synth_scaled[:5000])
    metrics["median_nn_distance"] = np.median(distances)
    metrics["min_nn_distance"] = np.min(distances)
    print(f"\n=== Privacy (Nearest-Neighbor Distance) ===")
    print(f"  Median distance to nearest real record: {metrics['median_nn_distance']:.4f}")
    print(f"  Minimum distance: {metrics['min_nn_distance']:.4f}")

    return metrics


metrics = evaluate_synthetic_data(real_ehr, synth_ehr, continuous_cols, binary_cols)

Why VAEs Work for Tabular EHR Data

Three properties make VAEs well-suited for this application:

  1. Mixed data types. The separate reconstruction heads (MSE for continuous, BCE for binary) handle the heterogeneous feature types that are ubiquitous in clinical data. GANs require careful architecture engineering for mixed types, and diffusion models for tabular data are still an active research area.

  2. Correlation preservation through the latent bottleneck. The latent space forces the model to learn a compressed representation that captures the joint distribution, not just individual marginals. The correlation between HbA1c and diabetes, for example, is preserved because both are functions of the same latent variables.

  3. Privacy through the information bottleneck. The KL regularizer limits the information the latent space carries about any individual record. With $\beta > 0$, the model learns general patterns (diabetes correlates with high HbA1c and BMI) rather than memorizing specific patient records. This is not a formal privacy guarantee (see Exercise 12.25 for differential privacy), but it provides a practical level of de-identification.

Limitations and Caveats

  1. No formal privacy guarantee. The VAE provides heuristic privacy through the information bottleneck, but does not satisfy differential privacy. For regulatory contexts requiring formal guarantees, combine with DP-SGD (differentially private stochastic gradient descent).

  2. Temporal structure is lost. This approach treats each patient as an independent feature vector. Real EHR data is longitudinal: lab values change over time, diagnoses accumulate, treatments are adjusted. Modeling temporal structure requires sequential generative models (Chapter 9 covers RNNs; combining sequence models with VAEs is an active research area).

  3. Rare conditions are underrepresented. If a condition affects 0.1% of patients, the VAE may not generate it at all — the reconstruction loss is dominated by common patterns. Conditional generation or oversampling strategies (Exercise 12.27) address this partially.

  4. Downstream utility must be validated. High statistical similarity does not guarantee that models trained on synthetic data will perform identically to models trained on real data. Always validate on a held-out real dataset.