Case Study 2: Node Classification on Citation Networks
Problem Description
Citation networks are a canonical benchmark for graph neural networks. In a citation network, each node represents a scientific paper, each edge represents a citation from one paper to another, and the task is to classify papers by their topic---using only a small fraction of labeled papers.
This is a semi-supervised problem: labeled data is scarce, but the graph structure provides rich relational information. A paper's topic is strongly correlated with the topics of the papers it cites and the papers that cite it. GNNs exploit this homophily (tendency of connected nodes to share labels) to propagate label information through the network.
The Cora Dataset
- Nodes: 2,708 scientific papers
- Edges: 5,429 citation links (treated as undirected)
- Node features: 1,433-dimensional binary bag-of-words vectors
- Classes: 7 topics (Case_Based, Genetic_Algorithms, Neural_Networks, Probabilistic_Methods, Reinforcement_Learning, Rule_Learning, Theory)
- Training labels: 140 nodes (20 per class, ~5% of total)
- Validation: 500 nodes
- Test: 1,000 nodes
Why This Problem Is Interesting
With only 5% labeled data, a standard MLP (ignoring graph structure) achieves about 55--60% accuracy. A 2-layer GCN achieves about 81%, and a tuned GAT reaches about 83%. The gap demonstrates the value of relational information.
Implementation
Step 1: Data Loading and Exploration
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
torch.manual_seed(42)
np.random.seed(42)
def create_synthetic_citation_network(
num_nodes: int = 2708,
num_features: int = 1433,
num_classes: int = 7,
num_edges: int = 5429,
num_train: int = 140,
num_val: int = 500,
num_test: int = 1000,
) -> dict[str, torch.Tensor]:
"""Create a synthetic citation network mimicking Cora statistics.
Generates a graph with homophilic structure: nodes of the same
class are more likely to be connected.
Returns:
Dictionary with x, edge_index, y, and train/val/test masks.
"""
# Assign class labels
y = torch.randint(0, num_classes, (num_nodes,))
# Create sparse binary features
x = torch.zeros(num_nodes, num_features)
for i in range(num_nodes):
nonzero_indices = torch.randint(0, num_features, (30,))
x[i, nonzero_indices] = 1.0
# Create edges with homophily (80% within-class, 20% between-class)
src_nodes, dst_nodes = [], []
for _ in range(num_edges):
src = torch.randint(0, num_nodes, (1,)).item()
if torch.rand(1).item() < 0.8:
# Same-class connection
same_class = (y == y[src]).nonzero(as_tuple=True)[0]
dst = same_class[torch.randint(0, len(same_class), (1,))].item()
else:
dst = torch.randint(0, num_nodes, (1,)).item()
src_nodes.append(src)
dst_nodes.append(dst)
# Make undirected
edge_index = torch.tensor(
[src_nodes + dst_nodes, dst_nodes + src_nodes], dtype=torch.long
)
# Create masks
perm = torch.randperm(num_nodes)
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
val_mask = torch.zeros(num_nodes, dtype=torch.bool)
test_mask = torch.zeros(num_nodes, dtype=torch.bool)
train_mask[perm[:num_train]] = True
val_mask[perm[num_train : num_train + num_val]] = True
test_mask[perm[num_train + num_val : num_train + num_val + num_test]] = True
return {
"x": x,
"edge_index": edge_index,
"y": y,
"train_mask": train_mask,
"val_mask": val_mask,
"test_mask": test_mask,
}
Step 2: GCN Model
class GCNConvLayer(nn.Module):
"""GCN convolution with symmetric normalization."""
def __init__(self, in_channels: int, out_channels: int) -> None:
super().__init__()
self.weight = nn.Parameter(torch.empty(in_channels, out_channels))
self.bias = nn.Parameter(torch.zeros(out_channels))
nn.init.xavier_uniform_(self.weight)
def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
num_nodes = x.size(0)
# Add self-loops
loop_index = torch.arange(num_nodes, device=x.device)
loop_index = loop_index.unsqueeze(0).repeat(2, 1)
edge_index_sl = torch.cat([edge_index, loop_index], dim=1)
row, col = edge_index_sl
# Degree computation
deg = torch.zeros(num_nodes, device=x.device)
deg.scatter_add_(0, row, torch.ones(row.size(0), device=x.device))
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0.0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Transform then aggregate
x = x @ self.weight
out = torch.zeros_like(x)
out.scatter_add_(
0, col.unsqueeze(1).expand_as(x[row]), norm.unsqueeze(1) * x[row]
)
return out + self.bias
class GCNModel(nn.Module):
"""Two-layer GCN for node classification."""
def __init__(
self,
in_channels: int,
hidden_channels: int,
out_channels: int,
dropout: float = 0.5,
) -> None:
super().__init__()
self.conv1 = GCNConvLayer(in_channels, hidden_channels)
self.conv2 = GCNConvLayer(hidden_channels, out_channels)
self.dropout = dropout
def forward(
self, x: torch.Tensor, edge_index: torch.Tensor
) -> torch.Tensor:
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
Step 3: GAT Model
class GATConvLayer(nn.Module):
"""Single-head GAT convolution layer."""
def __init__(
self,
in_channels: int,
out_channels: int,
num_heads: int = 1,
dropout: float = 0.6,
concat: bool = True,
) -> None:
super().__init__()
self.num_heads = num_heads
self.out_channels = out_channels
self.concat = concat
self.W = nn.Parameter(torch.empty(in_channels, num_heads * out_channels))
self.a_src = nn.Parameter(torch.empty(num_heads, out_channels))
self.a_dst = nn.Parameter(torch.empty(num_heads, out_channels))
self.dropout = nn.Dropout(dropout)
self.leaky_relu = nn.LeakyReLU(0.2)
nn.init.xavier_uniform_(self.W)
nn.init.xavier_uniform_(self.a_src.unsqueeze(0))
nn.init.xavier_uniform_(self.a_dst.unsqueeze(0))
def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
num_nodes = x.size(0)
# Add self-loops
loop = torch.arange(num_nodes, device=x.device).unsqueeze(0).repeat(2, 1)
edge_index_sl = torch.cat([edge_index, loop], dim=1)
row, col = edge_index_sl
# Linear transform
h = (x @ self.W).view(num_nodes, self.num_heads, self.out_channels)
# Attention scores
attn_src = (h[row] * self.a_src).sum(dim=-1)
attn_dst = (h[col] * self.a_dst).sum(dim=-1)
attn = self.leaky_relu(attn_src + attn_dst)
# Numerically stable softmax
attn_max = torch.full((num_nodes, self.num_heads), -1e9, device=x.device)
attn_max.scatter_reduce_(
0, col.unsqueeze(1).expand_as(attn), attn, reduce="amax"
)
attn = torch.exp(attn - attn_max[col])
attn_sum = torch.zeros(num_nodes, self.num_heads, device=x.device)
attn_sum.scatter_add_(0, col.unsqueeze(1).expand_as(attn), attn)
attn = attn / (attn_sum[col] + 1e-8)
attn = self.dropout(attn)
# Weighted aggregation
out = torch.zeros(num_nodes, self.num_heads, self.out_channels, device=x.device)
weighted = h[row] * attn.unsqueeze(-1)
out.scatter_add_(
0, col.unsqueeze(1).unsqueeze(2).expand_as(weighted), weighted
)
if self.concat:
return out.view(num_nodes, self.num_heads * self.out_channels)
return out.mean(dim=1)
class GATModel(nn.Module):
"""Two-layer GAT for node classification."""
def __init__(
self,
in_channels: int,
hidden_channels: int,
out_channels: int,
heads: int = 8,
dropout: float = 0.6,
) -> None:
super().__init__()
self.conv1 = GATConvLayer(
in_channels, hidden_channels, num_heads=heads, dropout=dropout
)
self.conv2 = GATConvLayer(
hidden_channels * heads, out_channels, num_heads=1,
dropout=dropout, concat=False,
)
self.dropout = dropout
def forward(
self, x: torch.Tensor, edge_index: torch.Tensor
) -> torch.Tensor:
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.conv1(x, edge_index)
x = F.elu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
Step 4: Training and Evaluation Framework
def train_and_evaluate(
model: nn.Module,
data: dict[str, torch.Tensor],
num_epochs: int = 200,
lr: float = 0.01,
weight_decay: float = 5e-4,
) -> dict[str, list[float]]:
"""Train a node classification model and track metrics.
Args:
model: GNN model for node classification.
data: Dict with x, edge_index, y, train/val/test masks.
num_epochs: Training epochs.
lr: Learning rate.
weight_decay: L2 regularization.
Returns:
Dict with train_losses, val_accs, test_accs.
"""
optimizer = torch.optim.Adam(
model.parameters(), lr=lr, weight_decay=weight_decay
)
x = data["x"]
edge_index = data["edge_index"]
y = data["y"]
train_mask = data["train_mask"]
val_mask = data["val_mask"]
test_mask = data["test_mask"]
history: dict[str, list[float]] = {
"train_losses": [],
"val_accs": [],
"test_accs": [],
}
best_val_acc = 0.0
best_test_acc = 0.0
for epoch in range(num_epochs):
# Training
model.train()
optimizer.zero_grad()
out = model(x, edge_index)
loss = F.nll_loss(out[train_mask], y[train_mask])
loss.backward()
optimizer.step()
history["train_losses"].append(loss.item())
# Evaluation
model.eval()
with torch.no_grad():
logits = model(x, edge_index)
pred = logits.argmax(dim=1)
val_acc = (pred[val_mask] == y[val_mask]).float().mean().item()
test_acc = (pred[test_mask] == y[test_mask]).float().mean().item()
history["val_accs"].append(val_acc)
history["test_accs"].append(test_acc)
if val_acc > best_val_acc:
best_val_acc = val_acc
best_test_acc = test_acc
if (epoch + 1) % 50 == 0:
print(
f"Epoch {epoch + 1:3d} | Loss: {loss.item():.4f} | "
f"Val Acc: {val_acc:.4f} | Test Acc: {test_acc:.4f}"
)
print(f"\nBest Val Acc: {best_val_acc:.4f} | Corresponding Test Acc: {best_test_acc:.4f}")
return history
def compare_models(data: dict[str, torch.Tensor]) -> None:
"""Compare MLP, GCN, and GAT on the citation network."""
in_channels = data["x"].size(1)
num_classes = data["y"].max().item() + 1
# MLP Baseline (ignores graph structure)
print("=" * 60)
print("MLP Baseline (no graph structure)")
print("=" * 60)
mlp = nn.Sequential(
nn.Linear(in_channels, 64),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(64, num_classes),
nn.LogSoftmax(dim=1),
)
# For MLP, wrap in a simple adapter
class MLPWrapper(nn.Module):
def __init__(self, mlp: nn.Module) -> None:
super().__init__()
self.mlp = mlp
def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
return self.mlp(x)
train_and_evaluate(MLPWrapper(mlp), data, num_epochs=200)
# GCN
print("\n" + "=" * 60)
print("GCN (2 layers, 16 hidden)")
print("=" * 60)
gcn = GCNModel(in_channels, 16, num_classes, dropout=0.5)
train_and_evaluate(gcn, data, num_epochs=200)
# GAT
print("\n" + "=" * 60)
print("GAT (2 layers, 8 hidden, 8 heads)")
print("=" * 60)
gat = GATModel(in_channels, 8, num_classes, heads=8, dropout=0.6)
train_and_evaluate(gat, data, num_epochs=200, lr=0.005, weight_decay=5e-4)
Step 5: Analysis
def analyze_predictions(
model: nn.Module,
data: dict[str, torch.Tensor],
class_names: list[str] | None = None,
) -> None:
"""Analyze model predictions by class and neighborhood structure."""
model.eval()
x = data["x"]
edge_index = data["edge_index"]
y = data["y"]
test_mask = data["test_mask"]
num_classes = y.max().item() + 1
if class_names is None:
class_names = [f"Class_{i}" for i in range(num_classes)]
with torch.no_grad():
logits = model(x, edge_index)
pred = logits.argmax(dim=1)
# Per-class accuracy
print("\nPer-class accuracy on test set:")
print("-" * 40)
for c in range(num_classes):
mask = test_mask & (y == c)
if mask.sum() > 0:
acc = (pred[mask] == y[mask]).float().mean().item()
print(f" {class_names[c]:30s}: {acc:.4f} ({mask.sum().item()} samples)")
# Accuracy by node degree
row = edge_index[0]
degree = torch.zeros(x.size(0), dtype=torch.long)
degree.scatter_add_(0, row, torch.ones(row.size(0), dtype=torch.long))
print("\nAccuracy by node degree (test nodes):")
print("-" * 40)
test_degrees = degree[test_mask]
test_correct = (pred[test_mask] == y[test_mask]).float()
for low, high in [(0, 2), (2, 5), (5, 10), (10, 50), (50, 1000)]:
bucket = (test_degrees >= low) & (test_degrees < high)
if bucket.sum() > 0:
acc = test_correct[bucket].mean().item()
print(f" Degree [{low}, {high}): {acc:.4f} ({bucket.sum().item()} nodes)")
Results
Performance Comparison on Cora
| Model | Test Accuracy | Parameters | Training Time (200 epochs) |
|---|---|---|---|
| MLP (no edges) | ~57% | ~92K | ~2s |
| GCN (2 layers, 16 hidden) | ~81% | ~23K | ~4s |
| GAT (8 heads, 8 hidden) | ~83% | ~92K | ~12s |
| GraphSAGE (mean, 16 hidden) | ~80% | ~47K | ~5s |
Key Findings
-
Graph structure is critical: The 24-percentage-point gap between MLP (57%) and GCN (81%) demonstrates that citation links carry substantial information about paper topics. Connected papers are likely to share the same topic.
-
Attention provides modest gains: GAT outperforms GCN by about 2 percentage points. The learned attention weights allow the model to focus on the most relevant citations, particularly useful when a paper cites works from multiple fields.
-
Low-degree nodes benefit most from graph structure: Nodes with very few connections show the largest improvement when switching from MLP to GCN. The graph structure compensates for limited node features by borrowing information from neighbors.
-
High-degree nodes can suffer from over-smoothing: Nodes with many connections aggregate a diverse set of features, which can dilute their representation. GAT mitigates this by attending selectively.
-
Feature quality matters: Replacing bag-of-words features with random features reduces GCN accuracy from 81% to about 65%. The graph structure alone is not sufficient---good node features and graph structure are complementary.
Attention Weight Analysis
Examining GAT attention weights reveals interpretable patterns: - Papers tend to attend most strongly to papers in the same topic area - High-degree hub papers (frequently cited surveys) receive moderate but consistent attention from many nodes - Cross-topic citations receive lower attention weights, suggesting the model learns to filter noise from off-topic references
Lessons Learned
-
Semi-supervised learning on graphs is remarkably data-efficient: With only 20 labeled examples per class, GNNs achieve 80%+ accuracy by leveraging graph structure to propagate information.
-
Model depth is limited: Adding a third GCN layer does not improve accuracy on Cora and may even hurt. This is consistent with over-smoothing: Cora's average path length is short, so 2 hops already captures most of the relevant neighborhood.
-
Transductive vs. inductive matters: On Cora's fixed split, GCN and GAT achieve similar performance. For inductive settings (new nodes at test time), GraphSAGE and GAT are preferred over GCN.
-
Hyperparameter sensitivity: GAT is more sensitive to learning rate and dropout than GCN. The attention mechanism benefits from higher dropout (0.6 vs. 0.5) and lower learning rate (0.005 vs. 0.01).
-
Standard benchmarks have limitations: The Cora/Citeseer/Pubmed benchmarks are small and highly homophilic. Performance on these datasets does not always predict performance on larger, more heterogeneous real-world graphs.