Throughout this book, we have worked with data that fits neatly into regular structures. Images are grids of pixels, perfectly suited to convolutional neural networks (Chapter 14). Text is a sequence of tokens, naturally handled by recurrent...
In This Chapter
- Introduction: When Data Has Structure Beyond Grids and Sequences
- 37.1 Graph Fundamentals
- 37.2 The Message Passing Framework
- 37.3 Graph Convolutional Networks (GCN)
- 37.4 GraphSAGE: Inductive Learning on Graphs
- 37.5 Graph Attention Networks (GAT)
- 37.6 Graph-Level Tasks: Readout and Pooling
- 37.7 Node Classification
- 37.8 Molecular Property Prediction
- 37.9 Knowledge Graphs
- 37.10 Applications of Graph Neural Networks
- 37.11 PyTorch Geometric: The Practical GNN Library
- 37.12 Advanced GNN Architectures
- 37.13 Practical Considerations
- Summary
- Quick Reference
Chapter 37: Graph Neural Networks and Structured Data
Introduction: When Data Has Structure Beyond Grids and Sequences
Throughout this book, we have worked with data that fits neatly into regular structures. Images are grids of pixels, perfectly suited to convolutional neural networks (Chapter 14). Text is a sequence of tokens, naturally handled by recurrent networks (Chapter 15) and Transformers (Chapter 19). But many of the most important datasets in the world are neither grids nor sequences---they are graphs.
A molecule is a graph: atoms are nodes, chemical bonds are edges. A social network is a graph: users are nodes, friendships are edges. A citation network is a graph: papers are nodes, citations are edges. Protein interactions, transportation networks, knowledge bases, recommendation systems, financial transaction networks, computer programs---all of these are most naturally represented as graphs.
The challenge is fundamental: graphs have no fixed ordering of nodes, no regular spatial structure, and varying connectivity patterns. You cannot simply flatten a graph into a vector and feed it to a multilayer perceptron. Convolution filters designed for grids do not transfer to irregular neighborhoods. The Transformer's positional encodings assume a sequential ordering that graphs lack.
Graph Neural Networks (GNNs) solve this problem by operating directly on graph structure through a simple but powerful idea: message passing. Each node aggregates information from its neighbors, transforms it, and updates its own representation. By stacking multiple message-passing layers, each node's representation incorporates information from increasingly distant parts of the graph.
This chapter provides a rigorous, implementation-focused treatment of GNNs. We will build from graph fundamentals through the most important GNN architectures---Graph Convolutional Networks (GCN), GraphSAGE, and Graph Attention Networks (GAT)---and then apply them to real-world problems in molecular property prediction, citation networks, and knowledge graphs. All implementations use PyTorch and PyTorch Geometric, the dominant library for GNN development.
Prerequisites
Before diving in, you should be comfortable with: - Neural network fundamentals and PyTorch (Chapters 4--7, 11--12) - The attention mechanism (Chapter 18) - Basic linear algebra, especially matrix multiplication (Chapter 2) - Convolutional neural networks conceptually (Chapter 14)
37.1 Graph Fundamentals
37.1.1 What Is a Graph?
A graph $G = (V, E)$ consists of a set of nodes (or vertices) $V$ and a set of edges $E \subseteq V \times V$ connecting pairs of nodes. Each node $v \in V$ may carry a feature vector $\mathbf{x}_v \in \mathbb{R}^d$, and each edge $(u, v) \in E$ may carry edge features $\mathbf{e}_{uv} \in \mathbb{R}^{d_e}$.
Graphs come in several varieties:
- Undirected graphs: Edges have no direction. If $(u, v) \in E$, then $(v, u) \in E$. Social friendships are undirected.
- Directed graphs: Edges have direction. A citation from paper A to paper B does not imply B cites A.
- Weighted graphs: Edges carry scalar weights. In a transportation network, edge weights might represent distances.
- Heterogeneous graphs: Nodes and edges have different types. In an e-commerce graph, nodes might be users, products, and brands, with different edge types (purchased, reviewed, manufactured_by).
37.1.2 The Adjacency Matrix
The most common algebraic representation of a graph is the adjacency matrix $\mathbf{A} \in \{0, 1\}^{N \times N}$, where $N = |V|$:
$$A_{ij} = \begin{cases} 1 & \text{if } (i, j) \in E \\ 0 & \text{otherwise} \end{cases}$$
For undirected graphs, $\mathbf{A}$ is symmetric: $A_{ij} = A_{ji}$. For weighted graphs, $A_{ij}$ holds the edge weight instead of a binary indicator.
The adjacency matrix immediately gives us useful information: - The degree of node $i$ is $d_i = \sum_j A_{ij}$ - The degree matrix $\mathbf{D}$ is a diagonal matrix with $D_{ii} = d_i$ - The graph Laplacian is $\mathbf{L} = \mathbf{D} - \mathbf{A}$ - The normalized Laplacian is $\mathbf{L}_{\text{norm}} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}$
The Laplacian is to graphs what the derivative is to continuous functions---it measures local variation. Its eigenvalues and eigenvectors define the "frequencies" of the graph, which forms the foundation of spectral graph theory. We will return to this connection in Section 37.3 when deriving the GCN layer from spectral graph theory.
Worked Example: A Small Social Network. Consider a graph with 4 users, where user 0 is friends with users 1 and 3, user 1 is friends with users 0 and 2, user 2 is friends with users 1 and 3, and user 3 is friends with users 0 and 2. The adjacency matrix is:
$$\mathbf{A} = \begin{pmatrix} 0 & 1 & 0 & 1 \\ 1 & 0 & 1 & 0 \\ 0 & 1 & 0 & 1 \\ 1 & 0 & 1 & 0 \end{pmatrix}$$
The degree matrix is $\mathbf{D} = \text{diag}(2, 2, 2, 2)$, since each node has exactly 2 neighbors. The Laplacian is:
$$\mathbf{L} = \mathbf{D} - \mathbf{A} = \begin{pmatrix} 2 & -1 & 0 & -1 \\ -1 & 2 & -1 & 0 \\ 0 & -1 & 2 & -1 \\ -1 & 0 & -1 & 2 \end{pmatrix}$$
The eigenvalues of $\mathbf{L}$ are $\{0, 2, 2, 4\}$. The zero eigenvalue always exists for a connected graph (its eigenvector is the constant vector), and the largest eigenvalue reflects the maximum "frequency" or variation across the graph. We see that this graph has a repeated eigenvalue of 2, reflecting its symmetry.
37.1.3 Sparse Representations
Most real-world graphs are sparse: the number of edges $|E|$ is much smaller than $N^2$. Storing the full adjacency matrix is wasteful and often impossible for large graphs. Instead, we use sparse representations:
- Edge list (COO format): A tensor of shape $(2, |E|)$ listing source and target nodes for each edge. This is the default in PyTorch Geometric.
- Adjacency list: For each node, store a list of its neighbors.
- CSR/CSC format: Compressed sparse row/column format, efficient for matrix operations.
import torch
# A simple undirected graph with 4 nodes and 4 edges
# 0 -- 1
# | |
# 3 -- 2
edge_index = torch.tensor([
[0, 1, 1, 2, 2, 3, 3, 0], # source nodes
[1, 0, 2, 1, 3, 2, 0, 3], # target nodes
], dtype=torch.long)
# Node features: 4 nodes, each with 3 features
x = torch.randn(4, 3)
# Edge index shape: [2, num_edges] (COO format)
print(f"Edge index shape: {edge_index.shape}") # [2, 8]
print(f"Node feature shape: {x.shape}") # [4, 3]
37.1.4 The Node Feature Matrix
For a graph with $N$ nodes, each having $d$-dimensional features, the node feature matrix is $\mathbf{X} \in \mathbb{R}^{N \times d}$. Row $i$ of $\mathbf{X}$ is the feature vector $\mathbf{x}_i$ of node $i$.
What constitutes node features depends entirely on the domain: - In a molecular graph, node features encode atom type (one-hot), atomic number, charge, hybridization, and whether the atom is in a ring. - In a citation network, node features might be bag-of-words or learned embeddings of the paper's abstract. - In a social network, node features could be user profiles: age, location, activity level.
When no natural features exist, common choices include one-hot node identity, random features, or learned embedding tables (similar to word embeddings in NLP).
37.2 The Message Passing Framework
37.2.1 The Core Idea
The fundamental operation of graph neural networks is message passing (also called neighborhood aggregation). The idea is simple: to compute the representation of a node, gather information from its neighbors, combine it, and use it to update the node's state.
Formally, at layer $\ell$, each node $v$ performs three steps:
-
Message computation: Each neighbor $u \in \mathcal{N}(v)$ sends a message: $$\mathbf{m}_{u \to v}^{(\ell)} = \phi\!\left(\mathbf{h}_u^{(\ell-1)}, \mathbf{h}_v^{(\ell-1)}, \mathbf{e}_{uv}\right)$$
-
Aggregation: Messages from all neighbors are combined: $$\mathbf{m}_v^{(\ell)} = \bigoplus_{u \in \mathcal{N}(v)} \mathbf{m}_{u \to v}^{(\ell)}$$
-
Update: The node representation is updated: $$\mathbf{h}_v^{(\ell)} = \psi\!\left(\mathbf{h}_v^{(\ell-1)}, \mathbf{m}_v^{(\ell)}\right)$$
Here, $\phi$ is the message function, $\bigoplus$ is a permutation-invariant aggregation (sum, mean, max), and $\psi$ is the update function. The initial node representations are $\mathbf{h}_v^{(0)} = \mathbf{x}_v$, the input features.
This framework is important because it ensures a key requirement: permutation equivariance. Reordering the nodes of a graph should produce correspondingly reordered outputs, not different outputs. Since the aggregation $\bigoplus$ is permutation-invariant (the sum of a set of numbers does not depend on the order), the entire message passing operation is equivariant to node permutations.
Why permutation equivariance matters. Unlike images (where pixel (0,0) is always the top-left corner) or sequences (where position 1 is always first), graphs have no canonical node ordering. If you relabel node 3 as node 7, the graph itself has not changed, so the model's output should simply be relabeled accordingly. Formally, let $\mathbf{P}$ be a permutation matrix. A function $f$ is equivariant if $f(\mathbf{P}\mathbf{X}, \mathbf{P}\mathbf{A}\mathbf{P}^T) = \mathbf{P} f(\mathbf{X}, \mathbf{A})$. The message passing framework guarantees this because the aggregation depends only on the multiset of neighbor features, not their ordering.
Formal connection to the Weisfeiler-Lehman test. The message passing framework has a deep theoretical connection to the Weisfeiler-Lehman (WL) graph isomorphism test, a classical algorithm from graph theory. The 1-WL test iteratively refines node "colors" (labels) by hashing each node's current color together with the multiset of its neighbors' colors. After $k$ iterations, two graphs that have different color histograms are provably non-isomorphic. Xu et al. (2019) showed that the expressive power of any message passing GNN is bounded above by the 1-WL test---we will explore this result in detail in Section 37.11.1.
Worked Example: One Round of Message Passing. Consider a triangle graph with 3 nodes, where each node has a scalar feature: $h_0 = 1$, $h_1 = 2$, $h_2 = 3$. Using sum aggregation with no transformation (for simplicity), after one message passing step:
- Node 0 receives messages from neighbors 1 and 2: $m_0 = h_1 + h_2 = 2 + 3 = 5$
- Node 1 receives messages from neighbors 0 and 2: $m_1 = h_0 + h_2 = 1 + 3 = 4$
- Node 2 receives messages from neighbors 0 and 1: $m_2 = h_0 + h_1 = 1 + 2 = 3$
If the update function simply replaces the old feature with the aggregated message, the new features are $(5, 4, 3)$. Notice that node 0, which had the smallest feature, now has the largest representation because its neighbors had large features. This illustrates how message passing diffuses information through the graph.
37.2.2 Receptive Field and Over-Smoothing
After $L$ layers of message passing, each node's representation depends on all nodes within $L$ hops. The receptive field of a node grows exponentially with depth in dense graphs---a phenomenon analogous to how convolutional neural networks aggregate over larger spatial regions with more layers.
However, there is a critical problem: over-smoothing. As the number of layers increases, all node representations converge to a similar value, making nodes indistinguishable. Intuitively, after enough rounds of averaging, every node sees the same "average" of the entire graph.
This limits most practical GNNs to 2--4 layers, in stark contrast to the deep architectures common in vision and language. Mitigating over-smoothing is an active research area, with approaches including: - Residual connections: Add skip connections from earlier layers, as we saw in Chapter 14 for ResNets. In GNNs, the residual connection takes the form $\mathbf{h}_v^{(\ell)} = \mathbf{h}_v^{(\ell-1)} + \text{GNN\_layer}(\mathbf{h}_v^{(\ell-1)}, \mathcal{N}(v))$. - Jumping Knowledge Networks (Xu et al., 2018): Instead of using only the final layer's representations, concatenate or select representations from all layers: $\mathbf{h}_v^{\text{final}} = \text{AGG}(\mathbf{h}_v^{(0)}, \mathbf{h}_v^{(1)}, \ldots, \mathbf{h}_v^{(L)})$. The aggregation can be concatenation, max-pooling, or an LSTM over layers. - DropEdge (Rong et al., 2020): Randomly remove a fraction of edges during each training iteration. This slows the information propagation, reducing over-smoothing while also acting as a regularizer. - PairNorm (Zhao and Akoglu, 2020): Explicitly normalize node features at each layer to maintain a target total pairwise distance between representations.
Quantifying over-smoothing. A useful diagnostic is the Mean Average Distance (MAD): compute the average pairwise cosine distance between all node representations. If MAD approaches zero as depth increases, the model is over-smoothing. In practice, you will learn to monitor this metric during architecture selection.
Historical context. The over-smoothing problem was first formally characterized by Li, Han, and Wu (2018), who proved that GCN with infinitely many layers converges to a state where all node representations are proportional to the stationary distribution of a random walk on the graph. This result connects GNNs to the theory of Markov chains and provides a precise mathematical characterization of when and why over-smoothing occurs.
37.2.3 Message Passing as Matrix Multiplication
The message passing operation can often be expressed as sparse matrix multiplication. For a simple sum aggregation with a linear message function:
$$\mathbf{H}^{(\ell)} = \sigma\!\left(\mathbf{A} \mathbf{H}^{(\ell-1)} \mathbf{W}^{(\ell)}\right)$$
where $\mathbf{A}$ is the adjacency matrix, $\mathbf{H}^{(\ell-1)} \in \mathbb{R}^{N \times d_{\ell-1}}$ is the matrix of all node representations at layer $\ell-1$, $\mathbf{W}^{(\ell)} \in \mathbb{R}^{d_{\ell-1} \times d_\ell}$ is a learnable weight matrix, and $\sigma$ is a nonlinearity.
This reveals the connection to spectral graph theory: multiplying by $\mathbf{A}$ is a low-pass filter on the graph, smoothing features across connected nodes. Different normalizations of $\mathbf{A}$ yield different filtering properties.
37.3 Graph Convolutional Networks (GCN)
37.3.1 Derivation from Spectral Graph Theory
To truly understand GCN, we must trace its derivation from spectral graph theory. This provides the theoretical justification for the specific normalization used in the GCN layer and reveals why GCN acts as a low-pass graph filter.
Graph Fourier Transform. Just as the classical Fourier transform decomposes a signal into frequency components using sinusoids, the graph Fourier transform decomposes a graph signal using the eigenvectors of the graph Laplacian. Let $\mathbf{L} = \mathbf{U} \boldsymbol{\Lambda} \mathbf{U}^T$ be the eigendecomposition of the normalized Laplacian, where $\mathbf{U}$ is the matrix of eigenvectors and $\boldsymbol{\Lambda} = \text{diag}(\lambda_1, \ldots, \lambda_N)$ contains the eigenvalues. For a signal $\mathbf{x}$ on the graph, its graph Fourier transform is:
$$\hat{\mathbf{x}} = \mathbf{U}^T \mathbf{x}$$
where: $\hat{\mathbf{x}}$ is the signal in the spectral (frequency) domain, $\mathbf{U}^T$ is the graph Fourier transform matrix, and $\mathbf{x}$ is the signal in the spatial (node) domain.
The inverse transform is $\mathbf{x} = \mathbf{U} \hat{\mathbf{x}}$.
Spectral graph convolution. A convolution on a graph can be defined as element-wise multiplication in the spectral domain (just as for classical signals). A spectral graph convolution with filter $g_\theta$ is:
$$g_\theta \star \mathbf{x} = \mathbf{U} \, g_\theta(\boldsymbol{\Lambda}) \, \mathbf{U}^T \mathbf{x}$$
where $g_\theta(\boldsymbol{\Lambda}) = \text{diag}(g_\theta(\lambda_1), \ldots, g_\theta(\lambda_N))$ is a spectral filter function applied to each eigenvalue.
The problem with this formulation is computational: it requires the full eigendecomposition of $\mathbf{L}$ ($O(N^3)$ cost) and the filter has $N$ free parameters. This is impractical for large graphs.
Chebyshev polynomial approximation. Defferrard et al. (2016) proposed approximating $g_\theta(\boldsymbol{\Lambda})$ using Chebyshev polynomials of degree $K$:
$$g_\theta(\boldsymbol{\Lambda}) \approx \sum_{k=0}^{K} \theta_k T_k(\tilde{\boldsymbol{\Lambda}})$$
where $T_k$ is the $k$-th Chebyshev polynomial, $\tilde{\boldsymbol{\Lambda}} = 2\boldsymbol{\Lambda}/\lambda_{\max} - \mathbf{I}$ rescales eigenvalues to $[-1, 1]$, and $\theta_k$ are learnable coefficients.
This reduces the filter to $K+1$ parameters and avoids computing the eigendecomposition entirely, since $T_k(\tilde{\mathbf{L}})$ can be computed recursively from the Laplacian using the Chebyshev recurrence: $T_0(\mathbf{x}) = 1$, $T_1(\mathbf{x}) = \mathbf{x}$, $T_k(\mathbf{x}) = 2\mathbf{x} T_{k-1}(\mathbf{x}) - T_{k-2}(\mathbf{x})$.
From ChebNet to GCN. Kipf and Welling (2017) made two simplifying choices: set $K = 1$ (first-order approximation only) and approximate $\lambda_{\max} \approx 2$. With these simplifications, the spectral filter reduces to:
$$g_\theta \star \mathbf{x} \approx \theta_0 \mathbf{x} + \theta_1 (\mathbf{L} - \mathbf{I}) \mathbf{x} = \theta_0 \mathbf{x} - \theta_1 \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2} \mathbf{x}$$
Further constraining $\theta = \theta_0 = -\theta_1$ to reduce parameters:
$$g_\theta \star \mathbf{x} \approx \theta (\mathbf{I} + \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}) \mathbf{x}$$
The renormalization trick replaces $\mathbf{I} + \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}$ with $\tilde{\mathbf{D}}^{-1/2} \tilde{\mathbf{A}} \tilde{\mathbf{D}}^{-1/2}$ (where $\tilde{\mathbf{A}} = \mathbf{A} + \mathbf{I}$ and $\tilde{\mathbf{D}}$ is its degree matrix) to avoid numerical instabilities. This yields the GCN propagation rule.
The key insight from this derivation is that GCN is a first-order approximation of a spectral graph convolution. The symmetric normalization is not an arbitrary design choice---it emerges naturally from the spectral formulation.
37.3.2 The GCN Layer
Kipf and Welling (2017) introduced the Graph Convolutional Network with the following layer-wise propagation rule:
$$\mathbf{H}^{(\ell)} = \sigma\!\left(\tilde{\mathbf{D}}^{-1/2} \tilde{\mathbf{A}} \tilde{\mathbf{D}}^{-1/2} \mathbf{H}^{(\ell-1)} \mathbf{W}^{(\ell)}\right)$$
where $\tilde{\mathbf{A}} = \mathbf{A} + \mathbf{I}$ is the adjacency matrix with added self-loops, and $\tilde{\mathbf{D}}$ is the degree matrix of $\tilde{\mathbf{A}}$.
Let us unpack this step by step:
-
Self-loops ($\tilde{\mathbf{A}} = \mathbf{A} + \mathbf{I}$): When aggregating neighbor information, a node should also include its own features. Adding self-loops ensures this.
-
Symmetric normalization ($\tilde{\mathbf{D}}^{-1/2} \tilde{\mathbf{A}} \tilde{\mathbf{D}}^{-1/2}$): Without normalization, nodes with many neighbors would have much larger aggregated values than nodes with few neighbors. The symmetric normalization ensures that each node's contribution is scaled by the geometric mean of the degrees of the source and target nodes.
-
Linear transformation ($\mathbf{W}^{(\ell)}$): A learnable weight matrix projects the aggregated features. This is analogous to a learnable convolution filter.
-
Nonlinearity ($\sigma$): Typically ReLU. Without nonlinearity, stacking layers would collapse to a single linear transformation.
For a single node $v$, the GCN update can be written as:
$$\mathbf{h}_v^{(\ell)} = \sigma\!\left(\mathbf{W}^{(\ell)} \sum_{u \in \mathcal{N}(v) \cup \{v\}} \frac{1}{\sqrt{\tilde{d}_u \tilde{d}_v}} \mathbf{h}_u^{(\ell-1)}\right)$$
where $\tilde{d}_u = 1 + \sum_j A_{uj}$ is the degree in the self-loop-augmented graph.
37.3.3 GCN Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(42)
class GCNLayer(nn.Module):
"""A single Graph Convolutional Network layer."""
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
super().__init__()
self.linear = nn.Linear(in_features, out_features, bias=bias)
nn.init.xavier_uniform_(self.linear.weight)
def forward(
self, x: torch.Tensor, edge_index: torch.Tensor, num_nodes: int
) -> torch.Tensor:
"""Forward pass of GCN layer.
Args:
x: Node feature matrix [num_nodes, in_features].
edge_index: Edge indices in COO format [2, num_edges].
num_nodes: Number of nodes in the graph.
Returns:
Updated node features [num_nodes, out_features].
"""
# Add self-loops
self_loops = torch.arange(num_nodes, device=x.device).unsqueeze(0).repeat(2, 1)
edge_index_with_loops = torch.cat([edge_index, self_loops], dim=1)
# Compute degree for normalization
row, col = edge_index_with_loops
deg = torch.zeros(num_nodes, device=x.device)
deg.scatter_add_(0, row, torch.ones(row.size(0), device=x.device))
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0.0
# Symmetric normalization coefficients
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Linear transformation
x = self.linear(x)
# Message passing: aggregate normalized neighbor features
out = torch.zeros_like(x)
out.scatter_add_(0, col.unsqueeze(1).expand_as(x[row]), norm.unsqueeze(1) * x[row])
return out
37.3.4 Limitations of GCN
GCN is elegant and effective but has notable limitations:
- Transductive by default: The original GCN operates on the full graph adjacency matrix. It cannot easily generalize to nodes not seen during training (inductive setting).
- Fixed normalization: All neighbors contribute equally (up to degree normalization). There is no mechanism to learn which neighbors are more important.
- Shallow depth: Over-smoothing limits GCN to 2--3 layers in practice.
- Homophily assumption: GCN works best when connected nodes have similar labels (homophily). It struggles on heterophilous graphs where connected nodes tend to have different labels.
37.4 GraphSAGE: Inductive Learning on Graphs
37.4.1 Sampling and Aggregating
Hamilton, Ying, and Leskovec (2017) introduced GraphSAGE (SAmple and aggreGatE) to address the transductive limitation of GCN. The key innovations are:
-
Sampling: Instead of using the full neighborhood, sample a fixed number of neighbors at each layer. This makes computation tractable for very large graphs and enables mini-batch training.
-
Learnable aggregation: Instead of using only symmetric normalization, GraphSAGE supports multiple aggregation functions: - Mean aggregator: Average of neighbor features (similar to GCN) - LSTM aggregator: Apply an LSTM to a random permutation of neighbors - Pooling aggregator: Element-wise max of transformed neighbor features
-
Concatenation with self: Instead of mixing self-features with neighbor features before transformation, GraphSAGE concatenates them:
$$\mathbf{h}_v^{(\ell)} = \sigma\!\left(\mathbf{W}^{(\ell)} \cdot \text{CONCAT}\!\left(\mathbf{h}_v^{(\ell-1)},\; \text{AGG}\!\left(\{\mathbf{h}_u^{(\ell-1)} : u \in \mathcal{N}_S(v)\}\right)\right)\right)$$
where $\mathcal{N}_S(v)$ is a fixed-size sample from the neighborhood of $v$.
37.4.2 The Neighbor Sampling Strategy
GraphSAGE's sampling approach is critical to its scalability and deserves careful examination. The key insight is that we do not need all neighbors to compute a useful aggregation---a random sample of fixed size suffices.
The sampling procedure. For a $L$-layer GraphSAGE with sample sizes $(S_1, S_2, \ldots, S_L)$ where $S_\ell$ is the number of neighbors sampled at layer $\ell$:
- Select a batch of $B$ target nodes (the nodes whose representations we want to compute).
- For each target node $v$, uniformly sample $S_L$ neighbors from $\mathcal{N}(v)$. If $|\mathcal{N}(v)| < S_L$, sample with replacement.
- For each of those sampled neighbors, uniformly sample $S_{L-1}$ of their neighbors.
- Continue recursively until reaching layer 1.
- Compute representations bottom-up: first compute layer-1 representations for all sampled nodes, then layer-2, and so on up to layer $L$.
This creates a computation tree (also called a computation graph or sampling fan-out) rooted at each target node. With $L$ layers and sample sizes $(S_1, \ldots, S_L)$, the total number of nodes involved in computing one target node's representation is at most $\prod_{\ell=1}^{L} S_\ell$.
Typical sample sizes. In practice, common choices are $S_1 = 25$, $S_2 = 10$ for a 2-layer model, yielding at most $25 \times 10 = 250$ nodes per target. For a 3-layer model, $S_1 = 25, S_2 = 10, S_3 = 10$ gives at most 2,500 nodes---still very manageable compared to a graph with millions of nodes.
Variance and bias. Sampling introduces variance: different random samples produce different representations for the same node. This variance decreases as $1/S_\ell$ (by the central limit theorem, since we are averaging over $S_\ell$ i.i.d. samples). The mean aggregator with sampling is an unbiased estimator of the full-neighborhood mean, which provides a theoretical guarantee. In practice, the stochastic noise from sampling acts as a regularizer, similar to dropout.
37.4.3 Mini-Batch Training
GraphSAGE's sampling approach enables mini-batch training on graphs with millions of nodes. The full training pipeline works as follows:
- Shuffle all labeled nodes and form mini-batches of size $B$.
- For each mini-batch, expand the computation trees by sampling neighbors at each layer.
- Load only the required node features (not the full feature matrix) into GPU memory.
- Compute representations bottom-up through the sampled subgraph.
- Compute loss only on the $B$ target nodes and backpropagate.
This creates a "computation tree" rooted at each target node. With $L$ layers and $k$ samples per layer, the computation tree has at most $k^L$ nodes---manageable even for enormous graphs.
Comparison with other sampling strategies. GraphSAGE's node-wise sampling is not the only option. Several alternatives have been developed:
| Strategy | Description | Trade-off |
|---|---|---|
| Node-wise (GraphSAGE) | Sample fixed neighbors per node per layer | Simple; high variance for hub nodes |
| Layer-wise (FastGCN) | Sample nodes per layer independently | Lower variance; loses neighbor connectivity |
| Subgraph (GraphSAINT) | Sample entire subgraphs via random walks | Preserves local structure; variable compute |
| Cluster (Cluster-GCN) | Partition graph into clusters, train on subgraphs | Efficient; boundary effects |
| Historical embeddings (GNNAutoScale) | Cache embeddings from previous iterations | Reduces recomputation; stale embeddings |
class GraphSAGELayer(nn.Module):
"""A single GraphSAGE layer with mean aggregation."""
def __init__(self, in_features: int, out_features: int) -> None:
super().__init__()
self.linear = nn.Linear(in_features * 2, out_features)
nn.init.xavier_uniform_(self.linear.weight)
def forward(
self, x: torch.Tensor, edge_index: torch.Tensor, num_nodes: int
) -> torch.Tensor:
"""Forward pass with mean aggregation and concat.
Args:
x: Node features [num_nodes, in_features].
edge_index: COO edge indices [2, num_edges].
num_nodes: Total node count.
Returns:
Updated node features [num_nodes, out_features].
"""
row, col = edge_index
# Mean aggregation of neighbor features
neighbor_sum = torch.zeros(num_nodes, x.size(1), device=x.device)
neighbor_sum.scatter_add_(0, col.unsqueeze(1).expand(-1, x.size(1)), x[row])
degree = torch.zeros(num_nodes, device=x.device)
degree.scatter_add_(0, col, torch.ones(row.size(0), device=x.device))
degree = degree.clamp(min=1).unsqueeze(1)
neighbor_mean = neighbor_sum / degree
# Concatenate self and neighbor features
combined = torch.cat([x, neighbor_mean], dim=1)
# Linear transformation and normalization
out = self.linear(combined)
out = F.normalize(out, p=2, dim=1)
return out
37.4.4 Advantages of GraphSAGE
- Inductive: Can generalize to unseen nodes and even entirely new graphs. This is critical for production systems where new users and items are continuously added.
- Scalable: Neighborhood sampling enables mini-batch training on graphs with billions of edges.
- Flexible aggregation: Different aggregators suit different tasks. In practice, the mean aggregator is the most common starting point, with the pool aggregator sometimes outperforming it on heterogeneous graphs.
- Practical: GraphSAGE is widely used in industry---Pinterest, for instance, deployed PinSage (a variant) for recommendation at scale, and Uber uses GraphSAGE-based models for fraud detection.
When to choose GraphSAGE over GCN. If your graph fits in GPU memory and is static (no new nodes at test time), GCN is simpler and often performs comparably. If your graph has millions of nodes, new nodes appear dynamically, or you need mini-batch training, GraphSAGE is the better choice.
37.5 Graph Attention Networks (GAT)
37.5.1 Attention on Graphs
Velickovic et al. (2018) introduced Graph Attention Networks, bringing the attention mechanism (Chapter 18) to graphs. The key insight: not all neighbors are equally important, and the model should learn to weight them.
Intuition. Consider a citation network where you want to classify a paper by topic. Not all cited papers are equally relevant---a paper on "deep learning for NLP" is more informative about a paper's topic than a generic "introduction to statistics" reference. GCN treats all citations equally (up to degree normalization), but GAT learns to attend more strongly to the relevant citations.
For each node $v$ and neighbor $u \in \mathcal{N}(v)$, GAT computes an unnormalized attention coefficient:
$$e_{vu} = \text{LeakyReLU}\!\left(\mathbf{a}^T \left[\mathbf{W} \mathbf{h}_v \,\|\, \mathbf{W} \mathbf{h}_u\right]\right)$$
where: - $\mathbf{W} \in \mathbb{R}^{d' \times d}$ is a shared linear transformation that projects features into a new space - $\mathbf{a} \in \mathbb{R}^{2d'}$ is a learnable attention vector - $\|$ denotes concatenation, producing a $2d'$-dimensional vector - LeakyReLU (with negative slope 0.2) is used instead of ReLU to allow small gradients for negative inputs
Decomposing the attention mechanism. The concatenation-then-dot-product form can be decomposed for efficiency. If we split $\mathbf{a} = [\mathbf{a}_{\text{src}}; \mathbf{a}_{\text{dst}}]$ where $\mathbf{a}_{\text{src}}, \mathbf{a}_{\text{dst}} \in \mathbb{R}^{d'}$, then:
$$e_{vu} = \text{LeakyReLU}\!\left(\mathbf{a}_{\text{dst}}^T \mathbf{W} \mathbf{h}_v + \mathbf{a}_{\text{src}}^T \mathbf{W} \mathbf{h}_u\right)$$
This additive decomposition means we can precompute $\mathbf{a}_{\text{dst}}^T \mathbf{W} \mathbf{h}_v$ for all nodes once, then combine them pairwise only for existing edges. This is the implementation strategy used in our code below and in PyTorch Geometric.
These coefficients are normalized across the neighborhood using softmax:
$$\alpha_{vu} = \frac{\exp(e_{vu})}{\sum_{k \in \mathcal{N}(v)} \exp(e_{vk})}$$
The softmax ensures that attention weights are non-negative and sum to 1 over each node's neighborhood, making them interpretable as a probability distribution over neighbors.
The updated representation is then a weighted sum:
$$\mathbf{h}_v' = \sigma\!\left(\sum_{u \in \mathcal{N}(v)} \alpha_{vu}\, \mathbf{W} \mathbf{h}_u\right)$$
Comparison with Transformer attention. GAT's attention is related to but distinct from the Transformer attention we studied in Chapter 18. In the Transformer, attention uses separate query, key, and value projections and computes $\text{softmax}(\mathbf{Q}\mathbf{K}^T / \sqrt{d_k})\mathbf{V}$. GAT uses a simpler additive attention with a single shared projection $\mathbf{W}$. GATv2 (Brody et al., 2022) later showed that the original GAT's attention is static---it can rank neighbors but cannot change the ranking based on the query node---and proposed a more expressive variant that applies the nonlinearity after concatenation but before the attention dot product.
37.5.2 Multi-Head Attention
Like Transformer attention (Chapter 19), GAT uses multiple attention heads to stabilize learning and capture different relationship patterns:
$$\mathbf{h}_v' = \Big\|_{k=1}^{K} \sigma\!\left(\sum_{u \in \mathcal{N}(v)} \alpha_{vu}^{(k)}\, \mathbf{W}^{(k)} \mathbf{h}_u\right)$$
where $\|$ is concatenation across heads. In the final layer, averaging is typically used instead of concatenation:
$$\mathbf{h}_v' = \sigma\!\left(\frac{1}{K} \sum_{k=1}^{K} \sum_{u \in \mathcal{N}(v)} \alpha_{vu}^{(k)}\, \mathbf{W}^{(k)} \mathbf{h}_u\right)$$
37.5.3 Implementation of GAT
class GATLayer(nn.Module):
"""A single Graph Attention Network layer."""
def __init__(
self, in_features: int, out_features: int, num_heads: int = 4,
dropout: float = 0.6, concat: bool = True,
) -> None:
super().__init__()
self.num_heads = num_heads
self.out_features = out_features
self.concat = concat
self.W = nn.Linear(in_features, out_features * num_heads, bias=False)
self.a_src = nn.Parameter(torch.zeros(num_heads, out_features))
self.a_dst = nn.Parameter(torch.zeros(num_heads, out_features))
self.leaky_relu = nn.LeakyReLU(0.2)
self.dropout = nn.Dropout(dropout)
nn.init.xavier_uniform_(self.W.weight)
nn.init.xavier_uniform_(self.a_src.unsqueeze(0))
nn.init.xavier_uniform_(self.a_dst.unsqueeze(0))
def forward(
self, x: torch.Tensor, edge_index: torch.Tensor, num_nodes: int
) -> torch.Tensor:
"""Forward pass with multi-head attention.
Args:
x: Node features [num_nodes, in_features].
edge_index: COO edge indices [2, num_edges].
num_nodes: Number of nodes.
Returns:
Updated features. Shape depends on concat mode.
"""
# Linear transform: [N, heads * out]
h = self.W(x).view(num_nodes, self.num_heads, self.out_features)
row, col = edge_index
# Compute attention scores
attn_src = (h[row] * self.a_src.unsqueeze(0)).sum(dim=-1) # [E, heads]
attn_dst = (h[col] * self.a_dst.unsqueeze(0)).sum(dim=-1) # [E, heads]
attn = self.leaky_relu(attn_src + attn_dst)
# Softmax over each node's neighborhood
attn_max = torch.zeros(num_nodes, self.num_heads, device=x.device)
attn_max.scatter_reduce_(0, col.unsqueeze(1).expand_as(attn), attn, reduce='amax')
attn = torch.exp(attn - attn_max[col])
attn_sum = torch.zeros(num_nodes, self.num_heads, device=x.device)
attn_sum.scatter_add_(0, col.unsqueeze(1).expand_as(attn), attn)
attn = attn / (attn_sum[col] + 1e-8)
attn = self.dropout(attn)
# Weighted aggregation
out = torch.zeros(num_nodes, self.num_heads, self.out_features, device=x.device)
weighted = h[row] * attn.unsqueeze(-1)
out.scatter_add_(0, col.unsqueeze(1).unsqueeze(2).expand_as(weighted), weighted)
if self.concat:
return out.view(num_nodes, self.num_heads * self.out_features)
else:
return out.mean(dim=1)
37.5.4 GAT vs. GCN vs. GraphSAGE
| Feature | GCN | GraphSAGE | GAT |
|---|---|---|---|
| Neighbor weighting | Fixed (degree-based) | Equal (mean/max/LSTM) | Learned (attention) |
| Inductive capability | Limited | Yes | Yes |
| Computational cost | Low | Medium | Higher |
| Interpretability | Low | Low | Attention weights |
| Multi-head support | No | No | Yes |
| Scalability | Full-batch | Mini-batch (sampling) | Full-batch or mini-batch |
37.6 Graph-Level Tasks: Readout and Pooling
37.6.1 Node-Level vs. Graph-Level Prediction
The architectures above produce node-level representations. Many tasks require graph-level predictions: - Predicting whether a molecule is toxic (the molecule is the graph) - Classifying a protein by function - Predicting the property of a material
To go from node representations to a graph representation, we need a readout (or pooling) operation.
37.6.2 Simple Readout Functions
The simplest approach is to aggregate all node representations:
$$\mathbf{h}_G = \text{READOUT}\!\left(\{\mathbf{h}_v^{(L)} : v \in V\}\right)$$
Common choices: - Mean pooling: $\mathbf{h}_G = \frac{1}{|V|} \sum_{v \in V} \mathbf{h}_v^{(L)}$ - Sum pooling: $\mathbf{h}_G = \sum_{v \in V} \mathbf{h}_v^{(L)}$ - Max pooling: $\mathbf{h}_G = \max_{v \in V} \mathbf{h}_v^{(L)}$ (element-wise)
Mean pooling is insensitive to graph size---useful when comparing graphs of different sizes. Sum pooling preserves information about graph size, which may be important (larger molecules generally have different properties than smaller ones). Xu et al. (2019) showed that sum pooling is more expressive than mean or max.
37.6.3 Hierarchical Pooling
Simple readout functions discard structural information. Hierarchical pooling methods create a multi-level coarsening of the graph, analogous to spatial pooling in CNNs. Just as a CNN progressively reduces spatial resolution through pooling layers (as we saw in Chapter 14), hierarchical graph pooling progressively reduces the number of nodes, creating a pyramid of increasingly coarse graph representations.
DiffPool (Ying et al., 2018) learns a soft assignment matrix $\mathbf{S}^{(\ell)} \in \mathbb{R}^{N_\ell \times N_{\ell+1}}$ that maps $N_\ell$ nodes to $N_{\ell+1}$ clusters ($N_{\ell+1} < N_\ell$), creating a coarsened graph. At each pooling level:
$$\mathbf{X}^{(\ell+1)} = \mathbf{S}^{(\ell)T} \mathbf{Z}^{(\ell)}, \quad \mathbf{A}^{(\ell+1)} = \mathbf{S}^{(\ell)T} \mathbf{A}^{(\ell)} \mathbf{S}^{(\ell)}$$
where $\mathbf{Z}^{(\ell)}$ is the node embedding matrix from a GNN layer. The assignment matrix $\mathbf{S}^{(\ell)}$ is itself computed by a separate GNN, so the clustering is learned end-to-end. DiffPool is expressive but expensive: the assignment matrix is dense ($O(N^2)$ memory), limiting it to graphs with at most a few thousand nodes.
TopKPooling (Gao and Ji, 2019) takes a simpler approach. It learns a scalar projection vector $\mathbf{p}$ and computes a score for each node: $s_i = \mathbf{x}_i^T \mathbf{p} / \|\mathbf{p}\|$. The top-$k$ scoring nodes are retained, and all others are dropped. The retained nodes' features are gated by their scores: $\tilde{\mathbf{x}}_i = \mathbf{x}_i \odot \text{sigmoid}(s_i)$. This is memory-efficient ($O(N)$) but may discard important structural information.
SAGPool (Lee et al., 2019) improves on TopKPooling by using a GNN to compute attention scores rather than a simple projection. This means the importance score of each node depends on its neighborhood, not just its own features. The GNN-based score captures structural importance that a simple projection misses.
Set2Set (Vinyals et al., 2016) uses an attention-based readout with a recurrent mechanism. It performs multiple "read" steps over the node set, attending to different nodes each time and accumulating information. This is more expressive than simple aggregation because it can capture complex set-level interactions, at the cost of sequential computation.
Practical guidance. For most graph classification tasks, start with simple mean or sum readout. If performance plateaus and you suspect that structural information is important, try hierarchical pooling. DiffPool works well for small graphs (molecular datasets), while TopKPooling and SAGPool scale to larger graphs.
class SimpleGraphClassifier(nn.Module):
"""Graph classification model with GCN layers and readout."""
def __init__(
self, in_features: int, hidden_features: int, num_classes: int,
num_layers: int = 3,
) -> None:
super().__init__()
self.convs = nn.ModuleList()
self.convs.append(GCNLayer(in_features, hidden_features))
for _ in range(num_layers - 1):
self.convs.append(GCNLayer(hidden_features, hidden_features))
self.classifier = nn.Linear(hidden_features, num_classes)
def forward(
self, x: torch.Tensor, edge_index: torch.Tensor,
batch: torch.Tensor, num_nodes: int,
) -> torch.Tensor:
"""Forward pass with mean readout.
Args:
x: Batched node features [total_nodes, in_features].
edge_index: Batched edge indices [2, total_edges].
batch: Batch assignment vector [total_nodes].
num_nodes: Total number of nodes.
Returns:
Graph-level predictions [num_graphs, num_classes].
"""
for conv in self.convs:
x = F.relu(conv(x, edge_index, num_nodes))
# Mean readout per graph
num_graphs = batch.max().item() + 1
graph_repr = torch.zeros(num_graphs, x.size(1), device=x.device)
graph_repr.scatter_add_(0, batch.unsqueeze(1).expand_as(x), x)
counts = torch.zeros(num_graphs, device=x.device)
counts.scatter_add_(0, batch, torch.ones(batch.size(0), device=x.device))
graph_repr = graph_repr / counts.unsqueeze(1).clamp(min=1)
return self.classifier(graph_repr)
37.7 Node Classification
37.7.1 The Semi-Supervised Setting
Node classification is perhaps the most studied GNN task. The setup is semi-supervised: you have a single large graph where a small fraction of nodes have labels, and the goal is to predict labels for the remaining nodes.
The classic benchmark is the Cora citation network: - 2,708 scientific papers (nodes) - 5,429 citations (edges) - 1,433-dimensional bag-of-words features per paper - 7 classes (topics) - Only 140 labeled nodes for training (about 5%)
GNNs excel here because they propagate label information through the graph structure---labeled nodes influence nearby unlabeled nodes through message passing.
37.7.2 Training a GCN for Node Classification
class NodeClassificationGCN(nn.Module):
"""Two-layer GCN for node classification."""
def __init__(
self, in_features: int, hidden_features: int, num_classes: int,
dropout: float = 0.5,
) -> None:
super().__init__()
self.conv1 = GCNLayer(in_features, hidden_features)
self.conv2 = GCNLayer(hidden_features, num_classes)
self.dropout = dropout
def forward(
self, x: torch.Tensor, edge_index: torch.Tensor, num_nodes: int
) -> torch.Tensor:
"""Two-layer GCN forward pass.
Args:
x: Node features [N, in_features].
edge_index: COO edges [2, E].
num_nodes: Number of nodes N.
Returns:
Log-probabilities per node [N, num_classes].
"""
x = self.conv1(x, edge_index, num_nodes)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.conv2(x, edge_index, num_nodes)
return F.log_softmax(x, dim=1)
The training loop follows the standard PyTorch pattern, with one key difference: we compute the loss only on the labeled training nodes, but forward propagation runs on the entire graph:
def train_node_classification(
model: nn.Module,
x: torch.Tensor,
edge_index: torch.Tensor,
labels: torch.Tensor,
train_mask: torch.Tensor,
num_epochs: int = 200,
lr: float = 0.01,
weight_decay: float = 5e-4,
) -> list[float]:
"""Train a GNN for semi-supervised node classification."""
optimizer = torch.optim.Adam(
model.parameters(), lr=lr, weight_decay=weight_decay
)
num_nodes = x.size(0)
losses = []
for epoch in range(num_epochs):
model.train()
optimizer.zero_grad()
out = model(x, edge_index, num_nodes)
loss = F.nll_loss(out[train_mask], labels[train_mask])
loss.backward()
optimizer.step()
losses.append(loss.item())
return losses
37.8 Molecular Property Prediction
37.8.1 Molecules as Graphs
Drug discovery is one of the most impactful applications of GNNs. A molecule is naturally represented as a graph where: - Nodes are atoms, with features encoding atom type, formal charge, number of hydrogens, aromaticity, and hybridization - Edges are chemical bonds, with features encoding bond type (single, double, triple, aromatic), stereochemistry, and whether the bond is in a ring
The task is typically graph regression (predicting a continuous property like solubility or toxicity) or graph classification (active/inactive against a biological target).
37.8.2 Featurization
Molecular featurization converts a SMILES string (a text representation of molecular structure) into a graph with meaningful node and edge features:
from typing import Any
def atom_features(atom: Any) -> list[float]:
"""Compute features for a single atom.
Features include: atom type (one-hot), degree, formal charge,
number of Hs, hybridization, aromaticity.
"""
atom_types = ['C', 'N', 'O', 'S', 'F', 'Cl', 'Br', 'I', 'P', 'Other']
symbol = atom.GetSymbol() if atom.GetSymbol() in atom_types[:-1] else 'Other'
type_encoding = [1.0 if symbol == t else 0.0 for t in atom_types]
features = type_encoding + [
atom.GetDegree() / 4.0,
atom.GetFormalCharge() / 2.0,
atom.GetTotalNumHs() / 4.0,
1.0 if atom.GetIsAromatic() else 0.0,
1.0 if atom.IsInRing() else 0.0,
]
return features
37.8.3 Message Passing Neural Network (MPNN)
Gilmer et al. (2017) proposed the Message Passing Neural Network framework specifically for molecular property prediction. The key extension beyond basic GCN is the incorporation of edge features into the message function:
$$\mathbf{m}_{u \to v} = \text{MLP}_\text{msg}\!\left(\mathbf{h}_u, \mathbf{e}_{uv}\right)$$
This is critical for molecules, where the bond type (single, double, triple) fundamentally changes the interaction between atoms.
37.8.4 Benchmarks and Results
The MoleculeNet benchmark (Wu et al., 2018) provides standardized datasets for molecular property prediction:
| Dataset | Task | Metric | # Molecules | GNN Performance |
|---|---|---|---|---|
| ESOL | Solubility (regression) | RMSE | 1,128 | ~0.55 |
| FreeSolv | Solvation energy (regression) | RMSE | 642 | ~1.10 |
| BBBP | Blood-brain barrier (classification) | ROC-AUC | 2,039 | ~0.92 |
| HIV | HIV activity (classification) | ROC-AUC | 41,127 | ~0.80 |
| Tox21 | Toxicity (classification) | ROC-AUC | 7,831 | ~0.85 |
State-of-the-art molecular GNNs like SchNet, DimeNet, and SphereNet also incorporate 3D geometric information (atom distances, angles, dihedral angles), moving beyond the 2D graph topology.
37.9 Knowledge Graphs
37.9.1 What Are Knowledge Graphs?
A knowledge graph stores factual information as a collection of (head, relation, tail) triples. For example: (Albert Einstein, born_in, Ulm), (Ulm, located_in, Germany). Nodes are entities, and typed edges represent relations.
Major knowledge graphs include Freebase, Wikidata, DBpedia, and proprietary ones at Google, Amazon, and Meta. They power search engines, recommendation systems, and question-answering systems.
37.9.2 Knowledge Graph Embeddings
The fundamental task on knowledge graphs is link prediction: given a head entity and a relation, predict the tail entity. For example, given (Albert Einstein, born_in, ?), the model should rank Ulm highly. This requires learning low-dimensional embeddings of entities and relations.
TransE (Bordes et al., 2013) is the simplest and most intuitive knowledge graph embedding model. The core idea: relations are translations in embedding space.
Intuition first: If we think of entities as points in a vector space, then a relation like "born_in" defines a direction. Starting from the point for "Albert Einstein" and moving in the "born_in" direction should land near the point for "Ulm."
Formula: $$\text{score}(h, r, t) = -\|\mathbf{h} + \mathbf{r} - \mathbf{t}\|$$
where: $\mathbf{h} \in \mathbb{R}^d$ is the head entity embedding, $\mathbf{r} \in \mathbb{R}^d$ is the relation embedding (the translation vector), and $\mathbf{t} \in \mathbb{R}^d$ is the tail entity embedding.
Training minimizes a margin-based loss: for each positive triple $(h, r, t)$, sample a negative triple $(h', r, t)$ or $(h, r, t')$ by corrupting either the head or tail, and push the positive score above the negative score by a margin $\gamma$:
$$\mathcal{L} = \sum_{(h,r,t) \in \mathcal{T}} \sum_{(h',r,t') \in \mathcal{T}'} \max\!\left(0,\; \gamma + \|\mathbf{h} + \mathbf{r} - \mathbf{t}\| - \|\mathbf{h}' + \mathbf{r} - \mathbf{t}'\|\right)$$
Limitation: TransE cannot model symmetric relations (if $\mathbf{h} + \mathbf{r} = \mathbf{t}$ and $\mathbf{t} + \mathbf{r} = \mathbf{h}$, then $\mathbf{r} = \mathbf{0}$, meaning all entities in a symmetric relation would need the same embedding). It also struggles with 1-to-N relations.
DistMult (Yang et al., 2015) uses a bilinear model: $\text{score}(h, r, t) = \mathbf{h}^T \text{diag}(\mathbf{r}) \mathbf{t} = \sum_i h_i r_i t_i$. This can model symmetric relations (since the score is symmetric in $\mathbf{h}$ and $\mathbf{t}$) but cannot model antisymmetric relations.
ComplEx (Trouillon et al., 2016) extends DistMult to complex-valued embeddings: $\text{score}(h, r, t) = \text{Re}(\mathbf{h}^T \text{diag}(\mathbf{r}) \bar{\mathbf{t}})$, where $\bar{\mathbf{t}}$ is the complex conjugate. The asymmetry of the conjugate operation allows ComplEx to model both symmetric and antisymmetric relations.
RotatE (Sun et al., 2019) represents relations as rotations in complex space.
Intuition first: Instead of translating entity embeddings, RotatE rotates them. Each relation corresponds to an element-wise rotation of the head embedding to align it with the tail.
Formula: $$\mathbf{t} = \mathbf{h} \circ \mathbf{r}, \quad |r_i| = 1 \;\; \forall i$$
where $\circ$ is the Hadamard (element-wise) product in complex space, and each component of $\mathbf{r}$ lies on the unit circle: $r_i = e^{i\theta_i}$. The scoring function is:
$$\text{score}(h, r, t) = -\|\mathbf{h} \circ \mathbf{r} - \mathbf{t}\|$$
Worked example: Suppose we have 2-dimensional complex embeddings. Let $\mathbf{h} = (1+i, 2+0i)$ and the relation "born_in" has rotation angles $\theta_1 = \pi/4$ and $\theta_2 = \pi/2$, so $\mathbf{r} = (e^{i\pi/4}, e^{i\pi/2}) = (\frac{\sqrt{2}}{2}+\frac{\sqrt{2}}{2}i,\; i)$. Then $\mathbf{h} \circ \mathbf{r} = ((1+i)(\frac{\sqrt{2}}{2}+\frac{\sqrt{2}}{2}i),\; (2)(i)) = (\sqrt{2}i,\; 2i)$. The tail entity "Ulm" should have an embedding close to this rotated vector.
RotatE can model all three relation patterns: symmetric (rotation by 0 or $\pi$), antisymmetric (rotation by any non-trivial angle), and composition ($\mathbf{r}_1 \circ \mathbf{r}_2$ is another rotation).
37.9.3 GNNs on Knowledge Graphs
GNN-based approaches like R-GCN (Relational Graph Convolutional Network, Schlichtkrull et al., 2018) extend GCN to handle multiple relation types. For each relation type $r$, there is a separate weight matrix:
$$\mathbf{h}_v^{(\ell)} = \sigma\!\left(\sum_{r \in \mathcal{R}} \sum_{u \in \mathcal{N}_r(v)} \frac{1}{|\mathcal{N}_r(v)|} \mathbf{W}_r^{(\ell)} \mathbf{h}_u^{(\ell-1)} + \mathbf{W}_0^{(\ell)} \mathbf{h}_v^{(\ell-1)}\right)$$
where $\mathcal{N}_r(v)$ is the set of neighbors of $v$ under relation $r$, and $\mathbf{W}_0^{(\ell)}$ is a self-connection weight.
The challenge with R-GCN is the explosion of parameters when there are many relation types. Basis decomposition addresses this by expressing each $\mathbf{W}_r$ as a linear combination of a small number of basis matrices:
$$\mathbf{W}_r = \sum_{b=1}^{B} a_{rb} \mathbf{V}_b$$
where $\mathbf{V}_b$ are shared basis matrices and $a_{rb}$ are relation-specific coefficients.
37.9.4 Heterogeneous Graphs and Relational Message Passing
Many real-world graphs are heterogeneous: they contain multiple types of nodes and edges. An e-commerce graph, for instance, might have user nodes, product nodes, and brand nodes, connected by "purchased," "reviewed," "manufactured_by," and "similar_to" edge types. Standard GNNs assume a single node type and edge type, limiting their applicability.
Heterogeneous graph formalization. A heterogeneous graph $G = (V, E, \tau, \phi)$ extends a standard graph with a node-type function $\tau: V \to \mathcal{A}$ and an edge-type function $\phi: E \to \mathcal{R}$, where $\mathcal{A}$ is the set of node types and $\mathcal{R}$ is the set of relation types. Each node type may have a different feature dimensionality, and each relation type may require a different message function.
R-GCN for heterogeneous graphs. As we saw earlier, R-GCN uses relation-specific weight matrices. In practice, R-GCN is implemented by computing separate message passing operations for each relation type and summing the results:
$$\mathbf{h}_v^{(\ell)} = \sigma\!\left(\sum_{r \in \mathcal{R}} \sum_{u \in \mathcal{N}_r(v)} \frac{1}{c_{v,r}} \mathbf{W}_r^{(\ell)} \mathbf{h}_u^{(\ell-1)} + \mathbf{W}_0^{(\ell)} \mathbf{h}_v^{(\ell-1)}\right)$$
where $c_{v,r}$ is a normalization constant (typically $|\mathcal{N}_r(v)|$).
Heterogeneous Graph Transformer (HGT). Hu et al. (2020) extended the Transformer attention mechanism to heterogeneous graphs. For a target node $t$ of type $\tau(t)$ and a source node $s$ of type $\tau(s)$ connected by relation $\phi(e)$, HGT computes type-dependent queries, keys, and values:
$$\mathbf{Q}_t = \mathbf{W}_Q^{\tau(t)} \mathbf{h}_t, \quad \mathbf{K}_s = \mathbf{W}_K^{\tau(s)} \mathbf{h}_s, \quad \mathbf{V}_s = \mathbf{W}_V^{\tau(s)} \mathbf{h}_s$$
with relation-dependent attention:
$$\alpha_{ts} = \text{softmax}\!\left(\frac{\mathbf{Q}_t \mathbf{W}_{\text{ATT}}^{\phi(e)} \mathbf{K}_s^T}{\sqrt{d}}\right)$$
This captures type-specific semantics while maintaining the efficiency of the Transformer attention pattern.
Applications of heterogeneous GNNs. Heterogeneous GNNs are the backbone of modern recommendation systems (as we will explore in Chapter 37's applications), where user-item-feature interactions naturally form heterogeneous graphs. Pinterest's PinSage, for example, operates on a heterogeneous bipartite graph of users and pins.
37.10 Applications of Graph Neural Networks
37.10.1 Drug Discovery and Molecular Design
Drug discovery is one of the highest-impact applications of GNNs. The pipeline includes:
- Virtual screening: Score millions of candidate molecules for a target protein using GNN-based property prediction. A trained GNN can evaluate molecules orders of magnitude faster than physical simulations.
- Lead optimization: Given a promising "hit" molecule, modify its structure to improve properties like binding affinity, solubility, and toxicity. GNNs predict how changes to atoms and bonds affect these properties.
- De novo design: Use generative GNN models (graph VAEs, graph diffusion models) to generate entirely new molecular structures with desired properties.
- Retrosynthesis: Given a target molecule, predict the synthetic routes to produce it. GNNs learn to propose reaction steps by treating chemical reactions as graph transformations.
Companies including Recursion, Insilico Medicine, and Relay Therapeutics have built GNN-powered drug discovery platforms, and several AI-designed drugs have entered clinical trials.
37.10.2 Social Network Analysis
Social networks are natural graphs, and GNNs power several key applications:
- Community detection: Identify clusters of closely connected users using learned node representations.
- Influence prediction: Predict how information spreads through a network by modeling cascades as temporal graph processes.
- Fake account detection: Classify accounts as legitimate or fake based on their connectivity patterns. Fake accounts often have distinctive structural signatures (e.g., forming dense clusters with other fakes, having few mutual connections with legitimate users).
- Content recommendation: Predict which content a user will engage with based on their position in the social graph and their interactions.
37.10.3 Recommendation Systems
GNN-based recommendation systems model the user-item interaction as a bipartite graph and learn representations that capture collaborative filtering signals:
- LightGCN (He et al., 2020): A simplified GCN that removes feature transformations and nonlinearities, showing that simple neighborhood aggregation on the user-item graph is sufficient for strong recommendation performance. The propagation rule is simply: $\mathbf{e}_u^{(\ell+1)} = \sum_{i \in \mathcal{N}(u)} \frac{1}{\sqrt{|\mathcal{N}(u)|}\sqrt{|\mathcal{N}(i)|}} \mathbf{e}_i^{(\ell)}$.
- PinSage (Ying et al., 2018): Pinterest's production GNN-based recommendation system, processing a graph with 3 billion nodes and 18 billion edges. It uses random-walk-based neighbor sampling and importance pooling.
37.11 PyTorch Geometric: The Practical GNN Library
37.11.1 Overview
PyTorch Geometric (PyG) is the dominant library for GNN development. It provides: - Efficient message passing infrastructure - 40+ GNN layer implementations (GCN, GAT, GraphSAGE, GIN, and many more) - Built-in datasets (Cora, Citeseer, PPI, QM9, ZINC, OGB, and dozens more) - Mini-batch handling for graphs of varying sizes - Data transforms and augmentations - Integration with OGB (Open Graph Benchmark)
37.11.2 The Data Object
PyG represents a graph as a Data object:
from torch_geometric.data import Data
# Create a graph with 4 nodes and 4 (undirected) edges
edge_index = torch.tensor([
[0, 1, 1, 2, 2, 3, 3, 0],
[1, 0, 2, 1, 3, 2, 0, 3],
], dtype=torch.long)
x = torch.randn(4, 16) # 4 nodes, 16 features each
y = torch.tensor([0, 1, 0, 1]) # Node labels
data = Data(x=x, edge_index=edge_index, y=y)
print(data)
# Data(x=[4, 16], edge_index=[2, 8], y=[4])
print(f"Number of nodes: {data.num_nodes}")
print(f"Number of edges: {data.num_edges}")
print(f"Has self-loops: {data.has_self_loops()}")
print(f"Is undirected: {data.is_undirected()}")
37.11.3 Built-in Datasets
from torch_geometric.datasets import Planetoid, TUDataset, MoleculeNet
# Citation networks
cora = Planetoid(root='/tmp/Cora', name='Cora')
data = cora[0]
print(f"Cora: {data.num_nodes} nodes, {data.num_edges} edges, "
f"{cora.num_features} features, {cora.num_classes} classes")
# Graph classification datasets
proteins = TUDataset(root='/tmp/PROTEINS', name='PROTEINS')
print(f"PROTEINS: {len(proteins)} graphs, "
f"{proteins.num_features} features, {proteins.num_classes} classes")
# Molecular datasets
esol = MoleculeNet(root='/tmp/ESOL', name='ESOL')
print(f"ESOL: {len(esol)} molecules")
37.11.4 Using PyG Layers
import torch.nn as nn
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, global_mean_pool
class PyGNodeClassifier(nn.Module):
"""Node classification model using PyG layers."""
def __init__(
self, in_channels: int, hidden_channels: int, out_channels: int,
dropout: float = 0.5,
) -> None:
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
self.dropout = dropout
def forward(self, data: Data) -> torch.Tensor:
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
class PyGGraphClassifier(nn.Module):
"""Graph classification model using PyG layers."""
def __init__(
self, in_channels: int, hidden_channels: int, out_channels: int,
num_layers: int = 3,
) -> None:
super().__init__()
self.convs = nn.ModuleList()
self.convs.append(GATConv(in_channels, hidden_channels, heads=4, concat=True))
for _ in range(num_layers - 2):
self.convs.append(
GATConv(hidden_channels * 4, hidden_channels, heads=4, concat=True)
)
self.convs.append(GATConv(hidden_channels * 4, hidden_channels, heads=1))
self.classifier = nn.Linear(hidden_channels, out_channels)
def forward(self, data: Data) -> torch.Tensor:
x, edge_index, batch = data.x, data.edge_index, data.batch
for conv in self.convs[:-1]:
x = F.elu(conv(x, edge_index))
x = self.convs[-1](x, edge_index)
x = global_mean_pool(x, batch)
return self.classifier(x)
37.11.5 Mini-Batching Graphs
In standard deep learning, mini-batching is trivial: stack tensors along a new batch dimension. For graphs, this is more complex because graphs have different sizes. PyG solves this elegantly by combining multiple graphs into a single disconnected graph:
from torch_geometric.loader import DataLoader
# Suppose we have a list of graph Data objects
# dataset = [data_1, data_2, ..., data_n]
# loader = DataLoader(dataset, batch_size=32, shuffle=True)
# Each batch contains:
# - batch.x: [total_nodes_in_batch, features]
# - batch.edge_index: [2, total_edges_in_batch]
# - batch.batch: [total_nodes_in_batch] mapping each node to its graph index
# - batch.y: [batch_size] graph-level labels
The batch.batch tensor is the key---it tells the readout function which nodes belong to which graph, enabling operations like global_mean_pool(x, batch.batch).
37.12 Advanced GNN Architectures
37.12.1 Graph Isomorphism Network (GIN)
Xu et al. (2019) analyzed the expressive power of GNNs through the lens of the Weisfeiler-Lehman (WL) graph isomorphism test. They showed that standard GNNs with mean or max aggregation are strictly less powerful than the WL test, while a GNN with sum aggregation and an injective update function can match it.
The Graph Isomorphism Network implements this:
$$\mathbf{h}_v^{(\ell)} = \text{MLP}^{(\ell)}\!\left((1 + \epsilon^{(\ell)}) \cdot \mathbf{h}_v^{(\ell-1)} + \sum_{u \in \mathcal{N}(v)} \mathbf{h}_u^{(\ell-1)}\right)$$
where $\epsilon$ is a learnable scalar. The MLP provides the injective function, and sum aggregation preserves multiset information.
37.12.2 Equivariant GNNs for Geometric Data
When working with 3D molecular structures, we need GNNs that respect 3D symmetries---rotations, translations, and reflections. Equivariant GNNs like:
- SchNet: Uses continuous-filter convolutions based on interatomic distances
- DimeNet: Incorporates bond angles via directional message passing
- EGNN (Equivariant GNN): Updates both node features and 3D coordinates while maintaining E(3) equivariance
- TFN (Tensor Field Networks): Uses spherical harmonics for fully SO(3)-equivariant message passing
These architectures are crucial for molecular dynamics, protein structure prediction (as in AlphaFold), and materials science.
37.12.3 Graph Transformers
Recent work has extended the Transformer architecture to graphs:
- Graphormer (Ying et al., 2021): Adds graph structure through centrality encoding, spatial encoding (shortest path distances), and edge encoding in the attention mechanism. Won the OGB Large-Scale Challenge.
- GPS (General, Powerful, Scalable): Combines local message passing with global attention, using positional encodings (Laplacian eigenvectors, random walk probabilities) to inject structural information.
- TokenGT: Treats nodes and edges as tokens for a standard Transformer.
Graph Transformers address the over-smoothing problem by allowing global attention (every node can attend to every other node), but they lose the efficiency of sparse message passing---attention is $O(N^2)$ in the number of nodes.
37.13 Practical Considerations
37.13.1 When to Use GNNs
Use GNNs when: - Your data has an inherent graph structure (molecules, networks, knowledge graphs) - Relationships between entities are as important as entity features - You need to reason about local and global structure simultaneously - You need inductive learning over varying graph structures
Do not use GNNs when: - Your data is naturally tabular, sequential, or grid-structured - The graph structure is artificial or uninformative - You have very few nodes (simple MLPs often suffice) - The graph is essentially complete (all nodes connected), where standard architectures work fine
37.13.2 Scaling GNNs
For large graphs (millions of nodes), several strategies exist:
- Neighbor sampling (GraphSAGE): Sample a fixed number of neighbors per layer
- Cluster-GCN: Partition the graph into clusters and train on subgraphs
- GraphSAINT: Sample subgraphs using random walks or node/edge sampling
- Distributed training: Use libraries like DistDGL or PyG's distributed data parallelism
37.13.3 Common Pitfalls
-
Data leakage through edges: In node classification, test nodes may be connected to training nodes. This is a feature, not a bug---but make sure you are not leaking label information through edge features.
-
Over-smoothing with deep models: Monitor the cosine similarity between node representations as you add layers. If representations become too similar, reduce depth.
-
Ignoring graph structure in baselines: Always compare against a simple MLP on node features (ignoring edges). If the MLP performs comparably, the graph structure may not be informative.
-
Feature engineering matters: In molecular property prediction, careful atom and bond featurization often matters more than the choice of GNN architecture.
-
Random splits can be misleading: For molecular datasets, use scaffold splits (splitting by molecular substructure) rather than random splits, as random splits overestimate generalization performance.
-
Undirected edge representation: Remember that undirected edges must be represented as two directed edges in the COO format. A common bug is adding only $(u, v)$ without $(v, u)$, which creates a directed graph and produces unexpected results.
-
Feature normalization: Normalize node features before feeding them to the GNN. Batch normalization or layer normalization between GNN layers also helps stabilize training.
37.13.4 Hyperparameter Guidelines
Based on extensive experimentation across multiple benchmarks, the following hyperparameter ranges serve as reasonable starting points:
| Hyperparameter | Typical Range | Notes |
|---|---|---|
| Number of layers | 2-4 | Deeper causes over-smoothing |
| Hidden dimension | 64-256 | Larger for more complex graphs |
| Learning rate | 1e-3 to 1e-2 | Use Adam optimizer |
| Dropout | 0.3-0.6 | Higher for small datasets |
| Number of heads (GAT) | 4-8 | More heads for complex relationships |
| Sample size (GraphSAGE) | 10-25 per layer | Balance variance vs. compute |
| Weight decay | 5e-4 to 5e-3 | Regularization for small graphs |
Summary
Graph neural networks extend deep learning to data with arbitrary structure. The message passing framework---where nodes iteratively aggregate information from their neighbors---provides a flexible and principled way to learn on graphs while respecting their permutation symmetry.
We covered three foundational architectures: GCN (spectral-inspired, simple, effective), GraphSAGE (inductive, scalable, industry-proven), and GAT (attention-weighted, expressive, interpretable). For graph-level tasks, readout functions aggregate node representations, with hierarchical pooling preserving structural information. We explored key applications: node classification in citation networks, molecular property prediction for drug discovery, and knowledge graph reasoning.
PyTorch Geometric provides the practical infrastructure for GNN development, with efficient message passing, dozens of built-in architectures, and seamless mini-batching. Advanced directions include maximally expressive architectures (GIN), equivariant models for 3D geometry, and graph Transformers that combine local message passing with global attention.
The field of graph neural networks is evolving rapidly, driven by applications in drug discovery, recommendation systems, social network analysis, and scientific computing. The foundations covered in this chapter provide the vocabulary, mathematical framework, and implementation skills needed to engage with this exciting area.
Quick Reference
| Concept | Key Equation / Idea |
|---|---|
| Adjacency matrix | $A_{ij} = 1$ if edge $(i,j)$ exists |
| Message passing | Aggregate neighbor info, transform, update |
| GCN | $\mathbf{H}^{(\ell)} = \sigma(\tilde{\mathbf{D}}^{-1/2}\tilde{\mathbf{A}}\tilde{\mathbf{D}}^{-1/2}\mathbf{H}^{(\ell-1)}\mathbf{W}^{(\ell)})$ |
| GraphSAGE | Sample neighbors, aggregate, concatenate with self |
| GAT | Learned attention weights over neighbors |
| Readout | Sum/mean/max over node representations for graph-level tasks |
| Over-smoothing | Node representations converge with too many layers |
| WL test | Upper bound on GNN expressiveness |
| PyG Data | Data(x, edge_index, y, ...) |