Case Study 1: Molecular Property Prediction with GNNs
Scientific Context
Predicting molecular properties from structure is a central challenge in drug discovery and materials science. Traditional computational chemistry methods like density functional theory (DFT) provide accurate predictions but are computationally expensive---a single DFT calculation can take hours to days. Machine learning offers a way to approximate these calculations at a fraction of the cost, enabling rapid screening of millions of candidate molecules.
Molecules are naturally represented as graphs: atoms are nodes, chemical bonds are edges. Graph neural networks can learn directly from this molecular graph structure, avoiding hand-crafted molecular descriptors and instead learning task-specific representations from the raw molecular topology.
In this case study, we predict aqueous solubility (logS) from molecular structure using the ESOL dataset. Solubility is a critical property in drug development: a drug must dissolve sufficiently in water to be absorbed by the body. We will compare a traditional fingerprint-based approach with GNN-based models.
Dataset
The ESOL (Estimated SOLubility) dataset contains 1,128 molecules with experimentally measured aqueous solubility values (log mol/L). Each molecule is provided as a SMILES string, which we convert to a graph representation.
Example molecules and their solubility:
| SMILES | Name | logS |
|---|---|---|
| OCC3OC(OCC2OC(OC(C#N)c1ccccc1)... | Amygdalin | -0.77 |
| c1ccc2c(c1)cc1ccc3cccc4ccc2c1c34 | Benzo[ghi]fluoranthene | -6.30 |
| C(=O)(OC(CC(=O)OC... | Acetyl tributyl citrate | -3.38 |
Approach
Step 1: Molecular Graph Construction
We convert each SMILES string to a molecular graph with atom features and bond features.
Atom features (per node): - Atom type: one-hot encoding of C, N, O, S, F, Cl, Br, I, P, Other (10 dims) - Degree: node degree / 4 (1 dim) - Formal charge: charge / 2 (1 dim) - Number of hydrogens: count / 4 (1 dim) - Is aromatic: binary (1 dim) - Is in ring: binary (1 dim)
Total: 15-dimensional node feature vector.
Bond features (per edge): - Bond type: one-hot encoding of single, double, triple, aromatic (4 dims) - Is conjugated: binary (1 dim) - Is in ring: binary (1 dim)
Total: 6-dimensional edge feature vector.
Step 2: Baseline Model---Morgan Fingerprints + MLP
Before building a GNN, we establish a baseline using Morgan fingerprints (also known as Extended Connectivity Fingerprints, ECFP). Morgan fingerprints are circular substructure descriptors that encode the presence or absence of molecular substructures up to a given radius.
import torch
import torch.nn as nn
import numpy as np
from typing import Any
torch.manual_seed(42)
np.random.seed(42)
class FingerprintMLP(nn.Module):
"""MLP baseline operating on molecular fingerprints."""
def __init__(
self, input_dim: int = 1024, hidden_dim: int = 256, dropout: float = 0.2
) -> None:
super().__init__()
self.network = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, 1),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.network(x).squeeze(-1)
Step 3: GCN Model for Molecular Graphs
import torch.nn.functional as F
class MolecularGCN(nn.Module):
"""GCN-based model for molecular property prediction."""
def __init__(
self,
node_features: int = 15,
hidden_dim: int = 128,
num_layers: int = 3,
dropout: float = 0.2,
) -> None:
super().__init__()
self.node_embedding = nn.Linear(node_features, hidden_dim)
self.convs = nn.ModuleList()
self.batch_norms = nn.ModuleList()
for _ in range(num_layers):
self.convs.append(nn.Linear(hidden_dim, hidden_dim))
self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
self.output_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 1),
)
self.dropout = dropout
def gcn_aggregate(
self, x: torch.Tensor, edge_index: torch.Tensor, num_nodes: int
) -> torch.Tensor:
"""Symmetric-normalized aggregation with self-loops."""
self_loops = torch.arange(num_nodes, device=x.device)
self_loops = self_loops.unsqueeze(0).repeat(2, 1)
edge_index_sl = torch.cat([edge_index, self_loops], dim=1)
row, col = edge_index_sl
deg = torch.zeros(num_nodes, device=x.device)
deg.scatter_add_(0, row, torch.ones(row.size(0), device=x.device))
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0.0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
out = torch.zeros_like(x)
out.scatter_add_(
0,
col.unsqueeze(1).expand_as(x[row]),
norm.unsqueeze(1) * x[row],
)
return out
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
batch: torch.Tensor,
) -> torch.Tensor:
"""Forward pass with mean readout.
Args:
x: Node features [total_nodes, node_features].
edge_index: COO edge indices [2, total_edges].
batch: Graph membership [total_nodes].
Returns:
Predicted property [num_graphs].
"""
num_nodes = x.size(0)
x = self.node_embedding(x)
for conv, bn in zip(self.convs, self.batch_norms):
x_agg = self.gcn_aggregate(x, edge_index, num_nodes)
x_agg = conv(x_agg)
x_agg = bn(x_agg)
x = F.relu(x_agg) + x # Residual connection
x = F.dropout(x, p=self.dropout, training=self.training)
# Mean readout per graph
num_graphs = batch.max().item() + 1
graph_repr = torch.zeros(num_graphs, x.size(1), device=x.device)
graph_repr.scatter_add_(0, batch.unsqueeze(1).expand_as(x), x)
counts = torch.zeros(num_graphs, device=x.device)
counts.scatter_add_(0, batch, torch.ones(batch.size(0), device=x.device))
graph_repr = graph_repr / counts.unsqueeze(1).clamp(min=1)
return self.output_head(graph_repr).squeeze(-1)
Step 4: MPNN Model with Edge Features
class MPNNLayer(nn.Module):
"""Message Passing layer that incorporates edge features."""
def __init__(self, node_dim: int, edge_dim: int) -> None:
super().__init__()
self.message_mlp = nn.Sequential(
nn.Linear(node_dim + edge_dim, node_dim),
nn.ReLU(),
nn.Linear(node_dim, node_dim),
)
self.update_mlp = nn.Sequential(
nn.Linear(node_dim * 2, node_dim),
nn.ReLU(),
nn.Linear(node_dim, node_dim),
)
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
edge_attr: torch.Tensor,
num_nodes: int,
) -> torch.Tensor:
row, col = edge_index
# Compute messages using source node features and edge features
messages = self.message_mlp(
torch.cat([x[row], edge_attr], dim=-1)
)
# Aggregate messages at target nodes
agg = torch.zeros(num_nodes, messages.size(1), device=x.device)
agg.scatter_add_(0, col.unsqueeze(1).expand_as(messages), messages)
# Update node features
x = self.update_mlp(torch.cat([x, agg], dim=-1))
return x
class MolecularMPNN(nn.Module):
"""Full MPNN model for molecular property prediction."""
def __init__(
self,
node_features: int = 15,
edge_features: int = 6,
hidden_dim: int = 128,
num_layers: int = 3,
dropout: float = 0.2,
) -> None:
super().__init__()
self.node_encoder = nn.Linear(node_features, hidden_dim)
self.edge_encoder = nn.Linear(edge_features, hidden_dim)
self.layers = nn.ModuleList(
[MPNNLayer(hidden_dim, hidden_dim) for _ in range(num_layers)]
)
self.output = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, 1),
)
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
edge_attr: torch.Tensor,
batch: torch.Tensor,
) -> torch.Tensor:
num_nodes = x.size(0)
x = self.node_encoder(x)
edge_attr = self.edge_encoder(edge_attr)
for layer in self.layers:
x = x + layer(x, edge_index, edge_attr, num_nodes)
# Sum readout (preserves molecule size information)
num_graphs = batch.max().item() + 1
graph_repr = torch.zeros(num_graphs, x.size(1), device=x.device)
graph_repr.scatter_add_(0, batch.unsqueeze(1).expand_as(x), x)
return self.output(graph_repr).squeeze(-1)
Step 5: Training and Evaluation
def train_molecular_model(
model: nn.Module,
train_data: list[dict[str, torch.Tensor]],
val_data: list[dict[str, torch.Tensor]],
num_epochs: int = 100,
lr: float = 1e-3,
weight_decay: float = 1e-5,
) -> dict[str, list[float]]:
"""Train a molecular property prediction model.
Args:
model: The GNN model.
train_data: List of graph dicts with keys x, edge_index, batch, y.
val_data: Validation data in the same format.
num_epochs: Number of training epochs.
lr: Learning rate.
weight_decay: L2 regularization strength.
Returns:
Dictionary with train_losses and val_rmses.
"""
optimizer = torch.optim.Adam(
model.parameters(), lr=lr, weight_decay=weight_decay
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=10
)
history: dict[str, list[float]] = {"train_losses": [], "val_rmses": []}
for epoch in range(num_epochs):
model.train()
epoch_loss = 0.0
for batch_data in train_data:
optimizer.zero_grad()
pred = model(
batch_data["x"],
batch_data["edge_index"],
batch_data["batch"],
)
loss = F.mse_loss(pred, batch_data["y"])
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
epoch_loss += loss.item()
avg_loss = epoch_loss / len(train_data)
history["train_losses"].append(avg_loss)
# Validation
model.eval()
val_preds, val_targets = [], []
with torch.no_grad():
for batch_data in val_data:
pred = model(
batch_data["x"],
batch_data["edge_index"],
batch_data["batch"],
)
val_preds.append(pred)
val_targets.append(batch_data["y"])
val_preds_cat = torch.cat(val_preds)
val_targets_cat = torch.cat(val_targets)
val_rmse = torch.sqrt(F.mse_loss(val_preds_cat, val_targets_cat)).item()
history["val_rmses"].append(val_rmse)
scheduler.step(val_rmse)
if (epoch + 1) % 20 == 0:
print(
f"Epoch {epoch + 1}/{num_epochs} | "
f"Train Loss: {avg_loss:.4f} | Val RMSE: {val_rmse:.4f}"
)
return history
Results and Analysis
Performance Comparison
| Model | ESOL Test RMSE | Parameters |
|---|---|---|
| Morgan Fingerprint + MLP | ~0.75 | ~330K |
| GCN (3-layer, no edge features) | ~0.62 | ~115K |
| MPNN (3-layer, with edge features) | ~0.55 | ~180K |
The GNN models outperform the fingerprint baseline, particularly on molecules with unusual substructures that are not well-captured by fixed-radius fingerprints. The MPNN's incorporation of bond type information provides a further improvement.
Key Observations
-
Edge features matter: The MPNN, which incorporates bond type (single, double, triple, aromatic) into message passing, consistently outperforms the GCN, which ignores edge features. For molecular graphs, bond type is a strong chemical signal.
-
Scaffold splitting reveals generalization gaps: When using scaffold splits instead of random splits, all models show increased error. The MPNN degrades least, suggesting it learns more transferable representations.
-
Residual connections stabilize training: Without residual connections, the 3-layer GCN overfits quickly. Residuals allow the model to maintain atom-level identity information alongside aggregated neighborhood information.
-
Sum readout outperforms mean readout: For solubility prediction, larger molecules tend to be less soluble. Sum readout implicitly captures molecular size, giving it an advantage over size-invariant mean readout.
Error Analysis
Examining the highest-error predictions reveals patterns: - Charged molecules: Molecules with formal charges are systematically mispredicted, likely due to underrepresentation in the training set. - Large flexible molecules: Molecules with many rotatable bonds adopt multiple conformations in solution, making solubility harder to predict from 2D graph topology alone. - Rare functional groups: Molecules containing functional groups rarely seen in training data show high error, indicating the model struggles to extrapolate to novel chemistry.
Lessons Learned
- Start with a fingerprint baseline: Morgan fingerprints are a strong baseline that takes minutes to implement. Always compare GNN models against this baseline.
- Featurization is as important as architecture: Carefully designed atom and bond features contribute more to performance than switching between GCN, GAT, and MPNN architectures.
- Use scaffold splits for realistic evaluation: Random splits overestimate performance because structurally similar molecules appear in both train and test sets.
- Consider 3D geometry: For properties that depend on molecular shape (binding affinity, conformational stability), 2D graph-based models may be insufficient. 3D-aware models like SchNet or DimeNet are needed.
- Ensemble methods help: Ensembling multiple GNN models with different random seeds or architectures typically reduces RMSE by 5--15%.