Case Study 1: Diagnosing and Fixing a Failing Training Run
Overview
You are an ML engineer at a startup building a document classification system. A junior colleague has trained a deep neural network to classify business documents into 50 categories, but the model is performing poorly---validation accuracy is stuck at 12% (barely above random for 50 classes). Your task is to diagnose the issues and systematically fix them.
This case study walks through a realistic debugging scenario, applying the diagnostic techniques from Section 12.9 to identify and resolve multiple interacting problems.
The Initial Setup
The colleague's code looks like this:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
# Model: 6-layer MLP for document classification
class DocumentClassifier(nn.Module):
"""A deep MLP for document classification.
Args:
input_dim: Dimension of TF-IDF features.
num_classes: Number of document categories.
"""
def __init__(self, input_dim: int = 5000, num_classes: int = 50):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(input_dim, 2048),
nn.Sigmoid(),
nn.Linear(2048, 1024),
nn.Sigmoid(),
nn.Linear(1024, 512),
nn.Sigmoid(),
nn.Linear(512, 256),
nn.Sigmoid(),
nn.Linear(256, 128),
nn.Sigmoid(),
nn.Linear(128, num_classes),
nn.Softmax(dim=1), # Problem!
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layers(x)
# Training
model = DocumentClassifier()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
for epoch in range(100):
model.eval() # Problem!
for inputs, targets in train_loader:
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
# No validation evaluation
The colleague reports: "The loss barely decreases and validation accuracy is stuck at 12%."
Step 1: The Overfit-One-Batch Test
Before analyzing the full training run, we apply the overfit-one-batch test (Section 12.9.3):
torch.manual_seed(42)
model = DocumentClassifier()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
batch = next(iter(train_loader))
inputs, targets = batch
model.train()
for step in range(500):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
if step % 50 == 0:
_, predicted = outputs.max(1)
acc = predicted.eq(targets).float().mean().item()
print(f"Step {step}: loss={loss.item():.6f}, acc={acc:.4f}")
Output:
Step 0: loss=3.912023, acc=0.0312
Step 50: loss=3.911998, acc=0.0312
Step 100: loss=3.911985, acc=0.0312
Step 150: loss=3.911979, acc=0.0312
Step 200: loss=3.911975, acc=0.0312
...
The loss is barely moving and accuracy is not improving. The model cannot even overfit a single batch---this confirms there are fundamental bugs.
Step 2: Systematic Bug Identification
We go through the debugging checklist from Section 12.9.4:
Bug 1: Softmax Before CrossEntropyLoss
The model applies nn.Softmax(dim=1) as the final layer, then passes the output to nn.CrossEntropyLoss. As discussed in Section 12.1.2, CrossEntropyLoss internally applies log-softmax to its input. Applying softmax first means the loss computes log(softmax(softmax(z))), which compresses the logits and produces near-zero gradients.
Fix: Remove the Softmax layer from the model. CrossEntropyLoss expects raw logits.
Bug 2: Sigmoid Activations with No Initialization Strategy
The model uses sigmoid activations throughout a 6-layer network. Sigmoid squashes all values to [0, 1], and its gradient $\sigma'(z) = \sigma(z)(1-\sigma(z))$ has a maximum of 0.25 at $z=0$. In a 6-layer network, the gradient is multiplied by at most $0.25^6 \approx 0.000244$ on the way back, causing severe vanishing gradients.
Additionally, there is no explicit weight initialization---PyTorch's default (Kaiming uniform) is designed for ReLU, not sigmoid.
Fix: Replace sigmoid with ReLU and use He initialization.
Bug 3: model.eval() During Training
The training loop calls model.eval() instead of model.train(). While this model does not currently use batch normalization or dropout, this is still a serious bug that would prevent normalization layers from working correctly if added later.
Fix: Call model.train() before the training loop.
Bug 4: Missing optimizer.zero_grad()
The training loop never calls optimizer.zero_grad(). Gradients accumulate across all batches indefinitely, leading to exploding effective gradients that counterintuitively produce no useful learning because the gradient direction becomes dominated by old, irrelevant batches.
Fix: Call optimizer.zero_grad() before each forward pass.
Bug 5: Learning Rate Too Low
With a learning rate of 0.0001 and plain SGD (no momentum), the parameter updates are tiny. Combined with the vanishing gradient problem from sigmoid, the effective update magnitude is negligible.
Fix: Use a more appropriate optimizer (AdamW) with a reasonable learning rate (e.g., 3e-4), or use SGD with momentum and a higher learning rate.
Bug 6: No Normalization
A 6-layer network without normalization layers is prone to internal covariate shift, especially without careful initialization.
Fix: Add batch normalization or layer normalization between layers.
Bug 7: No Validation Monitoring
The code has no validation evaluation, making it impossible to detect overfitting or assess true model performance.
Fix: Add periodic validation evaluation.
Step 3: The Fixed Model
Here is the corrected code with all bugs fixed:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
torch.manual_seed(42)
class DocumentClassifier(nn.Module):
"""A deep MLP for document classification.
Args:
input_dim: Dimension of TF-IDF features.
num_classes: Number of document categories.
dropout_rate: Dropout probability.
"""
def __init__(
self,
input_dim: int = 5000,
num_classes: int = 50,
dropout_rate: float = 0.3,
):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(input_dim, 2048),
nn.BatchNorm1d(2048),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(2048, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(1024, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(256, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(128, num_classes),
# No softmax---CrossEntropyLoss handles it
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
x: Input features of shape (batch_size, input_dim).
Returns:
Raw logits of shape (batch_size, num_classes).
"""
return self.layers(x)
def init_weights(module: nn.Module) -> None:
"""Initialize weights with He initialization for ReLU.
Args:
module: A PyTorch module.
"""
if isinstance(module, nn.Linear):
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.BatchNorm1d):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
# Create model
model = DocumentClassifier()
model.apply(init_weights)
# Loss function (no softmax in model, CrossEntropyLoss handles it)
criterion = nn.CrossEntropyLoss()
# Optimizer: AdamW with proper learning rate and weight decay
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=100, eta_min=1e-6
)
# Training loop
best_val_acc = 0.0
for epoch in range(100):
# Training
model.train() # Fixed: was model.eval()
running_loss = 0.0
correct = 0
total = 0
for inputs, targets in train_loader:
optimizer.zero_grad() # Fixed: was missing
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
# Gradient clipping for safety
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
train_loss = running_loss / len(train_loader)
train_acc = 100.0 * correct / total
# Validation
model.eval()
val_loss = 0.0
val_correct = 0
val_total = 0
with torch.no_grad():
for inputs, targets in val_loader:
outputs = model(inputs)
loss = criterion(outputs, targets)
val_loss += loss.item()
_, predicted = outputs.max(1)
val_total += targets.size(0)
val_correct += predicted.eq(targets).sum().item()
val_loss /= len(val_loader)
val_acc = 100.0 * val_correct / val_total
scheduler.step()
print(
f"Epoch {epoch}: train_loss={train_loss:.4f}, train_acc={train_acc:.1f}%, "
f"val_loss={val_loss:.4f}, val_acc={val_acc:.1f}%"
)
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), "best_model.pt")
Step 4: Verifying the Fix
We re-run the overfit-one-batch test with the fixed model:
Step 0: loss=3.9120, acc=0.0312
Step 50: loss=2.1043, acc=0.4375
Step 100: loss=0.4521, acc=0.8750
Step 150: loss=0.0312, acc=1.0000
Step 200: loss=0.0041, acc=1.0000
The model now rapidly overfits a single batch, confirming the pipeline is correct.
Step 5: Monitoring the Full Training Run
With the fixed code, training on the full dataset produces healthy loss curves:
Epoch 0: train_loss=3.4521, train_acc=8.2%, val_loss=3.1204, val_acc=12.5%
Epoch 5: train_loss=2.1043, train_acc=38.4%, val_loss=2.2156, val_acc=35.1%
Epoch 10: train_loss=1.4521, train_acc=56.7%, val_loss=1.6204, val_acc=51.3%
Epoch 20: train_loss=0.8043, train_acc=74.2%, val_loss=1.0156, val_acc=68.5%
Epoch 50: train_loss=0.3521, train_acc=89.1%, val_loss=0.6204, val_acc=82.3%
Epoch 100: train_loss=0.1243, train_acc=96.4%, val_loss=0.5812, val_acc=84.7%
Validation accuracy improved from 12% to 84.7%---a transformation driven entirely by fixing bugs and applying standard training practices.
Lessons Learned
-
Always run the overfit-one-batch test first. It is the fastest way to distinguish bugs from hyperparameter issues.
-
Never apply softmax before CrossEntropyLoss. This is perhaps the single most common PyTorch bug.
-
Sigmoid activations cause vanishing gradients in deep networks. Use ReLU-family activations with He initialization.
-
Missing
zero_grad()causes silent gradient accumulation. The model may still "train" but with degraded performance. -
Calling
model.eval()during training disables batch norm and dropout. Always verify the train/eval mode. -
A systematic checklist catches bugs faster than guessing. Work through the debugging checklist in Section 12.9.4 methodically.
-
Multiple bugs can mask each other. The original code had 7 bugs that interacted in complex ways. Fix them one at a time and verify each fix.
Discussion Questions
-
If the colleague had only one of these bugs (e.g., only the softmax-before-CrossEntropyLoss issue), would the model still achieve reasonable accuracy? Why or why not?
-
The fixed model shows a gap between training accuracy (96.4%) and validation accuracy (84.7%). What additional techniques from this chapter could reduce this gap?
-
How would your debugging approach differ if the loss was decreasing but validation accuracy was not improving?
-
The original code used SGD with lr=0.0001. With the fixed architecture (ReLU, batch norm), would SGD with momentum and a proper learning rate (e.g., 0.1 with step decay) achieve comparable results to AdamW?