> "A graph is the most general form of structured data. All other structures — sequences, grids, sets — are special cases."
In This Chapter
- Learning Objectives
- 14.1 Why Graphs?
- 14.2 Graph Fundamentals and Representation
- 14.3 The Spectral Perspective: Graph Convolution via Eigendecomposition
- 14.4 The Spatial Perspective: Message Passing Neural Networks
- 14.5 GCN: Graph Convolutional Networks
- 14.6 GraphSAGE: Inductive Learning via Neighborhood Sampling
- 14.7 GAT: Graph Attention Networks
- 14.8 GIN: The Most Powerful Message Passing GNN
- 14.9 Graph-Level Tasks: Readout and Pooling
- 14.10 Over-Smoothing: The Depth Limit of GNNs
- 14.11 Heterogeneous Graphs and Knowledge Graphs
- 14.12 PyTorch Geometric: A Practical Framework
- 14.13 Beyond Standard Message Passing
- 14.14 Practical Considerations
- 14.15 Summary: Choosing the Right Architecture
- Chapter Summary
Chapter 14: Graph Neural Networks and Geometric Deep Learning — When Your Data Has Structure Beyond Grids and Sequences
"A graph is the most general form of structured data. All other structures — sequences, grids, sets — are special cases." — Adapted from Michael Bronstein et al., "Geometric Deep Learning" (2021)
Learning Objectives
By the end of this chapter, you will be able to:
- Represent graph-structured data using adjacency matrices, edge lists, and feature matrices, and explain why standard neural networks (MLPs, CNNs, RNNs) cannot process graphs directly
- Derive the message passing framework from first principles and implement three foundational GNN architectures: GCN, GraphSAGE, and GAT
- Apply graph neural networks to three canonical tasks: node classification, link prediction, and graph classification
- Explain the Weisfeiler-Leman hierarchy and the expressiveness limitations it imposes on message passing GNNs
- Build end-to-end graph learning pipelines using PyTorch Geometric
14.1 Why Graphs?
Every neural network architecture we have studied in this book assumes a specific data geometry. MLPs (Chapter 6) operate on fixed-size vectors — unstructured feature sets. CNNs (Chapter 8) assume a regular grid where spatial locality and translation equivariance provide inductive bias. RNNs (Chapter 9) and transformers (Chapter 10) assume sequential ordering. These assumptions are powerful when they match the data, and catastrophically wrong when they do not.
Consider two problems central to this book's anchor examples:
StreamRec's recommendation problem. The platform has 5 million users and 200,000 items. The interactions between them form a bipartite graph: users are one set of nodes, items are another, and edges represent engagement. In Chapter 1, we factored the user-item interaction matrix using SVD. In Chapter 13, we learned separate user and item embeddings with a two-tower model. But neither approach exploits the graph structure of the interaction data. User A and User B might not share any items, but they are connected through a chain of shared items and other users — a two-hop or three-hop path in the interaction graph. A user who likes indie documentaries, avant-garde jazz, and literary fiction shares a structural neighborhood with users who have similar taste patterns, even if their specific item overlaps are zero. Matrix factorization sees only direct co-occurrences. Graph neural networks see the neighborhood.
MediCore's molecular property prediction. A drug molecule is not a sequence (it is not linear) and not an image (it does not live on a grid). It is a graph: atoms are nodes, bonds are edges, and the molecular properties (toxicity, solubility, binding affinity) are functions of the entire graph structure. A carbon atom bonded to two oxygens and a hydrogen behaves differently from a carbon atom in a benzene ring — the same node feature (carbon, atomic number 6) produces different chemical behavior depending on its neighborhood. No fixed-size vector representation captures this structural information without discarding it through an arbitrary serialization (SMILES strings) or a lossy featurization (molecular fingerprints). Graph neural networks operate on the molecule's native topology.
These are not niche applications. Graphs are everywhere: social networks, citation networks, biological interaction networks, knowledge graphs, traffic networks, supply chains, program dependency graphs, electrical circuits, 3D meshes. Any relational dataset — any data where entities have pairwise relationships — is naturally a graph. The question is not whether your data has graph structure, but whether ignoring that structure is costing you performance.
Understanding Why: This chapter fills a critical gap in most deep learning curricula. The standard progression — MLP, CNN, RNN, transformer — covers sequences and grids but ignores the most general structured data type. Graph neural networks are not a peripheral specialty; they are the mathematically natural framework for any relational learning problem. The spectral and spatial perspectives we develop here connect directly to the linear algebra of Chapter 1 (eigendecomposition of the Laplacian), the attention mechanisms of Chapter 10 (GAT is attention on graphs), and the information-theoretic view of Chapter 4 (how much information can a message passing round transmit?). The math is not incidental to the understanding — it is the understanding.
14.2 Graph Fundamentals and Representation
Notation and Definitions
A graph $\mathcal{G} = (\mathcal{V}, \mathcal{E})$ consists of a set of $N$ nodes (or vertices) $\mathcal{V} = \{v_1, v_2, \ldots, v_N\}$ and a set of edges $\mathcal{E} \subseteq \mathcal{V} \times \mathcal{V}$. An edge $(v_i, v_j) \in \mathcal{E}$ indicates a relationship between nodes $v_i$ and $v_j$.
The graph may be:
- Undirected: $(v_i, v_j) \in \mathcal{E} \Leftrightarrow (v_j, v_i) \in \mathcal{E}$ (e.g., molecular bonds, friendships)
- Directed: $(v_i, v_j) \in \mathcal{E} \not\Rightarrow (v_j, v_i) \in \mathcal{E}$ (e.g., citations, web links, follower graphs)
- Weighted: each edge carries a scalar weight $w_{ij} \in \mathbb{R}$ (e.g., interaction strength, distance)
- Attributed: nodes carry feature vectors $\mathbf{x}_i \in \mathbb{R}^{d}$ and/or edges carry feature vectors $\mathbf{e}_{ij} \in \mathbb{R}^{d_e}$
The Adjacency Matrix
The adjacency matrix $\mathbf{A} \in \{0, 1\}^{N \times N}$ encodes the graph structure:
$$A_{ij} = \begin{cases} 1 & \text{if } (v_i, v_j) \in \mathcal{E} \\ 0 & \text{otherwise} \end{cases}$$
For an undirected graph, $\mathbf{A}$ is symmetric. For a weighted graph, $A_{ij} = w_{ij}$.
The degree of node $v_i$ is $d_i = \sum_{j} A_{ij}$ — the number of edges incident to it. The degree matrix $\mathbf{D}$ is the diagonal matrix with $D_{ii} = d_i$.
The node feature matrix $\mathbf{X} \in \mathbb{R}^{N \times d}$ stacks all node feature vectors: row $i$ is $\mathbf{x}_i$, the feature vector of node $v_i$.
The Graph Laplacian
The graph Laplacian is central to spectral graph theory and provides the mathematical bridge between graph structure and signal processing:
$$\mathbf{L} = \mathbf{D} - \mathbf{A}$$
The Laplacian has several critical properties:
- Positive semi-definite: all eigenvalues are non-negative, $\lambda_i \geq 0$.
- Smallest eigenvalue is zero: $\lambda_1 = 0$, with eigenvector $\mathbf{1}$ (the all-ones vector). The multiplicity of the zero eigenvalue equals the number of connected components.
- Quadratic form interpretation: for any signal $\mathbf{f} \in \mathbb{R}^N$ defined on the graph nodes:
$$\mathbf{f}^\top \mathbf{L} \mathbf{f} = \sum_{(v_i, v_j) \in \mathcal{E}} (f_i - f_j)^2$$
This quadratic form measures the total variation of the signal on the graph — how much the signal changes across edges. A signal that is constant on connected components has zero Laplacian quadratic form; a signal that changes sharply across every edge has large quadratic form.
The normalized Laplacian is often preferred for GNN derivations:
$$\tilde{\mathbf{L}} = \mathbf{D}^{-1/2} \mathbf{L} \mathbf{D}^{-1/2} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}$$
Its eigenvalues lie in $[0, 2]$, making it independent of the graph's degree distribution.
Why Standard Architectures Fail on Graphs
Three fundamental properties of graphs prevent direct application of MLPs, CNNs, and RNNs:
-
Variable size. Graphs have different numbers of nodes and edges. An MLP requires a fixed input dimension. A molecular graph might have 10 atoms or 100 atoms — there is no natural way to pad or truncate.
-
No canonical ordering. The nodes of a graph have no natural order. An MLP applied to the flattened adjacency matrix would produce different outputs for different node orderings — the representation is not permutation invariant (for graph-level tasks) or permutation equivariant (for node-level tasks). If we relabel node 3 as node 7, the prediction should not change.
-
No regular connectivity. A CNN assumes each "pixel" has the same number of neighbors in the same spatial arrangement. Graph nodes have variable-degree neighborhoods with no spatial layout.
These are not engineering inconveniences — they are fundamental mathematical constraints. Any neural network for graphs must be:
- Permutation equivariant for node-level outputs: $f(\mathbf{P}\mathbf{X}, \mathbf{P}\mathbf{A}\mathbf{P}^\top) = \mathbf{P} f(\mathbf{X}, \mathbf{A})$ for any permutation matrix $\mathbf{P}$.
- Permutation invariant for graph-level outputs: $g(\mathbf{P}\mathbf{X}, \mathbf{P}\mathbf{A}\mathbf{P}^\top) = g(\mathbf{X}, \mathbf{A})$.
14.3 The Spectral Perspective: Graph Convolution via Eigendecomposition
The Graph Fourier Transform
In classical signal processing, the Fourier transform decomposes a signal into frequency components — eigenfunctions of the Laplacian operator on the real line. The same idea extends to graphs.
The normalized Laplacian has an eigendecomposition (Chapter 1):
$$\tilde{\mathbf{L}} = \mathbf{U} \mathbf{\Lambda} \mathbf{U}^\top$$
where $\mathbf{U} = [\mathbf{u}_1, \mathbf{u}_2, \ldots, \mathbf{u}_N]$ is the matrix of orthonormal eigenvectors and $\mathbf{\Lambda} = \text{diag}(\lambda_1, \lambda_2, \ldots, \lambda_N)$ is the diagonal matrix of eigenvalues, sorted $0 = \lambda_1 \leq \lambda_2 \leq \cdots \leq \lambda_N$.
The eigenvectors $\mathbf{u}_k$ are the graph Fourier modes, and the eigenvalues $\lambda_k$ are the graph frequencies:
- $\lambda_1 = 0$: the DC component (constant signal across the graph).
- Small $\lambda_k$: low-frequency modes (signals that vary slowly across the graph — nearby nodes have similar values).
- Large $\lambda_k$: high-frequency modes (signals that oscillate rapidly — adjacent nodes have very different values).
The graph Fourier transform of a signal $\mathbf{f} \in \mathbb{R}^N$ is:
$$\hat{\mathbf{f}} = \mathbf{U}^\top \mathbf{f}$$
and the inverse transform is $\mathbf{f} = \mathbf{U} \hat{\mathbf{f}}$. This is a direct analog of the discrete Fourier transform, with the graph Laplacian eigenvectors playing the role of sinusoidal basis functions.
Spectral Graph Convolution
In classical signal processing, convolution in the spatial domain equals pointwise multiplication in the frequency domain. The same principle defines convolution on graphs.
A spectral graph convolution applies a learnable filter $g_\theta$ in the spectral domain:
$$\mathbf{y} = g_\theta(\tilde{\mathbf{L}}) \mathbf{f} = \mathbf{U} \, g_\theta(\mathbf{\Lambda}) \, \mathbf{U}^\top \mathbf{f}$$
where $g_\theta(\mathbf{\Lambda}) = \text{diag}(g_\theta(\lambda_1), g_\theta(\lambda_2), \ldots, g_\theta(\lambda_N))$ is the filter applied pointwise to the eigenvalues.
If $g_\theta$ is unconstrained (one learnable parameter per eigenvalue), we have $N$ parameters per filter — impractical for large graphs. More critically, the filter is defined in terms of the specific eigenvectors of a specific graph, so it does not transfer to other graphs.
Chebyshev Polynomial Approximation
Defferrard et al. (2016) addressed both problems by approximating the spectral filter with Chebyshev polynomials of the eigenvalues:
$$g_\theta(\lambda) \approx \sum_{k=0}^{K-1} \theta_k T_k(\tilde{\lambda})$$
where $T_k$ is the $k$-th Chebyshev polynomial, $\tilde{\lambda} = 2\lambda / \lambda_{\max} - 1$ rescales the eigenvalues to $[-1, 1]$, and $\theta_0, \ldots, \theta_{K-1}$ are the $K$ learnable parameters.
This has two critical advantages:
- $K$ parameters per filter instead of $N$. Typically $K = 2$ or $K = 3$.
- $K$-localized in the spatial domain. A degree-$K$ polynomial of the Laplacian is equivalent to aggregating information from nodes at most $K$ hops away. The filter is spatially localized without ever computing the eigendecomposition.
The filtered signal becomes:
$$\mathbf{y} = \sum_{k=0}^{K-1} \theta_k T_k(\tilde{\mathbf{L}}) \mathbf{f}$$
where $T_k(\tilde{\mathbf{L}})$ is a matrix polynomial computed via the Chebyshev recurrence: $T_0(\tilde{\mathbf{L}}) = \mathbf{I}$, $T_1(\tilde{\mathbf{L}}) = \tilde{\mathbf{L}}$, $T_k(\tilde{\mathbf{L}}) = 2\tilde{\mathbf{L}} T_{k-1}(\tilde{\mathbf{L}}) - T_{k-2}(\tilde{\mathbf{L}})$.
Fundamentals > Frontier: The spectral perspective is more than historical background. It reveals why graph convolution works: it is a principled frequency-domain filter on graph signals. Understanding this connection lets you reason about what GNNs can and cannot do — why they smooth high-frequency signals (a property that leads to over-smoothing), why the Laplacian eigenvalues encode connectivity (the Fiedler value is the algebraic connectivity), and why the spectral approach connects graph learning to the rich mathematical toolkit of spectral graph theory. Modern GNN practice is spatial, but spectral understanding is what separates application from comprehension.
14.4 The Spatial Perspective: Message Passing Neural Networks
From Spectral to Spatial
Kipf and Welling (2017) made the connection to spatial methods explicit. Starting from the Chebyshev filter with $K = 1$ (first-order approximation) and $\lambda_{\max} \approx 2$:
$$g_\theta(\tilde{\mathbf{L}}) \approx \theta_0 \mathbf{I} + \theta_1 (\tilde{\mathbf{L}} - \mathbf{I}) = \theta_0 \mathbf{I} - \theta_1 \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}$$
Setting $\theta_0 = -\theta_1 = \theta$ (a single parameter) and generalizing to multi-dimensional features:
$$\mathbf{H}^{(\ell+1)} = \sigma\!\left( \hat{\mathbf{A}} \mathbf{H}^{(\ell)} \mathbf{W}^{(\ell)} \right)$$
where $\hat{\mathbf{A}} = \tilde{\mathbf{D}}^{-1/2} \tilde{\mathbf{A}} \tilde{\mathbf{D}}^{-1/2}$ is the symmetrically normalized adjacency matrix with self-loops ($\tilde{\mathbf{A}} = \mathbf{A} + \mathbf{I}$, $\tilde{\mathbf{D}}_{ii} = \sum_j \tilde{A}_{ij}$), $\mathbf{H}^{(\ell)} \in \mathbb{R}^{N \times d_\ell}$ is the matrix of node representations at layer $\ell$ (with $\mathbf{H}^{(0)} = \mathbf{X}$), and $\mathbf{W}^{(\ell)} \in \mathbb{R}^{d_\ell \times d_{\ell+1}}$ is a learnable weight matrix.
This is the Graph Convolutional Network (GCN) layer. Reading the matrix multiplication column-by-column reveals its spatial interpretation: for each node $v_i$, the GCN layer computes:
$$\mathbf{h}_i^{(\ell+1)} = \sigma\!\left( \sum_{j \in \mathcal{N}(i) \cup \{i\}} \frac{1}{\sqrt{\tilde{d}_i \tilde{d}_j}} \mathbf{h}_j^{(\ell)} \mathbf{W}^{(\ell)} \right)$$
This is a neighborhood aggregation: the new representation of node $i$ is a (normalized) sum of its neighbors' representations (plus its own), linearly transformed and passed through a nonlinearity. Each GCN layer propagates information one hop through the graph.
The General Message Passing Framework
Gilmer et al. (2017) unified most GNN architectures under the Message Passing Neural Network (MPNN) framework. A single message passing layer consists of three operations:
1. Message construction. For each edge $(v_j, v_i)$, compute a message:
$$\mathbf{m}_{j \to i}^{(\ell)} = \phi_{\text{msg}}\!\left(\mathbf{h}_i^{(\ell)}, \mathbf{h}_j^{(\ell)}, \mathbf{e}_{ij}\right)$$
where $\phi_{\text{msg}}$ is the message function (a learnable function of the source, target, and edge features).
2. Aggregation. Aggregate messages from all neighbors:
$$\mathbf{m}_i^{(\ell)} = \bigoplus_{j \in \mathcal{N}(i)} \mathbf{m}_{j \to i}^{(\ell)}$$
where $\bigoplus$ is a permutation-invariant aggregation function (sum, mean, max, or more sophisticated functions).
3. Update. Compute the new node representation:
$$\mathbf{h}_i^{(\ell+1)} = \phi_{\text{upd}}\!\left(\mathbf{h}_i^{(\ell)}, \mathbf{m}_i^{(\ell)}\right)$$
where $\phi_{\text{upd}}$ is the update function (typically an MLP or GRU).
Different GNN architectures are different instantiations of $\phi_{\text{msg}}$, $\bigoplus$, and $\phi_{\text{upd}}$.
| Architecture | Message $\phi_{\text{msg}}$ | Aggregation $\bigoplus$ | Update $\phi_{\text{upd}}$ |
|---|---|---|---|
| GCN | $\frac{1}{\sqrt{\tilde{d}_i \tilde{d}_j}} \mathbf{h}_j^{(\ell)} \mathbf{W}$ | Sum | Identity + $\sigma$ |
| GraphSAGE | $\mathbf{h}_j^{(\ell)}$ | Mean / Max / LSTM | $\sigma(\mathbf{W} [\mathbf{h}_i^{(\ell)} \| \mathbf{m}_i])$ |
| GAT | $\alpha_{ij} \mathbf{h}_j^{(\ell)} \mathbf{W}$ | Sum | Identity + $\sigma$ |
| GIN | $\mathbf{h}_j^{(\ell)}$ | Sum | $\text{MLP}((1 + \epsilon) \mathbf{h}_i^{(\ell)} + \mathbf{m}_i)$ |
The framework is permutation equivariant by construction: because the aggregation $\bigoplus$ is permutation invariant, the output is independent of the order in which neighbors are processed. This is the mathematical reason that message passing GNNs satisfy the symmetry requirements of Section 14.2.
14.5 GCN: Graph Convolutional Networks
Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
from torch_geometric.data import Data
from typing import Optional, Tuple
import numpy as np
class GCNConv(MessagePassing):
"""Graph Convolutional Network layer (Kipf & Welling, 2017).
Implements the normalized neighborhood aggregation:
h_i' = sigma( sum_{j in N(i) U {i}} (1 / sqrt(d_i * d_j)) * h_j * W )
Uses the symmetric normalization from the original paper, derived from
the first-order Chebyshev approximation of spectral graph convolution.
Args:
in_channels: Dimensionality of input node features.
out_channels: Dimensionality of output node features.
bias: If True, adds a learnable bias to the output.
"""
def __init__(
self, in_channels: int, out_channels: int, bias: bool = True
) -> None:
super().__init__(aggr="add") # Sum aggregation
self.lin = nn.Linear(in_channels, out_channels, bias=False)
if bias:
self.bias = nn.Parameter(torch.zeros(out_channels))
else:
self.bias = None
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.xavier_uniform_(self.lin.weight)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(
self, x: torch.Tensor, edge_index: torch.Tensor
) -> torch.Tensor:
"""Forward pass.
Args:
x: Node feature matrix, shape (N, in_channels).
edge_index: Edge indices in COO format, shape (2, E).
Returns:
Updated node features, shape (N, out_channels).
"""
# Add self-loops: A_tilde = A + I
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Linear transformation: X * W
x = self.lin(x)
# Compute symmetric normalization coefficients
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype) # d_tilde
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0.0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] # 1/sqrt(d_i * d_j)
# Message passing
out = self.propagate(edge_index, x=x, norm=norm)
if self.bias is not None:
out = out + self.bias
return out
def message(
self, x_j: torch.Tensor, norm: torch.Tensor
) -> torch.Tensor:
"""Construct messages from neighbors.
Args:
x_j: Source node features, shape (E, out_channels).
norm: Normalization coefficients, shape (E,).
Returns:
Messages, shape (E, out_channels).
"""
return norm.view(-1, 1) * x_j
class GCN(nn.Module):
"""Multi-layer Graph Convolutional Network for node classification.
Args:
in_channels: Input feature dimensionality.
hidden_channels: Hidden layer dimensionality.
out_channels: Number of output classes.
num_layers: Number of GCN layers.
dropout: Dropout probability.
"""
def __init__(
self,
in_channels: int,
hidden_channels: int,
out_channels: int,
num_layers: int = 2,
dropout: float = 0.5,
) -> None:
super().__init__()
self.dropout = dropout
self.convs = nn.ModuleList()
self.convs.append(GCNConv(in_channels, hidden_channels))
for _ in range(num_layers - 2):
self.convs.append(GCNConv(hidden_channels, hidden_channels))
self.convs.append(GCNConv(hidden_channels, out_channels))
def forward(
self, x: torch.Tensor, edge_index: torch.Tensor
) -> torch.Tensor:
"""Forward pass through all GCN layers.
Args:
x: Node features, shape (N, in_channels).
edge_index: Edge indices, shape (2, E).
Returns:
Node logits, shape (N, out_channels).
"""
for i, conv in enumerate(self.convs[:-1]):
x = conv(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1](x, edge_index)
return x
Training a GCN on Cora
The Cora citation dataset is the standard GNN benchmark: 2,708 scientific publications (nodes) classified into 7 categories, with 5,429 citation links (edges) and 1,433-dimensional bag-of-words features.
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
from typing import Dict, List
def train_gcn_cora(
hidden_channels: int = 64,
num_layers: int = 2,
learning_rate: float = 0.01,
weight_decay: float = 5e-4,
epochs: int = 200,
dropout: float = 0.5,
) -> Dict[str, float]:
"""Train a GCN on the Cora citation network.
Uses the standard semi-supervised split: 20 nodes per class for training,
500 for validation, 1000 for testing.
Args:
hidden_channels: Hidden layer size.
num_layers: Number of GCN layers.
learning_rate: Optimizer learning rate.
weight_decay: L2 regularization.
epochs: Number of training epochs.
dropout: Dropout probability.
Returns:
Dictionary with train, validation, and test accuracy.
"""
dataset = Planetoid(root="/tmp/Cora", name="Cora",
transform=NormalizeFeatures())
data = dataset[0]
model = GCN(
in_channels=dataset.num_features,
hidden_channels=hidden_channels,
out_channels=dataset.num_classes,
num_layers=num_layers,
dropout=dropout,
)
optimizer = torch.optim.Adam(
model.parameters(), lr=learning_rate, weight_decay=weight_decay
)
# Training loop
best_val_acc = 0.0
best_test_acc = 0.0
for epoch in range(epochs):
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
# Evaluation
model.eval()
with torch.no_grad():
logits = model(data.x, data.edge_index)
pred = logits.argmax(dim=1)
train_acc = (pred[data.train_mask] == data.y[data.train_mask]).float().mean().item()
val_acc = (pred[data.val_mask] == data.y[data.val_mask]).float().mean().item()
test_acc = (pred[data.test_mask] == data.y[data.test_mask]).float().mean().item()
if val_acc > best_val_acc:
best_val_acc = val_acc
best_test_acc = test_acc
if (epoch + 1) % 50 == 0:
print(
f"Epoch {epoch+1:3d} | Loss: {loss.item():.4f} | "
f"Train: {train_acc:.4f} | Val: {val_acc:.4f} | "
f"Test: {test_acc:.4f}"
)
print(f"\nBest test accuracy: {best_test_acc:.4f}")
return {
"train_acc": train_acc,
"val_acc": best_val_acc,
"test_acc": best_test_acc,
}
# Expected output (approximate):
# Epoch 50 | Loss: 0.5312 | Train: 0.9857 | Val: 0.7920 | Test: 0.8100
# Epoch 100 | Loss: 0.3408 | Train: 0.9929 | Val: 0.7980 | Test: 0.8150
# Epoch 150 | Loss: 0.2615 | Train: 1.0000 | Val: 0.7940 | Test: 0.8120
# Epoch 200 | Loss: 0.2148 | Train: 1.0000 | Val: 0.7960 | Test: 0.8130
# Best test accuracy: 0.8150
A 2-layer GCN with 64 hidden units achieves approximately 81% test accuracy on Cora — significantly above the ~60% achieved by an MLP on the same features (which ignores graph structure entirely). The gap is the value of the graph.
14.6 GraphSAGE: Inductive Learning via Neighborhood Sampling
The Scalability Problem
GCN has a fundamental scalability limitation: computing the representation of a single node at layer $\ell$ requires the representations of all its neighbors at layer $\ell - 1$. For a $K$-layer GCN, computing one node's embedding requires a $K$-hop neighborhood expansion — which can encompass the entire graph for high-degree nodes or small-world networks.
In the StreamRec graph with 5.2 million nodes, a 2-layer GCN would require materializing the 2-hop neighborhood of every node during training. A user with 40 interactions, each item interacted with by 1,000 users, produces a 2-hop neighborhood of $40 \times 1{,}000 = 40{,}000$ nodes. A 3-layer model would require 40 million nodes per target — more than the entire graph.
The GraphSAGE Solution
Hamilton, Ying, and Leskovec (2017) introduced GraphSAGE (SAmple and aggreGatE) with two key innovations:
-
Neighborhood sampling. Instead of aggregating all neighbors, sample a fixed-size subset at each layer. If we sample $k$ neighbors per layer and have $L$ layers, the computation graph per target node has at most $k^L$ nodes — a controllable, fixed budget.
-
Inductive learning. GraphSAGE learns an aggregation function (not a fixed embedding lookup), so it generalizes to unseen nodes. A new user who joins StreamRec after training can be embedded by aggregating their neighbors' features — no retraining required.
The GraphSAGE Layer
$$\mathbf{h}_{\mathcal{N}(i)}^{(\ell)} = \text{AGG}\!\left(\left\{\mathbf{h}_j^{(\ell)}, \forall j \in \text{SAMPLE}(\mathcal{N}(i), k)\right\}\right)$$
$$\mathbf{h}_i^{(\ell+1)} = \sigma\!\left(\mathbf{W}^{(\ell)} \cdot [\mathbf{h}_i^{(\ell)} \,\|\, \mathbf{h}_{\mathcal{N}(i)}^{(\ell)}]\right)$$
where $[\cdot \| \cdot]$ denotes concatenation. The key difference from GCN: GraphSAGE explicitly separates the node's own representation from its aggregated neighborhood, then concatenates them before the linear transformation. This preserves the node's identity more strongly than GCN's symmetric sum.
The aggregation function AGG can be: - Mean: $\text{AGG} = \frac{1}{|\mathcal{S}|} \sum_{j \in \mathcal{S}} \mathbf{h}_j$ (equivalent to GCN without degree normalization) - Max: $\text{AGG} = \max_{j \in \mathcal{S}} \sigma(\mathbf{W}_{\text{pool}} \mathbf{h}_j + \mathbf{b}_{\text{pool}})$ (element-wise max after a nonlinear transformation) - LSTM: Process the sampled neighbors through an LSTM (requires a random permutation since neighbors have no order)
Implementation with Mini-Batch Sampling
from torch_geometric.nn import SAGEConv
from torch_geometric.loader import NeighborLoader
class GraphSAGE(nn.Module):
"""GraphSAGE model for inductive node classification/embedding.
Uses mean aggregation with concatenation of self and neighbor
representations. Supports mini-batch training via neighborhood sampling.
Args:
in_channels: Input feature dimensionality.
hidden_channels: Hidden layer dimensionality.
out_channels: Output dimensionality.
num_layers: Number of SAGE layers.
dropout: Dropout probability.
"""
def __init__(
self,
in_channels: int,
hidden_channels: int,
out_channels: int,
num_layers: int = 2,
dropout: float = 0.5,
) -> None:
super().__init__()
self.dropout = dropout
self.convs = nn.ModuleList()
self.convs.append(SAGEConv(in_channels, hidden_channels))
for _ in range(num_layers - 2):
self.convs.append(SAGEConv(hidden_channels, hidden_channels))
self.convs.append(SAGEConv(hidden_channels, out_channels))
def forward(
self, x: torch.Tensor, edge_index: torch.Tensor
) -> torch.Tensor:
"""Forward pass.
Args:
x: Node features, shape (N, in_channels).
edge_index: Edge indices, shape (2, E).
Returns:
Node embeddings, shape (N, out_channels).
"""
for i, conv in enumerate(self.convs[:-1]):
x = conv(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1](x, edge_index)
return x
def train_graphsage_minibatch(
data: Data,
num_classes: int,
hidden_channels: int = 64,
num_layers: int = 2,
batch_size: int = 512,
num_neighbors: list = None,
epochs: int = 50,
lr: float = 0.005,
) -> Dict[str, float]:
"""Train GraphSAGE with mini-batch neighborhood sampling.
This is the training paradigm for large graphs: instead of loading
the entire graph into memory, sample fixed-size neighborhoods around
a batch of target nodes.
Args:
data: PyG Data object with features, edges, and masks.
num_classes: Number of output classes.
hidden_channels: Hidden dimensionality.
num_layers: Number of layers.
batch_size: Number of target nodes per mini-batch.
num_neighbors: Neighbors to sample at each layer, e.g., [25, 10].
epochs: Number of training epochs.
lr: Learning rate.
Returns:
Dictionary with train and test accuracy.
"""
if num_neighbors is None:
num_neighbors = [25, 10] # Sample 25 at layer 1, 10 at layer 2
train_loader = NeighborLoader(
data,
num_neighbors=num_neighbors,
batch_size=batch_size,
input_nodes=data.train_mask,
shuffle=True,
)
model = GraphSAGE(
in_channels=data.num_features,
hidden_channels=hidden_channels,
out_channels=num_classes,
num_layers=num_layers,
)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
for epoch in range(epochs):
model.train()
total_loss = 0
total_correct = 0
total_nodes = 0
for batch in train_loader:
optimizer.zero_grad()
out = model(batch.x, batch.edge_index)
# Only compute loss on target nodes (first batch_size nodes)
target_out = out[:batch.batch_size]
target_y = batch.y[:batch.batch_size]
loss = F.cross_entropy(target_out, target_y)
loss.backward()
optimizer.step()
total_loss += loss.item() * batch.batch_size
total_correct += (target_out.argmax(dim=1) == target_y).sum().item()
total_nodes += batch.batch_size
if (epoch + 1) % 10 == 0:
train_acc = total_correct / total_nodes
print(f"Epoch {epoch+1:3d} | Loss: {total_loss/total_nodes:.4f} | "
f"Train Acc: {train_acc:.4f}")
# Full-graph evaluation
model.eval()
with torch.no_grad():
out = model(data.x, data.edge_index)
pred = out.argmax(dim=1)
test_acc = (pred[data.test_mask] == data.y[data.test_mask]).float().mean().item()
print(f"\nTest accuracy: {test_acc:.4f}")
return {"test_acc": test_acc}
The mini-batch sampling is the practical workhorse: for a graph with millions of nodes, it reduces memory from $O(N)$ (full-batch GCN) to $O(B \cdot k^L)$ (where $B$ is the batch size, $k$ is the neighbor sample size, and $L$ is the number of layers).
14.7 GAT: Graph Attention Networks
Motivation: Beyond Symmetric Normalization
GCN normalizes neighbor contributions by degree: $1/\sqrt{\tilde{d}_i \tilde{d}_j}$. This treats all neighbors equally (up to degree correction). GraphSAGE with mean aggregation is similar — it averages over neighbors uniformly.
But not all neighbors are equally informative. In a citation network, a paper may cite both a foundational reference and a tangentially related work; the foundational citation should contribute more to the paper's representation. In the StreamRec interaction graph, a user's deep engagement with a documentary carries more signal than a 10-second accidental click.
Graph Attention Networks (GAT) (Velickovic et al., 2018) solve this by learning data-dependent attention weights — the same core idea as the transformer attention mechanism (Chapter 10), applied to graph neighborhoods.
The GAT Attention Mechanism
For a single attention head, the GAT layer computes:
Step 1: Linear transformation. Transform node features:
$$\mathbf{z}_i = \mathbf{W} \mathbf{h}_i, \quad \mathbf{z}_i \in \mathbb{R}^{d'}$$
Step 2: Attention coefficients. For each edge $(v_j, v_i)$, compute a raw attention score:
$$e_{ij} = \text{LeakyReLU}\!\left(\mathbf{a}^\top [\mathbf{z}_i \,\|\, \mathbf{z}_j]\right)$$
where $\mathbf{a} \in \mathbb{R}^{2d'}$ is a learnable attention vector and $[\cdot \| \cdot]$ is concatenation.
Step 3: Normalization. Softmax over the neighborhood:
$$\alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k \in \mathcal{N}(i) \cup \{i\}} \exp(e_{ik})}$$
Step 4: Weighted aggregation:
$$\mathbf{h}_i^{(\ell+1)} = \sigma\!\left(\sum_{j \in \mathcal{N}(i) \cup \{i\}} \alpha_{ij} \mathbf{z}_j\right)$$
Multi-Head Attention
As in the transformer, multiple attention heads capture different types of relationships. With $H$ heads:
$$\mathbf{h}_i^{(\ell+1)} = \Big\|_{h=1}^{H} \sigma\!\left(\sum_{j \in \mathcal{N}(i) \cup \{i\}} \alpha_{ij}^{(h)} \mathbf{W}^{(h)} \mathbf{h}_j^{(\ell)}\right)$$
where $\|$ denotes concatenation. For the final layer, the heads are averaged instead of concatenated to produce a fixed-size output.
Comparing GAT and Transformer Attention
The connection between GAT and transformer self-attention is deep:
| Aspect | Transformer (Ch. 10) | GAT |
|---|---|---|
| Input structure | Sequence (positions) | Graph (neighborhoods) |
| Attention scope | All positions (global) | Neighbors only (local) |
| Score computation | $\mathbf{q}_i^\top \mathbf{k}_j / \sqrt{d_k}$ | $\mathbf{a}^\top [\mathbf{z}_i \| \mathbf{z}_j]$ |
| Positional information | Positional encoding | Graph topology |
| Complexity per node | $O(n \cdot d)$ | $O(|\mathcal{N}(i)| \cdot d)$ |
A transformer with a full attention mask is a GAT on a complete graph. A GAT with a sparse adjacency is a transformer with a structured sparsity mask. The graph provides the inductive bias that the transformer must learn from data.
Implementation
from torch_geometric.nn import GATConv
class GAT(nn.Module):
"""Graph Attention Network for node classification.
Uses multi-head attention in intermediate layers (concatenation)
and single-head or averaged multi-head in the output layer.
Args:
in_channels: Input feature dimensionality.
hidden_channels: Hidden channels per attention head.
out_channels: Number of output classes.
num_layers: Number of GAT layers.
heads: Number of attention heads in intermediate layers.
dropout: Dropout probability (applied to both features and attention).
"""
def __init__(
self,
in_channels: int,
hidden_channels: int,
out_channels: int,
num_layers: int = 2,
heads: int = 8,
dropout: float = 0.6,
) -> None:
super().__init__()
self.dropout = dropout
self.convs = nn.ModuleList()
# First layer: in_channels -> hidden_channels * heads
self.convs.append(
GATConv(in_channels, hidden_channels, heads=heads, dropout=dropout)
)
# Intermediate layers
for _ in range(num_layers - 2):
self.convs.append(
GATConv(
hidden_channels * heads, hidden_channels,
heads=heads, dropout=dropout,
)
)
# Output layer: average heads instead of concatenating
self.convs.append(
GATConv(
hidden_channels * heads, out_channels,
heads=1, concat=False, dropout=dropout,
)
)
def forward(
self, x: torch.Tensor, edge_index: torch.Tensor
) -> torch.Tensor:
"""Forward pass.
Args:
x: Node features, shape (N, in_channels).
edge_index: Edge indices, shape (2, E).
Returns:
Node logits, shape (N, out_channels).
"""
for i, conv in enumerate(self.convs[:-1]):
x = F.dropout(x, p=self.dropout, training=self.training)
x = conv(x, edge_index)
x = F.elu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1](x, edge_index)
return x
14.8 GIN: The Most Powerful Message Passing GNN
The Weisfeiler-Leman Test
How powerful can a message passing GNN be? Xu et al. (2019) answered this question by connecting GNNs to the Weisfeiler-Leman (WL) graph isomorphism test — a classical algorithm from graph theory.
The 1-WL test iteratively refines node labels:
- Initialize: each node receives a label based on its features (or degree, if no features).
- Iterate: each node's label is updated to a hash of its current label and the multiset of its neighbors' labels.
- Terminate: when labels stabilize (no further refinement), the multiset of all node labels forms a graph fingerprint.
Two graphs that the 1-WL test assigns different fingerprints are provably non-isomorphic (they have different structures). However, the test has blind spots: certain non-isomorphic graphs (e.g., $k$-regular graphs with the same number of nodes) receive identical fingerprints.
GNNs and the WL Hierarchy
Xu et al. (2019) proved two fundamental results:
-
Upper bound: No message passing GNN can distinguish graphs that the 1-WL test cannot distinguish. The message passing framework is at most as powerful as 1-WL.
-
Achievability: A message passing GNN is as powerful as 1-WL if and only if: - The aggregation function is injective on multisets (it can distinguish different multisets of neighbor features). - The update function is injective (it can distinguish different combinations of self-features and aggregated neighborhoods).
This is a profound result. It says that the sum aggregation (which preserves multiset information) is strictly more powerful than mean or max aggregation (which lose information). And it says that no amount of architectural engineering within the message passing framework can exceed the 1-WL bound.
The Graph Isomorphism Network
The Graph Isomorphism Network (GIN) is designed to be maximally powerful under these constraints:
$$\mathbf{h}_i^{(\ell+1)} = \text{MLP}^{(\ell)}\!\left((1 + \epsilon^{(\ell)}) \cdot \mathbf{h}_i^{(\ell)} + \sum_{j \in \mathcal{N}(i)} \mathbf{h}_j^{(\ell)}\right)$$
where $\epsilon^{(\ell)}$ is a learnable scalar (or fixed to 0). The sum aggregation preserves multiset information, and the MLP provides the injective update function (by the universal approximation theorem, an MLP can approximate any function, including injective ones).
from torch_geometric.nn import GINConv
class GIN(nn.Module):
"""Graph Isomorphism Network — maximally expressive MPNN.
Uses sum aggregation and MLP update functions to achieve
the theoretical maximum expressiveness of the message passing
framework (equivalent to the 1-WL test).
Args:
in_channels: Input feature dimensionality.
hidden_channels: Hidden layer dimensionality.
out_channels: Number of output classes.
num_layers: Number of GIN layers.
dropout: Dropout probability.
"""
def __init__(
self,
in_channels: int,
hidden_channels: int,
out_channels: int,
num_layers: int = 5,
dropout: float = 0.5,
) -> None:
super().__init__()
self.dropout = dropout
self.convs = nn.ModuleList()
self.batch_norms = nn.ModuleList()
for i in range(num_layers):
in_dim = in_channels if i == 0 else hidden_channels
mlp = nn.Sequential(
nn.Linear(in_dim, hidden_channels),
nn.BatchNorm1d(hidden_channels),
nn.ReLU(),
nn.Linear(hidden_channels, hidden_channels),
)
self.convs.append(GINConv(mlp, train_eps=True))
self.batch_norms.append(nn.BatchNorm1d(hidden_channels))
self.classifier = nn.Linear(hidden_channels, out_channels)
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
batch: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass for graph classification.
Args:
x: Node features, shape (total_nodes, in_channels).
edge_index: Edge indices, shape (2, total_edges).
batch: Batch assignment vector, shape (total_nodes,).
Returns:
Graph-level logits, shape (num_graphs, out_channels).
"""
from torch_geometric.nn import global_add_pool
for conv, bn in zip(self.convs, self.batch_norms):
x = conv(x, edge_index)
x = bn(x)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# Graph-level readout: sum over all nodes in each graph
x = global_add_pool(x, batch)
return self.classifier(x)
14.9 Graph-Level Tasks: Readout and Pooling
From Node Embeddings to Graph Embeddings
Node classification and link prediction operate on individual nodes or pairs. Many applications require a representation of the entire graph: molecular property prediction (is this molecule toxic?), graph classification (is this social network a bot ring?), graph regression (what is this polymer's melting point?).
The transition from node embeddings to a graph embedding requires a readout (or pooling) function:
$$\mathbf{h}_{\mathcal{G}} = \text{READOUT}\!\left(\left\{\mathbf{h}_i^{(L)} : v_i \in \mathcal{V}\right\}\right)$$
The readout must be permutation invariant (graph-level properties do not depend on node ordering). Common choices:
| Readout | Formula | Properties |
|---|---|---|
| Sum | $\sum_i \mathbf{h}_i^{(L)}$ | Preserves multiset structure; sensitive to graph size |
| Mean | $\frac{1}{N} \sum_i \mathbf{h}_i^{(L)}$ | Size-invariant; loses cardinality information |
| Max | $\max_i \mathbf{h}_i^{(L)}$ (element-wise) | Captures most "activated" features; ignores frequency |
| Set2Set | Attention-based | Learnable, most expressive; higher cost |
For graph classification, multi-scale readout — concatenating sum, mean, and max — often outperforms any single readout because different properties depend on different aggregation semantics. Sum captures total activity (useful for molecular weight prediction), mean captures average behavior (useful for density), and max captures extreme features (useful for toxicity, where a single functional group can be decisive).
Link Prediction
Link prediction asks: given a graph with some edges, predict which edges are missing or will appear in the future. This is the fundamental problem in recommendation (predict user-item interactions), knowledge graph completion (predict missing facts), and social network analysis (predict future friendships).
The standard approach embeds pairs of nodes and scores the likelihood of an edge:
$$\text{score}(v_i, v_j) = f(\mathbf{h}_i^{(L)}, \mathbf{h}_j^{(L)})$$
Common scoring functions include:
- Dot product: $\text{score} = \mathbf{h}_i^\top \mathbf{h}_j$ (simplest, equivalent to matrix factorization in the linear case)
- Bilinear: $\text{score} = \mathbf{h}_i^\top \mathbf{W} \mathbf{h}_j$ (learnable interaction matrix)
- MLP: $\text{score} = \text{MLP}([\mathbf{h}_i \| \mathbf{h}_j])$ (most expressive, non-symmetric unless the MLP is designed to be)
from torch_geometric.utils import negative_sampling
class LinkPredictor(nn.Module):
"""Link prediction head using dot product scoring.
Given node embeddings from a GNN encoder, predicts edge
existence probability via dot product similarity.
Args:
in_channels: Dimensionality of node embeddings.
hidden_channels: Hidden layer size for optional MLP scorer.
use_mlp: If True, use MLP scorer instead of dot product.
"""
def __init__(
self,
in_channels: int,
hidden_channels: int = 64,
use_mlp: bool = False,
) -> None:
super().__init__()
self.use_mlp = use_mlp
if use_mlp:
self.scorer = nn.Sequential(
nn.Linear(2 * in_channels, hidden_channels),
nn.ReLU(),
nn.Linear(hidden_channels, 1),
)
def forward(
self,
z: torch.Tensor,
edge_index: torch.Tensor,
) -> torch.Tensor:
"""Score edges.
Args:
z: Node embeddings, shape (N, in_channels).
edge_index: Edges to score, shape (2, E).
Returns:
Edge scores, shape (E,).
"""
src, dst = edge_index
if self.use_mlp:
edge_feat = torch.cat([z[src], z[dst]], dim=1)
return self.scorer(edge_feat).squeeze(-1)
else:
return (z[src] * z[dst]).sum(dim=1)
def train_link_prediction(
data: Data,
hidden_channels: int = 64,
num_layers: int = 2,
epochs: int = 100,
lr: float = 0.01,
) -> Dict[str, float]:
"""Train a GNN for link prediction with negative sampling.
Splits edges into train/val/test, trains a GCN encoder
with a dot-product decoder, and evaluates with AUC.
Args:
data: PyG Data object.
hidden_channels: GNN hidden dimensionality.
num_layers: Number of GNN layers.
epochs: Training epochs.
lr: Learning rate.
Returns:
Dictionary with train and test AUC.
"""
from torch_geometric.transforms import RandomLinkSplit
from sklearn.metrics import roc_auc_score
# Split edges: 85% train, 5% val, 10% test
transform = RandomLinkSplit(
num_val=0.05, num_test=0.1,
is_undirected=True, add_negative_train_samples=True,
)
train_data, val_data, test_data = transform(data)
encoder = GCN(
in_channels=data.num_features,
hidden_channels=hidden_channels,
out_channels=hidden_channels,
num_layers=num_layers,
)
decoder = LinkPredictor(in_channels=hidden_channels)
optimizer = torch.optim.Adam(
list(encoder.parameters()) + list(decoder.parameters()), lr=lr
)
for epoch in range(epochs):
encoder.train()
decoder.train()
optimizer.zero_grad()
z = encoder(train_data.x, train_data.edge_index)
# Positive edges
pos_score = decoder(z, train_data.edge_label_index[:, train_data.edge_label == 1].T
if hasattr(train_data, 'edge_label') else train_data.edge_label_index)
# Negative sampling
neg_edge_index = negative_sampling(
train_data.edge_index,
num_nodes=train_data.num_nodes,
num_neg_samples=train_data.edge_label_index.size(1) // 2,
)
neg_score = decoder(z, neg_edge_index)
# Binary cross-entropy loss
pos_loss = F.binary_cross_entropy_with_logits(
pos_score, torch.ones_like(pos_score)
)
neg_loss = F.binary_cross_entropy_with_logits(
neg_score, torch.zeros_like(neg_score)
)
loss = pos_loss + neg_loss
loss.backward()
optimizer.step()
if (epoch + 1) % 20 == 0:
encoder.eval()
decoder.eval()
with torch.no_grad():
z = encoder(test_data.x, test_data.edge_index)
pos_score = decoder(z, test_data.edge_label_index[:, test_data.edge_label == 1].T
if hasattr(test_data, 'edge_label') else test_data.edge_label_index)
neg_edge_index = negative_sampling(
test_data.edge_index,
num_nodes=test_data.num_nodes,
num_neg_samples=pos_score.size(0),
)
neg_score = decoder(z, neg_edge_index)
scores = torch.cat([pos_score, neg_score]).sigmoid().cpu()
labels = torch.cat([
torch.ones(pos_score.size(0)),
torch.zeros(neg_score.size(0)),
])
auc = roc_auc_score(labels.numpy(), scores.numpy())
print(f"Epoch {epoch+1:3d} | Loss: {loss.item():.4f} | Test AUC: {auc:.4f}")
return {"test_auc": auc}
14.10 Over-Smoothing: The Depth Limit of GNNs
The Problem
Stack enough GCN layers and something pathological happens: all node representations converge to the same vector. This is over-smoothing, and it is the fundamental reason that GNNs are typically shallow (2-3 layers) compared to the deep CNNs (50-150 layers) and transformers (12-96 layers) we studied in previous chapters.
The mathematical explanation is direct. A single GCN layer computes:
$$\mathbf{H}^{(\ell+1)} = \sigma\!\left(\hat{\mathbf{A}} \mathbf{H}^{(\ell)} \mathbf{W}^{(\ell)}\right)$$
Ignoring the nonlinearity and weight matrix, iterating $K$ times gives $\mathbf{H}^{(K)} \propto \hat{\mathbf{A}}^K \mathbf{X}$. Since $\hat{\mathbf{A}}$ is a (doubly) stochastic matrix, its powers converge to a rank-one matrix: $\hat{\mathbf{A}}^K \to \frac{1}{N} \mathbf{1} \mathbf{1}^\top$ as $K \to \infty$ (for connected graphs). Every node's representation converges to the same value — the graph-level mean.
From the spectral perspective: multiplication by $\hat{\mathbf{A}}$ is a low-pass filter. Each application attenuates high-frequency components (eigenvectors corresponding to large Laplacian eigenvalues). After $K$ applications, only the lowest-frequency component (the DC component) survives. The node features are smoothed to uniformity.
Empirical Evidence
On Cora, a 2-layer GCN achieves ~81% accuracy. A 4-layer GCN drops to ~78%. An 8-layer GCN drops to ~65%. A 16-layer GCN is near random (~20%). The representations become indistinguishable.
Mitigation Strategies
Several techniques combat over-smoothing:
-
Residual connections. Add skip connections as in ResNets (Chapter 7): $\mathbf{h}_i^{(\ell+1)} = \mathbf{h}_i^{(\ell+1)} + \mathbf{h}_i^{(\ell)}$. This preserves the original signal alongside the smoothed aggregation.
-
JKNet (Jumping Knowledge Networks). Concatenate or attention-pool representations from all layers: $\mathbf{h}_i^{\text{final}} = f(\mathbf{h}_i^{(1)}, \mathbf{h}_i^{(2)}, \ldots, \mathbf{h}_i^{(L)})$. Different nodes may benefit from different receptive field sizes.
-
DropEdge. Randomly remove edges during training, analogous to dropout. This reduces the effective propagation range and acts as a regularizer.
-
PairNorm and DiffGroupNorm. Normalization schemes that maintain the total variation of node representations across layers, preventing collapse to a single point.
-
Graph transformers. Replace local message passing with global attention, as in the original transformer. This avoids the depth-smoothing tradeoff entirely but sacrifices the inductive bias of local neighborhoods and has $O(N^2)$ complexity.
Understanding Why: Over-smoothing is not a bug in GNN implementations — it is an inherent property of iterated neighborhood averaging on graphs. Understanding this through the spectral lens (repeated low-pass filtering) connects directly to Chapter 8's discussion of CNNs as spatial filters and Chapter 1's treatment of eigendecomposition. The practical consequence is that GNN depth is not a free hyperparameter to increase for more "expressiveness" — it must be balanced against the information loss from smoothing. This is qualitatively different from CNNs and transformers, where depth generally helps.
14.11 Heterogeneous Graphs and Knowledge Graphs
Beyond Homogeneous Graphs
Real-world graphs are rarely homogeneous (one node type, one edge type). The StreamRec interaction graph has at least two node types (users and items) and potentially multiple edge types (views, likes, purchases, shares). A knowledge graph has entity nodes of many types (person, organization, location, concept) and relation edges of many types (works-at, born-in, is-a, part-of).
A heterogeneous graph $\mathcal{G} = (\mathcal{V}, \mathcal{E}, \tau, \phi)$ adds: - A node type function $\tau: \mathcal{V} \to \mathcal{T}$ mapping each node to one of $|\mathcal{T}|$ types. - An edge type function $\phi: \mathcal{E} \to \mathcal{R}$ mapping each edge to one of $|\mathcal{R}|$ relation types.
Relational GCN (R-GCN)
The simplest heterogeneous extension uses separate weight matrices per relation type:
$$\mathbf{h}_i^{(\ell+1)} = \sigma\!\left(\sum_{r \in \mathcal{R}} \sum_{j \in \mathcal{N}_r(i)} \frac{1}{|\mathcal{N}_r(i)|} \mathbf{W}_r^{(\ell)} \mathbf{h}_j^{(\ell)} + \mathbf{W}_0^{(\ell)} \mathbf{h}_i^{(\ell)}\right)$$
where $\mathcal{N}_r(i)$ is the set of neighbors connected to $v_i$ via relation $r$. This requires $|\mathcal{R}|$ weight matrices per layer, which can be regularized via basis decomposition:
$$\mathbf{W}_r = \sum_{b=1}^{B} a_{rb} \mathbf{V}_b$$
where $\mathbf{V}_1, \ldots, \mathbf{V}_B$ are shared basis matrices and $a_{rb}$ are relation-specific coefficients. This reduces parameters from $O(|\mathcal{R}| \cdot d^2)$ to $O(B \cdot d^2 + |\mathcal{R}| \cdot B)$.
Knowledge Graph Embeddings
Knowledge graphs represent facts as triples $(h, r, t)$: head entity $h$, relation $r$, tail entity $t$. Examples: (aspirin, treats, headache), (penicillin, interacts_with, warfarin). Knowledge graph completion — predicting missing triples — combines GNN encoders (to compute entity embeddings from the graph structure) with scoring functions:
- TransE: $\text{score}(h, r, t) = -\|\mathbf{h} + \mathbf{r} - \mathbf{t}\|$ (translations in embedding space)
- DistMult: $\text{score}(h, r, t) = \mathbf{h}^\top \text{diag}(\mathbf{r}) \mathbf{t}$ (bilinear scoring)
- RotatE: $\text{score}(h, r, t) = -\|\mathbf{h} \circ \mathbf{r} - \mathbf{t}\|$ (rotations in complex space)
In the MediCore pharmaceutical context, knowledge graph completion can predict drug-drug interactions, drug-target bindings, and disease-gene associations — turning the known biomedical knowledge graph into a predictive tool.
14.12 PyTorch Geometric: A Practical Framework
The Data Object
PyTorch Geometric (PyG) represents graphs with the Data class, storing node features, edge indices, labels, and metadata in a single object:
import torch
from torch_geometric.data import Data
def create_example_graph() -> Data:
"""Create a small example graph for illustration.
Builds a graph with 6 nodes, 8 undirected edges, and
3-dimensional node features. This is the graph:
0---1---2
| | |
3---4---5
Args: None.
Returns:
PyG Data object with features, edges, and labels.
"""
# Node features: 6 nodes, 3 features each
x = torch.tensor([
[1.0, 0.0, 0.5], # Node 0
[0.5, 1.0, 0.0], # Node 1
[0.0, 0.5, 1.0], # Node 2
[1.0, 0.5, 0.0], # Node 3
[0.5, 0.5, 0.5], # Node 4
[0.0, 1.0, 0.5], # Node 5
], dtype=torch.float)
# Edge indices in COO format (undirected: each edge listed twice)
edge_index = torch.tensor([
[0, 1, 1, 2, 0, 3, 1, 4, 2, 5, 3, 4, 4, 5],
[1, 0, 2, 1, 3, 0, 4, 1, 5, 2, 4, 3, 5, 4],
], dtype=torch.long)
# Node labels (for classification)
y = torch.tensor([0, 0, 1, 0, 1, 1], dtype=torch.long)
data = Data(x=x, edge_index=edge_index, y=y)
print(f"Number of nodes: {data.num_nodes}")
print(f"Number of edges: {data.num_edges}")
print(f"Number of features: {data.num_features}")
print(f"Has isolated nodes: {data.has_isolated_nodes()}")
print(f"Has self-loops: {data.has_self_loops()}")
print(f"Is undirected: {data.is_undirected()}")
return data
# Number of nodes: 6
# Number of edges: 14
# Number of features: 3
# Has isolated nodes: False
# Has self-loops: False
# Is undirected: True
Batching Multiple Graphs
For graph-level tasks (e.g., molecular property prediction), each training example is a separate graph. PyG batches multiple graphs into a single disconnected graph, using a batch vector to track which node belongs to which graph:
from torch_geometric.loader import DataLoader as PyGDataLoader
def demonstrate_batching() -> None:
"""Demonstrate PyG's graph batching mechanism.
Multiple small graphs are combined into a single large
disconnected graph. The batch vector maps each node to
its source graph, enabling graph-level pooling.
"""
# Create 4 small graphs with different sizes
graphs = []
for i in range(4):
n_nodes = np.random.randint(5, 15)
n_edges = np.random.randint(n_nodes, n_nodes * 3)
x = torch.randn(n_nodes, 8)
edge_index = torch.randint(0, n_nodes, (2, n_edges))
y = torch.tensor([i % 2], dtype=torch.long) # Binary graph label
graphs.append(Data(x=x, edge_index=edge_index, y=y))
loader = PyGDataLoader(graphs, batch_size=2, shuffle=False)
for batch in loader:
print(f"Batch: {batch}")
print(f" Total nodes: {batch.num_nodes}")
print(f" Total edges: {batch.num_edges}")
print(f" Batch vector shape: {batch.batch.shape}")
print(f" Number of graphs: {batch.num_graphs}")
print(f" Graph labels: {batch.y}")
print()
Building a Complete Pipeline
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader as PyGDataLoader
from torch_geometric.nn import global_mean_pool, global_add_pool
from sklearn.model_selection import StratifiedKFold
def molecular_classification_pipeline(
dataset_name: str = "MUTAG",
hidden_channels: int = 64,
num_layers: int = 3,
epochs: int = 100,
batch_size: int = 32,
lr: float = 0.01,
n_folds: int = 10,
) -> Dict[str, float]:
"""Graph classification pipeline for molecular datasets.
Trains a GIN model on a molecular benchmark dataset with
10-fold cross-validation, following standard evaluation protocol.
Args:
dataset_name: Name of the TU dataset (e.g., MUTAG, PTC_MR, PROTEINS).
hidden_channels: Hidden layer dimensionality.
num_layers: Number of GIN layers.
epochs: Training epochs per fold.
batch_size: Batch size.
lr: Learning rate.
n_folds: Number of cross-validation folds.
Returns:
Dictionary with mean and std accuracy across folds.
"""
dataset = TUDataset(root=f"/tmp/{dataset_name}", name=dataset_name)
print(f"Dataset: {dataset_name}")
print(f" Graphs: {len(dataset)}")
print(f" Features: {dataset.num_features}")
print(f" Classes: {dataset.num_classes}")
# Cross-validation
labels = np.array([data.y.item() for data in dataset])
skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
fold_accs = []
for fold, (train_idx, test_idx) in enumerate(skf.split(labels, labels)):
train_dataset = [dataset[i] for i in train_idx]
test_dataset = [dataset[i] for i in test_idx]
train_loader = PyGDataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = PyGDataLoader(test_dataset, batch_size=batch_size)
model = GIN(
in_channels=dataset.num_features,
hidden_channels=hidden_channels,
out_channels=dataset.num_classes,
num_layers=num_layers,
)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
# Training
for epoch in range(epochs):
model.train()
for batch in train_loader:
optimizer.zero_grad()
out = model(batch.x, batch.edge_index, batch.batch)
loss = F.cross_entropy(out, batch.y)
loss.backward()
optimizer.step()
scheduler.step()
# Evaluation
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch in test_loader:
out = model(batch.x, batch.edge_index, batch.batch)
pred = out.argmax(dim=1)
correct += (pred == batch.y).sum().item()
total += batch.y.size(0)
fold_acc = correct / total
fold_accs.append(fold_acc)
print(f" Fold {fold+1:2d} | Accuracy: {fold_acc:.4f}")
mean_acc = np.mean(fold_accs)
std_acc = np.std(fold_accs)
print(f"\nResult: {mean_acc:.4f} +/- {std_acc:.4f}")
return {"mean_acc": mean_acc, "std_acc": std_acc}
14.13 Beyond Standard Message Passing
Graph Transformers
The over-smoothing problem and the WL expressiveness limit have motivated architectures that go beyond local message passing:
Graph Transformer (Dwivedi & Bresson, 2021; Ying et al., 2021): Replace local neighborhood aggregation with global self-attention over all nodes. Each node attends to every other node, with positional encodings derived from the graph structure (e.g., random walk positional encodings, Laplacian eigenvector positional encodings).
The attention mechanism is:
$$\alpha_{ij} = \frac{\exp\left(\frac{(\mathbf{W}_Q \mathbf{h}_i)^\top (\mathbf{W}_K \mathbf{h}_j)}{\sqrt{d_k}} + b_{\phi(i,j)}\right)}{\sum_{k \in \mathcal{V}} \exp\left(\frac{(\mathbf{W}_Q \mathbf{h}_i)^\top (\mathbf{W}_K \mathbf{h}_k)}{\sqrt{d_k}} + b_{\phi(i,k)}\right)}$$
where $b_{\phi(i,j)}$ is a learnable bias based on the structural relationship between nodes $i$ and $j$ (e.g., shortest path distance, relative positional encoding).
Graph transformers can be more expressive than 1-WL message passing, but they lose the sparsity advantage: attention is $O(N^2)$ instead of $O(E)$. For large graphs, this is prohibitive without additional sparsification.
Higher-Order Message Passing
The $k$-WL hierarchy provides a roadmap for more powerful GNNs. While 1-WL operates on nodes, $k$-WL operates on $k$-tuples. A 2-WL GNN (e.g., Morris et al., 2019) passes messages between pairs of nodes, enabling it to distinguish graphs that 1-WL cannot — at the cost of $O(N^2)$ memory and $O(N^3)$ computation per layer.
Practical higher-order GNNs include:
- PPGN (Provably Powerful Graph Networks): Operate on $N \times N$ feature matrices.
- SpeqNet: Use equivariant operations on subgraph structures.
- Subgraph GNNs: Run standard message passing on node-induced subgraphs and aggregate the results.
These methods are theoretically more powerful but rarely practical for graphs with more than a few thousand nodes.
Equivariant Neural Networks and Geometric Deep Learning
The broader field of geometric deep learning (Bronstein et al., 2021) unifies GNNs with other structured architectures under the principle of symmetry. The key idea: the architecture should respect the symmetries of the data.
| Data Type | Symmetry Group | Architecture |
|---|---|---|
| Sets | Permutation ($S_n$) | DeepSets |
| Graphs | Node permutation ($S_n$) | GNN / MPNN |
| Grids | Translation ($\mathbb{Z}^2$) | CNN |
| Sequences | Translation ($\mathbb{Z}$) | RNN / Transformer |
| 3D point clouds | Rotation + translation ($SE(3)$) | EGNN, SchNet |
| Spherical data | Rotation ($SO(3)$) | Spherical CNN |
This perspective reveals that CNNs, RNNs, and GNNs are all instances of the same principle — equivariant function approximation — applied to different symmetry groups. The transformer, which has no built-in symmetry (it relies on positional encoding to break permutation equivariance), is the degenerate case where symmetry is learned rather than built in.
14.14 Practical Considerations
Feature Engineering for Graphs
Node features significantly impact GNN performance. When node features are absent (common in social networks), effective alternatives include:
- One-hot node degree: simple but surprisingly effective for small graphs.
- Random features: random vectors assigned to nodes, providing a unique identifier (used in random feature GNNs to boost expressiveness beyond WL).
- Structural features: PageRank, clustering coefficient, betweenness centrality.
- Positional encodings: Laplacian eigenvectors (sign-invariant), random walk statistics (landing probabilities at different walk lengths).
Negative Sampling for Link Prediction
The choice of negative samples dramatically affects link prediction performance. Uniform random sampling (any non-edge is a negative) produces "easy" negatives — two nodes in different communities are trivially predicted as non-edges. Hard negative sampling — sampling non-edges between nodes that are close in the graph (e.g., at distance 2) — produces more informative gradients.
Common Failure Modes
-
Heterophily. GNNs assume homophily: connected nodes tend to have similar labels. On heterophilous graphs (e.g., dating networks, where linked nodes have different genders), standard GNNs underperform MLPs. Mitigation: separate ego and neighbor representations, use signed message passing, or add non-local connections.
-
Feature dominance. When node features are highly informative (e.g., bag-of-words in citation networks), the graph structure adds relatively little. Always compare against an MLP baseline to quantify the graph's contribution.
-
Label leakage in transductive settings. In standard semi-supervised node classification (Cora, CiteSeer), the test nodes and their edges are visible during training — only the labels are masked. This means the model can "see" the test distribution. Inductive evaluation (held-out nodes not in the training graph) is more rigorous.
14.15 Summary: Choosing the Right Architecture
| Architecture | Scalability | Expressiveness | Inductive? | Best For |
|---|---|---|---|---|
| GCN | Medium (full-batch) | Low (degree-weighted mean) | No (transductive) | Homophilous semi-supervised |
| GraphSAGE | High (mini-batch sampling) | Medium (various aggregators) | Yes | Large-scale, new nodes |
| GAT | Medium | Medium (learned attention) | Yes | When neighbor importance varies |
| GIN | Medium | High (WL-equivalent) | Yes | Graph classification, expressiveness |
| Graph Transformer | Low ($O(N^2)$) | High (global attention) | Yes | Small graphs, max expressiveness |
For the StreamRec content platform (Case Study 1): GraphSAGE or a heterogeneous GAT, because the graph is large (millions of nodes), bipartite (users and items), and continuously growing (new users and items require inductive capability).
For MediCore molecular prediction (Case Study 2): GIN or an attention-based GNN, because molecular graphs are small (10-100 atoms), expressiveness matters (stereochemistry, functional groups), and graph classification requires a powerful readout.
Fundamentals > Frontier: Graph neural networks are among the most rapidly evolving areas of deep learning. New architectures appear weekly. The fundamentals — message passing, spectral theory, WL expressiveness, over-smoothing — are stable. Master these, and every new paper is a variation on a theme you already understand. Ignore them, and every new paper is an incomprehensible collection of acronyms. The spectral-spatial connection, the WL hierarchy, and the geometric deep learning perspective are the conceptual foundations that do not change.
Chapter Summary
This chapter developed graph neural networks from both spectral and spatial perspectives. We began with the graph Laplacian and its eigendecomposition (connecting to Chapter 1's linear algebra), derived graph convolution in the frequency domain, and showed how the Chebyshev approximation yields spatially localized filters. The spatial perspective — the message passing framework — unifies GCN, GraphSAGE, GAT, and GIN as different instantiations of message construction, aggregation, and update functions.
The Weisfeiler-Leman hierarchy establishes a fundamental expressiveness limit: no message passing GNN can distinguish graphs that the 1-WL test cannot. GIN achieves this upper bound through sum aggregation and MLP updates. Over-smoothing limits GNN depth: iterated neighborhood averaging acts as a low-pass filter that destroys high-frequency signal. These are not engineering problems with engineering solutions — they are mathematical properties of the message passing framework.
For the progressive project, Chapter 14's milestone (M6) applies these ideas to the StreamRec recommendation system: modeling the user-item interaction graph with a GNN-based collaborative filter and comparing against the matrix factorization baseline from Chapter 1 and the two-tower model from Chapter 13. Case Study 2 applies GNNs to molecular property prediction, demonstrating the power of graph-native representations for data that is inherently non-Euclidean.
The next chapter transitions from prediction to the question that separates data science from statistics: when does a pattern in data reflect a causal relationship, and when is it merely a correlation?