Case Study 2: MediCore Molecular Property Prediction — Drug Molecules as Graphs

Context

MediCore Pharmaceuticals screens thousands of candidate molecules for drug development. Before any compound enters animal testing (a $2M+ process per molecule), the medicinal chemistry team needs computational predictions of molecular properties: aqueous solubility (will the drug dissolve in blood?), lipophilicity (will it cross cell membranes?), toxicity (will it poison the patient?), and binding affinity to target proteins (will it actually work?).

The traditional approach is molecular fingerprinting: convert each molecule to a fixed-length binary vector using hand-crafted rules (Morgan fingerprints, MACCS keys), then train a random forest or gradient-boosted model. This works — Morgan fingerprints have been the industry standard for decades — but the representation is lossy. A fingerprint encodes which substructures are present but not how they are connected. Two molecules with the same functional groups but different topologies (e.g., a linear chain vs. a ring) produce the same fingerprint bits but may have radically different properties.

A molecule is a graph: atoms are nodes, bonds are edges. A GNN operating on this native representation can learn task-specific structural features that fingerprints miss. This case study builds a GIN-based molecular property predictor and compares it to the fingerprint baseline.

Molecular Graphs in PyTorch Geometric

Each molecule is represented as a graph where node features encode atomic properties and edge features encode bond properties.

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader as PyGDataLoader
from torch_geometric.nn import GINConv, global_add_pool, global_mean_pool
from typing import Dict, List, Tuple


# Atom and bond feature dimensions for standard molecular graphs
ATOM_FEATURES = {
    "atomic_num": list(range(1, 119)),       # 118 elements
    "chirality": ["CHI_UNSPECIFIED", "CHI_TETRAHEDRAL_CW",
                   "CHI_TETRAHEDRAL_CCW", "CHI_OTHER"],
    "degree": list(range(11)),                # 0-10
    "formal_charge": list(range(-5, 6)),      # -5 to 5
    "num_hs": list(range(9)),                 # 0-8
    "hybridization": ["SP", "SP2", "SP3", "SP3D", "SP3D2", "UNSPECIFIED"],
    "aromatic": [False, True],
}

BOND_FEATURES = {
    "bond_type": ["SINGLE", "DOUBLE", "TRIPLE", "AROMATIC"],
    "stereo": ["STEREONONE", "STEREOZ", "STEREOE",
               "STEREOCIS", "STEREOTRANS", "STEREOANY"],
    "conjugated": [False, True],
}


def mol_to_graph(smiles: str) -> Data:
    """Convert a SMILES string to a PyG molecular graph.

    Creates a graph where:
    - Nodes = atoms with features [atomic_num, degree, formal_charge,
      num_Hs, hybridization, aromaticity] (one-hot encoded).
    - Edges = bonds with features [bond_type, stereo, conjugation]
      (one-hot encoded).

    Args:
        smiles: SMILES string representation of the molecule.

    Returns:
        PyG Data object with node features, edge indices, and edge features.
        Returns None if the SMILES is invalid.
    """
    from rdkit import Chem

    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None

    # Node features
    atom_features = []
    for atom in mol.GetAtoms():
        features = []
        features.append(atom.GetAtomicNum())
        features.append(atom.GetTotalDegree())
        features.append(atom.GetFormalCharge())
        features.append(atom.GetTotalNumHs())
        features.append(int(atom.GetHybridization()))
        features.append(int(atom.GetIsAromatic()))
        atom_features.append(features)

    x = torch.tensor(atom_features, dtype=torch.float)

    # Edge features and indices
    if mol.GetNumBonds() == 0:
        edge_index = torch.zeros((2, 0), dtype=torch.long)
        edge_attr = torch.zeros((0, 3), dtype=torch.float)
    else:
        src, dst, bond_features = [], [], []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            bf = [
                float(bond.GetBondTypeAsDouble()),
                float(bond.GetStereo()),
                float(bond.GetIsConjugated()),
            ]
            # Undirected: add both directions
            src.extend([i, j])
            dst.extend([j, i])
            bond_features.extend([bf, bf])

        edge_index = torch.tensor([src, dst], dtype=torch.long)
        edge_attr = torch.tensor(bond_features, dtype=torch.float)

    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

Synthetic Molecular Dataset

To avoid RDKit dependencies for reproducibility, we construct a synthetic molecular dataset that mirrors the statistical properties of real drug-like molecules. Each graph has 10-50 atoms, edges follow chemical valence constraints, and the target property depends on graph structure.

def generate_molecular_dataset(
    n_molecules: int = 2000,
    seed: int = 42,
) -> List[Data]:
    """Generate a synthetic molecular graph dataset.

    Creates graphs that simulate drug-like molecules:
    - Node features: [atomic_num_encoded, degree, charge, n_hydrogens,
      hybridization, is_aromatic] (6 features).
    - Graph-level target: a binary label (e.g., toxic / non-toxic)
      that depends on structural properties (ring count, heteroatom
      fraction, molecular weight proxy).

    Args:
        n_molecules: Number of molecules to generate.
        seed: Random seed.

    Returns:
        List of PyG Data objects with graph-level labels.
    """
    rng = np.random.RandomState(seed)
    molecules = []

    for idx in range(n_molecules):
        # Variable-size molecules (10-50 atoms)
        n_atoms = rng.randint(10, 51)

        # Node features: simulate atomic properties
        # atomic_num: mostly C(6), N(7), O(8), S(16), with occasional others
        atomic_nums = rng.choice(
            [6, 6, 6, 6, 7, 7, 8, 8, 16, 9, 17, 35],
            size=n_atoms,
        )
        degrees = np.zeros(n_atoms, dtype=int)
        charges = np.zeros(n_atoms, dtype=int)
        n_hydrogens = rng.randint(0, 4, size=n_atoms)
        hybridization = rng.choice([2, 3], size=n_atoms)  # sp2 or sp3
        is_aromatic = rng.binomial(1, 0.3, size=n_atoms)

        # Generate edges: tree backbone + random additional bonds
        edges_src = []
        edges_dst = []

        # Backbone: connect atoms sequentially (tree structure)
        for i in range(1, n_atoms):
            parent = rng.randint(0, i)
            edges_src.extend([parent, i])
            edges_dst.extend([i, parent])
            degrees[parent] += 1
            degrees[i] += 1

        # Add ring-closing bonds (create cycles)
        n_extra = rng.randint(0, min(n_atoms // 3, 10))
        for _ in range(n_extra):
            i = rng.randint(0, n_atoms)
            j = rng.randint(0, n_atoms)
            if i != j and degrees[i] < 4 and degrees[j] < 4:
                edges_src.extend([i, j])
                edges_dst.extend([j, i])
                degrees[i] += 1
                degrees[j] += 1

        # Build feature matrix
        x = torch.tensor(
            np.column_stack([
                atomic_nums.astype(float),
                degrees.astype(float),
                charges.astype(float),
                n_hydrogens.astype(float),
                hybridization.astype(float),
                is_aromatic.astype(float),
            ]),
            dtype=torch.float,
        )

        edge_index = torch.tensor([edges_src, edges_dst], dtype=torch.long)

        # Target: structural property classification
        # "Toxic" if: many heteroatoms + aromatic + large molecule
        heteroatom_frac = (atomic_nums != 6).sum() / n_atoms
        aromatic_frac = is_aromatic.sum() / n_atoms
        n_rings = n_extra  # Proxy for ring count
        size_factor = n_atoms / 50.0

        toxicity_score = (
            0.3 * heteroatom_frac
            + 0.25 * aromatic_frac
            + 0.2 * (n_rings / 10.0)
            + 0.25 * size_factor
            + rng.normal(0, 0.1)
        )
        label = int(toxicity_score > 0.35)

        data = Data(x=x, edge_index=edge_index, y=torch.tensor([label]))
        molecules.append(data)

    # Print dataset statistics
    labels = [d.y.item() for d in molecules]
    sizes = [d.num_nodes for d in molecules]
    print(f"Dataset: {n_molecules} molecules")
    print(f"  Atom count: {np.mean(sizes):.1f} +/- {np.std(sizes):.1f}")
    print(f"  Class balance: {np.mean(labels):.3f} positive")
    return molecules

The Model: GIN for Molecular Property Prediction

We use a Graph Isomorphism Network (GIN) — the most expressive standard message passing architecture — because molecular property prediction requires distinguishing subtle structural differences. A hydroxyl group (-OH) at position 2 vs. position 4 on a benzene ring can change a molecule's biological activity entirely; the GNN must capture these topological distinctions.

class MolecularGIN(nn.Module):
    """GIN model for molecular property prediction.

    Uses sum aggregation (maximally expressive under 1-WL) with MLP
    update functions. Final graph representation uses concatenation
    of sum and mean pooling across all layers for multi-scale features.

    Args:
        in_channels: Number of atom features.
        hidden_channels: Hidden layer dimensionality.
        out_channels: Number of output classes (2 for binary).
        num_layers: Number of GIN layers.
        dropout: Dropout probability.
    """

    def __init__(
        self,
        in_channels: int = 6,
        hidden_channels: int = 128,
        out_channels: int = 2,
        num_layers: int = 5,
        dropout: float = 0.3,
    ) -> None:
        super().__init__()
        self.num_layers = num_layers
        self.dropout = dropout

        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()

        # First layer
        mlp = nn.Sequential(
            nn.Linear(in_channels, hidden_channels),
            nn.BatchNorm1d(hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, hidden_channels),
        )
        self.convs.append(GINConv(mlp, train_eps=True))
        self.batch_norms.append(nn.BatchNorm1d(hidden_channels))

        # Subsequent layers
        for _ in range(num_layers - 1):
            mlp = nn.Sequential(
                nn.Linear(hidden_channels, hidden_channels),
                nn.BatchNorm1d(hidden_channels),
                nn.ReLU(),
                nn.Linear(hidden_channels, hidden_channels),
            )
            self.convs.append(GINConv(mlp, train_eps=True))
            self.batch_norms.append(nn.BatchNorm1d(hidden_channels))

        # Multi-scale readout: concatenate sum and mean from all layers
        self.classifier = nn.Sequential(
            nn.Linear(hidden_channels * num_layers * 2, hidden_channels),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_channels, out_channels),
        )

    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        batch: torch.Tensor,
    ) -> torch.Tensor:
        """Forward pass: GIN layers -> multi-scale readout -> classifier.

        Args:
            x: Node features, shape (total_nodes, in_channels).
            edge_index: Edge indices, shape (2, total_edges).
            batch: Batch assignment vector, shape (total_nodes,).

        Returns:
            Graph-level logits, shape (num_graphs, out_channels).
        """
        layer_embeddings = []

        for conv, bn in zip(self.convs, self.batch_norms):
            x = conv(x, edge_index)
            x = bn(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            layer_embeddings.append(x)

        # Multi-scale graph readout
        graph_features = []
        for emb in layer_embeddings:
            graph_features.append(global_add_pool(emb, batch))
            graph_features.append(global_mean_pool(emb, batch))

        graph_repr = torch.cat(graph_features, dim=1)
        return self.classifier(graph_repr)

Training and Evaluation Pipeline

from sklearn.model_selection import StratifiedKFold


def train_molecular_gin(
    dataset: List[Data],
    hidden_channels: int = 128,
    num_layers: int = 5,
    epochs: int = 100,
    batch_size: int = 64,
    lr: float = 1e-3,
    n_folds: int = 5,
) -> Dict[str, float]:
    """Train and evaluate a GIN model for molecular property prediction.

    Uses stratified K-fold cross-validation, following the standard
    evaluation protocol for molecular benchmarks.

    Args:
        dataset: List of PyG Data objects with graph-level labels.
        hidden_channels: Hidden dimensionality.
        num_layers: Number of GIN layers.
        epochs: Training epochs per fold.
        batch_size: Batch size.
        lr: Learning rate.
        n_folds: Number of cross-validation folds.

    Returns:
        Dictionary with mean and std accuracy, AUROC, and F1.
    """
    from sklearn.metrics import roc_auc_score, f1_score

    labels = np.array([d.y.item() for d in dataset])
    in_channels = dataset[0].x.size(1)
    n_classes = len(np.unique(labels))

    skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
    fold_results = {"accuracy": [], "auroc": [], "f1": []}

    for fold, (train_idx, test_idx) in enumerate(skf.split(labels, labels)):
        train_data = [dataset[i] for i in train_idx]
        test_data = [dataset[i] for i in test_idx]

        train_loader = PyGDataLoader(
            train_data, batch_size=batch_size, shuffle=True
        )
        test_loader = PyGDataLoader(
            test_data, batch_size=batch_size
        )

        model = MolecularGIN(
            in_channels=in_channels,
            hidden_channels=hidden_channels,
            out_channels=n_classes,
            num_layers=num_layers,
        )
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=epochs
        )

        # Training
        for epoch in range(epochs):
            model.train()
            for batch in train_loader:
                optimizer.zero_grad()
                out = model(batch.x, batch.edge_index, batch.batch)
                loss = F.cross_entropy(out, batch.y.view(-1))
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
            scheduler.step()

        # Evaluation
        model.eval()
        all_preds = []
        all_probs = []
        all_labels = []
        with torch.no_grad():
            for batch in test_loader:
                out = model(batch.x, batch.edge_index, batch.batch)
                probs = F.softmax(out, dim=1)[:, 1]
                preds = out.argmax(dim=1)
                all_preds.extend(preds.tolist())
                all_probs.extend(probs.tolist())
                all_labels.extend(batch.y.view(-1).tolist())

        acc = np.mean(np.array(all_preds) == np.array(all_labels))
        auroc = roc_auc_score(all_labels, all_probs)
        f1 = f1_score(all_labels, all_preds)

        fold_results["accuracy"].append(acc)
        fold_results["auroc"].append(auroc)
        fold_results["f1"].append(f1)
        print(
            f"  Fold {fold+1} | Accuracy: {acc:.4f} | "
            f"AUROC: {auroc:.4f} | F1: {f1:.4f}"
        )

    # Summary
    for metric in ["accuracy", "auroc", "f1"]:
        values = fold_results[metric]
        print(f"{metric}: {np.mean(values):.4f} +/- {np.std(values):.4f}")

    return {
        f"{m}_mean": np.mean(fold_results[m])
        for m in fold_results
    }


# ---- Run the pipeline ----
molecules = generate_molecular_dataset(n_molecules=2000, seed=42)
results = train_molecular_gin(molecules, hidden_channels=128, num_layers=5)

Comparison with Fingerprint Baselines

To quantify the value of graph-native representations, we compare against the industry-standard molecular fingerprint approach:

def fingerprint_baseline(
    dataset: List[Data],
    n_folds: int = 5,
) -> Dict[str, float]:
    """Random forest baseline using graph-derived structural features.

    Computes hand-crafted features that simulate Morgan fingerprint
    statistics: node type counts, degree statistics, ring count proxies,
    and graph-level descriptors.

    Args:
        dataset: List of PyG Data objects.
        n_folds: Number of cross-validation folds.

    Returns:
        Dictionary with mean accuracy, AUROC, and F1.
    """
    from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
    from sklearn.metrics import roc_auc_score, f1_score

    # Extract structural features per molecule
    features = []
    labels = []
    for data in dataset:
        x = data.x.numpy()
        n_atoms = x.shape[0]
        n_edges = data.edge_index.size(1) // 2  # Undirected

        feat = [
            n_atoms,                              # Molecular size
            n_edges,                              # Bond count
            n_edges / max(n_atoms, 1),            # Avg degree
            (x[:, 0] == 6).sum() / n_atoms,       # Carbon fraction
            (x[:, 0] == 7).sum() / n_atoms,       # Nitrogen fraction
            (x[:, 0] == 8).sum() / n_atoms,       # Oxygen fraction
            (x[:, 0] == 16).sum() / n_atoms,      # Sulfur fraction
            x[:, 1].mean(),                       # Mean degree
            x[:, 1].max(),                        # Max degree
            x[:, 5].mean(),                       # Aromatic fraction
            n_edges - (n_atoms - 1),              # Ring bond count (Euler)
        ]
        features.append(feat)
        labels.append(data.y.item())

    X = np.array(features)
    y = np.array(labels)

    skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
    fold_accs = []

    for train_idx, test_idx in skf.split(X, y):
        clf = GradientBoostingClassifier(
            n_estimators=200, max_depth=5, random_state=42
        )
        clf.fit(X[train_idx], y[train_idx])
        preds = clf.predict(X[test_idx])
        fold_accs.append(np.mean(preds == y[test_idx]))

    print(f"Fingerprint baseline: {np.mean(fold_accs):.4f} +/- {np.std(fold_accs):.4f}")
    return {"accuracy_mean": np.mean(fold_accs)}


fingerprint_baseline(molecules)

Expected approximate comparison:

Method Accuracy AUROC
GBM on structural features 0.72-0.76 0.78-0.82
GIN (5 layers, 128 hidden) 0.78-0.83 0.84-0.89

The GIN outperforms the fingerprint baseline by 5-8 percentage points in accuracy and 5-7 points in AUROC. The gap is largest on molecules where topological details (ring connectivity, substituent positioning) determine the property — exactly the structural information that fingerprints lose.

Why GNNs Outperform Fingerprints

  1. Learned features vs. hand-crafted features. Morgan fingerprints encode a fixed, pre-defined set of substructure patterns. GINs learn task-specific structural features: for toxicity prediction, the model learns to attend to electrophilic centers and aromatic nitro groups; for solubility, it learns to count hydrophilic groups in context. The representation adapts to the task.

  2. Context-dependent atom representations. In a fingerprint, a nitrogen atom is always encoded the same way regardless of its neighborhood. In a GIN, after 5 message passing rounds, a nitrogen's representation encodes its 5-hop structural context — whether it is part of an amine, an amide, a nitro group, or a heterocyclic ring. The same element in different contexts produces different representations.

  3. Graph-level readout preserves structural information. The multi-scale readout (sum + mean from all layers) captures both global properties (molecular size, total electron density) and local structural motifs (functional group counts, ring patterns). Fingerprints, being fixed-length bit vectors, cannot scale their representation to capture both simultaneously.

From Benchmark to Production at MediCore

Moving from academic benchmarks to production molecular screening requires several enhancements:

  1. Edge features. Real molecular graphs have bond types (single, double, triple, aromatic), stereochemistry, and conformational information. Incorporating edge features into message passing (via edge-conditioned convolutions or attention over bond types) improves predictions for stereo-sensitive properties.

  2. 3D geometry. Many molecular properties depend on the 3D conformation, not just the 2D topology. Equivariant GNNs (SchNet, DimeNet, SphereNet) incorporate atomic coordinates and inter-atomic distances while respecting rotational symmetry.

  3. Uncertainty quantification. MediCore's chemists need to know not just whether a molecule is predicted toxic, but how confident the prediction is. Ensembles of GNNs, MC-dropout, or evidential deep learning (Chapter 34) provide calibrated uncertainty estimates.

  4. Multi-task learning. Real drug candidates must satisfy multiple property constraints simultaneously (solubility, toxicity, binding, metabolic stability). Training a single GNN with multiple output heads — one per property — enables shared molecular representations and improves data efficiency.

Lessons for Practice

  1. When data is naturally a graph, use a graph model. The history of molecular machine learning — from QSAR descriptors to Morgan fingerprints to GNNs — is a progression from lossy featurizations toward native structural representations. Each step recovered information that the previous step discarded.

  2. Expressiveness matters for structural tasks. GIN's sum aggregation (1-WL expressiveness) outperforms GCN's mean aggregation for molecular property prediction because distinguishing structural isomers requires multiset-injective aggregation. The WL expressiveness hierarchy is not merely theoretical — it translates directly to empirical performance gaps on molecular benchmarks.

  3. Domain knowledge complements learned representations. The best production molecular models combine GNN representations with hand-crafted chemical descriptors (molecular weight, LogP, topological polar surface area). The GNN captures structure; the descriptors capture properties that are hard to learn from structure alone.