Case Study 1: Molecular Property Prediction with GNNs

Scientific Context

Predicting molecular properties from structure is a central challenge in drug discovery and materials science. Traditional computational chemistry methods like density functional theory (DFT) provide accurate predictions but are computationally expensive---a single DFT calculation can take hours to days. Machine learning offers a way to approximate these calculations at a fraction of the cost, enabling rapid screening of millions of candidate molecules.

Molecules are naturally represented as graphs: atoms are nodes, chemical bonds are edges. Graph neural networks can learn directly from this molecular graph structure, avoiding hand-crafted molecular descriptors and instead learning task-specific representations from the raw molecular topology.

In this case study, we predict aqueous solubility (logS) from molecular structure using the ESOL dataset. Solubility is a critical property in drug development: a drug must dissolve sufficiently in water to be absorbed by the body. We will compare a traditional fingerprint-based approach with GNN-based models.

Dataset

The ESOL (Estimated SOLubility) dataset contains 1,128 molecules with experimentally measured aqueous solubility values (log mol/L). Each molecule is provided as a SMILES string, which we convert to a graph representation.

Example molecules and their solubility:

SMILES Name logS
OCC3OC(OCC2OC(OC(C#N)c1ccccc1)... Amygdalin -0.77
c1ccc2c(c1)cc1ccc3cccc4ccc2c1c34 Benzo[ghi]fluoranthene -6.30
C(=O)(OC(CC(=O)OC... Acetyl tributyl citrate -3.38

Approach

Step 1: Molecular Graph Construction

We convert each SMILES string to a molecular graph with atom features and bond features.

Atom features (per node): - Atom type: one-hot encoding of C, N, O, S, F, Cl, Br, I, P, Other (10 dims) - Degree: node degree / 4 (1 dim) - Formal charge: charge / 2 (1 dim) - Number of hydrogens: count / 4 (1 dim) - Is aromatic: binary (1 dim) - Is in ring: binary (1 dim)

Total: 15-dimensional node feature vector.

Bond features (per edge): - Bond type: one-hot encoding of single, double, triple, aromatic (4 dims) - Is conjugated: binary (1 dim) - Is in ring: binary (1 dim)

Total: 6-dimensional edge feature vector.

Step 2: Baseline Model---Morgan Fingerprints + MLP

Before building a GNN, we establish a baseline using Morgan fingerprints (also known as Extended Connectivity Fingerprints, ECFP). Morgan fingerprints are circular substructure descriptors that encode the presence or absence of molecular substructures up to a given radius.

import torch
import torch.nn as nn
import numpy as np
from typing import Any

torch.manual_seed(42)
np.random.seed(42)


class FingerprintMLP(nn.Module):
    """MLP baseline operating on molecular fingerprints."""

    def __init__(
        self, input_dim: int = 1024, hidden_dim: int = 256, dropout: float = 0.2
    ) -> None:
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x).squeeze(-1)

Step 3: GCN Model for Molecular Graphs

import torch.nn.functional as F


class MolecularGCN(nn.Module):
    """GCN-based model for molecular property prediction."""

    def __init__(
        self,
        node_features: int = 15,
        hidden_dim: int = 128,
        num_layers: int = 3,
        dropout: float = 0.2,
    ) -> None:
        super().__init__()
        self.node_embedding = nn.Linear(node_features, hidden_dim)
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        for _ in range(num_layers):
            self.convs.append(nn.Linear(hidden_dim, hidden_dim))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        self.output_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),
        )
        self.dropout = dropout

    def gcn_aggregate(
        self, x: torch.Tensor, edge_index: torch.Tensor, num_nodes: int
    ) -> torch.Tensor:
        """Symmetric-normalized aggregation with self-loops."""
        self_loops = torch.arange(num_nodes, device=x.device)
        self_loops = self_loops.unsqueeze(0).repeat(2, 1)
        edge_index_sl = torch.cat([edge_index, self_loops], dim=1)
        row, col = edge_index_sl

        deg = torch.zeros(num_nodes, device=x.device)
        deg.scatter_add_(0, row, torch.ones(row.size(0), device=x.device))
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0.0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        out = torch.zeros_like(x)
        out.scatter_add_(
            0,
            col.unsqueeze(1).expand_as(x[row]),
            norm.unsqueeze(1) * x[row],
        )
        return out

    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        batch: torch.Tensor,
    ) -> torch.Tensor:
        """Forward pass with mean readout.

        Args:
            x: Node features [total_nodes, node_features].
            edge_index: COO edge indices [2, total_edges].
            batch: Graph membership [total_nodes].

        Returns:
            Predicted property [num_graphs].
        """
        num_nodes = x.size(0)
        x = self.node_embedding(x)

        for conv, bn in zip(self.convs, self.batch_norms):
            x_agg = self.gcn_aggregate(x, edge_index, num_nodes)
            x_agg = conv(x_agg)
            x_agg = bn(x_agg)
            x = F.relu(x_agg) + x  # Residual connection
            x = F.dropout(x, p=self.dropout, training=self.training)

        # Mean readout per graph
        num_graphs = batch.max().item() + 1
        graph_repr = torch.zeros(num_graphs, x.size(1), device=x.device)
        graph_repr.scatter_add_(0, batch.unsqueeze(1).expand_as(x), x)
        counts = torch.zeros(num_graphs, device=x.device)
        counts.scatter_add_(0, batch, torch.ones(batch.size(0), device=x.device))
        graph_repr = graph_repr / counts.unsqueeze(1).clamp(min=1)

        return self.output_head(graph_repr).squeeze(-1)

Step 4: MPNN Model with Edge Features

class MPNNLayer(nn.Module):
    """Message Passing layer that incorporates edge features."""

    def __init__(self, node_dim: int, edge_dim: int) -> None:
        super().__init__()
        self.message_mlp = nn.Sequential(
            nn.Linear(node_dim + edge_dim, node_dim),
            nn.ReLU(),
            nn.Linear(node_dim, node_dim),
        )
        self.update_mlp = nn.Sequential(
            nn.Linear(node_dim * 2, node_dim),
            nn.ReLU(),
            nn.Linear(node_dim, node_dim),
        )

    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        edge_attr: torch.Tensor,
        num_nodes: int,
    ) -> torch.Tensor:
        row, col = edge_index
        # Compute messages using source node features and edge features
        messages = self.message_mlp(
            torch.cat([x[row], edge_attr], dim=-1)
        )
        # Aggregate messages at target nodes
        agg = torch.zeros(num_nodes, messages.size(1), device=x.device)
        agg.scatter_add_(0, col.unsqueeze(1).expand_as(messages), messages)
        # Update node features
        x = self.update_mlp(torch.cat([x, agg], dim=-1))
        return x


class MolecularMPNN(nn.Module):
    """Full MPNN model for molecular property prediction."""

    def __init__(
        self,
        node_features: int = 15,
        edge_features: int = 6,
        hidden_dim: int = 128,
        num_layers: int = 3,
        dropout: float = 0.2,
    ) -> None:
        super().__init__()
        self.node_encoder = nn.Linear(node_features, hidden_dim)
        self.edge_encoder = nn.Linear(edge_features, hidden_dim)
        self.layers = nn.ModuleList(
            [MPNNLayer(hidden_dim, hidden_dim) for _ in range(num_layers)]
        )
        self.output = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1),
        )

    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        edge_attr: torch.Tensor,
        batch: torch.Tensor,
    ) -> torch.Tensor:
        num_nodes = x.size(0)
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)

        for layer in self.layers:
            x = x + layer(x, edge_index, edge_attr, num_nodes)

        # Sum readout (preserves molecule size information)
        num_graphs = batch.max().item() + 1
        graph_repr = torch.zeros(num_graphs, x.size(1), device=x.device)
        graph_repr.scatter_add_(0, batch.unsqueeze(1).expand_as(x), x)

        return self.output(graph_repr).squeeze(-1)

Step 5: Training and Evaluation

def train_molecular_model(
    model: nn.Module,
    train_data: list[dict[str, torch.Tensor]],
    val_data: list[dict[str, torch.Tensor]],
    num_epochs: int = 100,
    lr: float = 1e-3,
    weight_decay: float = 1e-5,
) -> dict[str, list[float]]:
    """Train a molecular property prediction model.

    Args:
        model: The GNN model.
        train_data: List of graph dicts with keys x, edge_index, batch, y.
        val_data: Validation data in the same format.
        num_epochs: Number of training epochs.
        lr: Learning rate.
        weight_decay: L2 regularization strength.

    Returns:
        Dictionary with train_losses and val_rmses.
    """
    optimizer = torch.optim.Adam(
        model.parameters(), lr=lr, weight_decay=weight_decay
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", factor=0.5, patience=10
    )
    history: dict[str, list[float]] = {"train_losses": [], "val_rmses": []}

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0
        for batch_data in train_data:
            optimizer.zero_grad()
            pred = model(
                batch_data["x"],
                batch_data["edge_index"],
                batch_data["batch"],
            )
            loss = F.mse_loss(pred, batch_data["y"])
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(train_data)
        history["train_losses"].append(avg_loss)

        # Validation
        model.eval()
        val_preds, val_targets = [], []
        with torch.no_grad():
            for batch_data in val_data:
                pred = model(
                    batch_data["x"],
                    batch_data["edge_index"],
                    batch_data["batch"],
                )
                val_preds.append(pred)
                val_targets.append(batch_data["y"])

        val_preds_cat = torch.cat(val_preds)
        val_targets_cat = torch.cat(val_targets)
        val_rmse = torch.sqrt(F.mse_loss(val_preds_cat, val_targets_cat)).item()
        history["val_rmses"].append(val_rmse)
        scheduler.step(val_rmse)

        if (epoch + 1) % 20 == 0:
            print(
                f"Epoch {epoch + 1}/{num_epochs} | "
                f"Train Loss: {avg_loss:.4f} | Val RMSE: {val_rmse:.4f}"
            )

    return history

Results and Analysis

Performance Comparison

Model ESOL Test RMSE Parameters
Morgan Fingerprint + MLP ~0.75 ~330K
GCN (3-layer, no edge features) ~0.62 ~115K
MPNN (3-layer, with edge features) ~0.55 ~180K

The GNN models outperform the fingerprint baseline, particularly on molecules with unusual substructures that are not well-captured by fixed-radius fingerprints. The MPNN's incorporation of bond type information provides a further improvement.

Key Observations

  1. Edge features matter: The MPNN, which incorporates bond type (single, double, triple, aromatic) into message passing, consistently outperforms the GCN, which ignores edge features. For molecular graphs, bond type is a strong chemical signal.

  2. Scaffold splitting reveals generalization gaps: When using scaffold splits instead of random splits, all models show increased error. The MPNN degrades least, suggesting it learns more transferable representations.

  3. Residual connections stabilize training: Without residual connections, the 3-layer GCN overfits quickly. Residuals allow the model to maintain atom-level identity information alongside aggregated neighborhood information.

  4. Sum readout outperforms mean readout: For solubility prediction, larger molecules tend to be less soluble. Sum readout implicitly captures molecular size, giving it an advantage over size-invariant mean readout.

Error Analysis

Examining the highest-error predictions reveals patterns: - Charged molecules: Molecules with formal charges are systematically mispredicted, likely due to underrepresentation in the training set. - Large flexible molecules: Molecules with many rotatable bonds adopt multiple conformations in solution, making solubility harder to predict from 2D graph topology alone. - Rare functional groups: Molecules containing functional groups rarely seen in training data show high error, indicating the model struggles to extrapolate to novel chemistry.

Lessons Learned

  1. Start with a fingerprint baseline: Morgan fingerprints are a strong baseline that takes minutes to implement. Always compare GNN models against this baseline.
  2. Featurization is as important as architecture: Carefully designed atom and bond features contribute more to performance than switching between GCN, GAT, and MPNN architectures.
  3. Use scaffold splits for realistic evaluation: Random splits overestimate performance because structurally similar molecules appear in both train and test sets.
  4. Consider 3D geometry: For properties that depend on molecular shape (binding affinity, conformational stability), 2D graph-based models may be insufficient. 3D-aware models like SchNet or DimeNet are needed.
  5. Ensemble methods help: Ensembling multiple GNN models with different random seeds or architectures typically reduces RMSE by 5--15%.