Case Study 1: Diagnosing and Fixing a Failing Training Run

Overview

You are an ML engineer at a startup building a document classification system. A junior colleague has trained a deep neural network to classify business documents into 50 categories, but the model is performing poorly---validation accuracy is stuck at 12% (barely above random for 50 classes). Your task is to diagnose the issues and systematically fix them.

This case study walks through a realistic debugging scenario, applying the diagnostic techniques from Section 12.9 to identify and resolve multiple interacting problems.


The Initial Setup

The colleague's code looks like this:

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# Model: 6-layer MLP for document classification
class DocumentClassifier(nn.Module):
    """A deep MLP for document classification.

    Args:
        input_dim: Dimension of TF-IDF features.
        num_classes: Number of document categories.
    """

    def __init__(self, input_dim: int = 5000, num_classes: int = 50):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, 2048),
            nn.Sigmoid(),
            nn.Linear(2048, 1024),
            nn.Sigmoid(),
            nn.Linear(1024, 512),
            nn.Sigmoid(),
            nn.Linear(512, 256),
            nn.Sigmoid(),
            nn.Linear(256, 128),
            nn.Sigmoid(),
            nn.Linear(128, num_classes),
            nn.Softmax(dim=1),  # Problem!
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)

# Training
model = DocumentClassifier()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)

for epoch in range(100):
    model.eval()  # Problem!
    for inputs, targets in train_loader:
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
    # No validation evaluation

The colleague reports: "The loss barely decreases and validation accuracy is stuck at 12%."


Step 1: The Overfit-One-Batch Test

Before analyzing the full training run, we apply the overfit-one-batch test (Section 12.9.3):

torch.manual_seed(42)

model = DocumentClassifier()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)

batch = next(iter(train_loader))
inputs, targets = batch

model.train()
for step in range(500):
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

    if step % 50 == 0:
        _, predicted = outputs.max(1)
        acc = predicted.eq(targets).float().mean().item()
        print(f"Step {step}: loss={loss.item():.6f}, acc={acc:.4f}")

Output:

Step 0: loss=3.912023, acc=0.0312
Step 50: loss=3.911998, acc=0.0312
Step 100: loss=3.911985, acc=0.0312
Step 150: loss=3.911979, acc=0.0312
Step 200: loss=3.911975, acc=0.0312
...

The loss is barely moving and accuracy is not improving. The model cannot even overfit a single batch---this confirms there are fundamental bugs.


Step 2: Systematic Bug Identification

We go through the debugging checklist from Section 12.9.4:

Bug 1: Softmax Before CrossEntropyLoss

The model applies nn.Softmax(dim=1) as the final layer, then passes the output to nn.CrossEntropyLoss. As discussed in Section 12.1.2, CrossEntropyLoss internally applies log-softmax to its input. Applying softmax first means the loss computes log(softmax(softmax(z))), which compresses the logits and produces near-zero gradients.

Fix: Remove the Softmax layer from the model. CrossEntropyLoss expects raw logits.

Bug 2: Sigmoid Activations with No Initialization Strategy

The model uses sigmoid activations throughout a 6-layer network. Sigmoid squashes all values to [0, 1], and its gradient $\sigma'(z) = \sigma(z)(1-\sigma(z))$ has a maximum of 0.25 at $z=0$. In a 6-layer network, the gradient is multiplied by at most $0.25^6 \approx 0.000244$ on the way back, causing severe vanishing gradients.

Additionally, there is no explicit weight initialization---PyTorch's default (Kaiming uniform) is designed for ReLU, not sigmoid.

Fix: Replace sigmoid with ReLU and use He initialization.

Bug 3: model.eval() During Training

The training loop calls model.eval() instead of model.train(). While this model does not currently use batch normalization or dropout, this is still a serious bug that would prevent normalization layers from working correctly if added later.

Fix: Call model.train() before the training loop.

Bug 4: Missing optimizer.zero_grad()

The training loop never calls optimizer.zero_grad(). Gradients accumulate across all batches indefinitely, leading to exploding effective gradients that counterintuitively produce no useful learning because the gradient direction becomes dominated by old, irrelevant batches.

Fix: Call optimizer.zero_grad() before each forward pass.

Bug 5: Learning Rate Too Low

With a learning rate of 0.0001 and plain SGD (no momentum), the parameter updates are tiny. Combined with the vanishing gradient problem from sigmoid, the effective update magnitude is negligible.

Fix: Use a more appropriate optimizer (AdamW) with a reasonable learning rate (e.g., 3e-4), or use SGD with momentum and a higher learning rate.

Bug 6: No Normalization

A 6-layer network without normalization layers is prone to internal covariate shift, especially without careful initialization.

Fix: Add batch normalization or layer normalization between layers.

Bug 7: No Validation Monitoring

The code has no validation evaluation, making it impossible to detect overfitting or assess true model performance.

Fix: Add periodic validation evaluation.


Step 3: The Fixed Model

Here is the corrected code with all bugs fixed:

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

torch.manual_seed(42)


class DocumentClassifier(nn.Module):
    """A deep MLP for document classification.

    Args:
        input_dim: Dimension of TF-IDF features.
        num_classes: Number of document categories.
        dropout_rate: Dropout probability.
    """

    def __init__(
        self,
        input_dim: int = 5000,
        num_classes: int = 50,
        dropout_rate: float = 0.3,
    ):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, 2048),
            nn.BatchNorm1d(2048),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(2048, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(128, num_classes),
            # No softmax---CrossEntropyLoss handles it
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass.

        Args:
            x: Input features of shape (batch_size, input_dim).

        Returns:
            Raw logits of shape (batch_size, num_classes).
        """
        return self.layers(x)


def init_weights(module: nn.Module) -> None:
    """Initialize weights with He initialization for ReLU.

    Args:
        module: A PyTorch module.
    """
    if isinstance(module, nn.Linear):
        nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.BatchNorm1d):
        nn.init.ones_(module.weight)
        nn.init.zeros_(module.bias)


# Create model
model = DocumentClassifier()
model.apply(init_weights)

# Loss function (no softmax in model, CrossEntropyLoss handles it)
criterion = nn.CrossEntropyLoss()

# Optimizer: AdamW with proper learning rate and weight decay
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=100, eta_min=1e-6
)

# Training loop
best_val_acc = 0.0

for epoch in range(100):
    # Training
    model.train()  # Fixed: was model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, targets in train_loader:
        optimizer.zero_grad()  # Fixed: was missing

        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()

        # Gradient clipping for safety
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    train_loss = running_loss / len(train_loader)
    train_acc = 100.0 * correct / total

    # Validation
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for inputs, targets in val_loader:
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            val_total += targets.size(0)
            val_correct += predicted.eq(targets).sum().item()

    val_loss /= len(val_loader)
    val_acc = 100.0 * val_correct / val_total

    scheduler.step()

    print(
        f"Epoch {epoch}: train_loss={train_loss:.4f}, train_acc={train_acc:.1f}%, "
        f"val_loss={val_loss:.4f}, val_acc={val_acc:.1f}%"
    )

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_model.pt")

Step 4: Verifying the Fix

We re-run the overfit-one-batch test with the fixed model:

Step 0: loss=3.9120, acc=0.0312
Step 50: loss=2.1043, acc=0.4375
Step 100: loss=0.4521, acc=0.8750
Step 150: loss=0.0312, acc=1.0000
Step 200: loss=0.0041, acc=1.0000

The model now rapidly overfits a single batch, confirming the pipeline is correct.


Step 5: Monitoring the Full Training Run

With the fixed code, training on the full dataset produces healthy loss curves:

Epoch 0: train_loss=3.4521, train_acc=8.2%, val_loss=3.1204, val_acc=12.5%
Epoch 5: train_loss=2.1043, train_acc=38.4%, val_loss=2.2156, val_acc=35.1%
Epoch 10: train_loss=1.4521, train_acc=56.7%, val_loss=1.6204, val_acc=51.3%
Epoch 20: train_loss=0.8043, train_acc=74.2%, val_loss=1.0156, val_acc=68.5%
Epoch 50: train_loss=0.3521, train_acc=89.1%, val_loss=0.6204, val_acc=82.3%
Epoch 100: train_loss=0.1243, train_acc=96.4%, val_loss=0.5812, val_acc=84.7%

Validation accuracy improved from 12% to 84.7%---a transformation driven entirely by fixing bugs and applying standard training practices.


Lessons Learned

  1. Always run the overfit-one-batch test first. It is the fastest way to distinguish bugs from hyperparameter issues.

  2. Never apply softmax before CrossEntropyLoss. This is perhaps the single most common PyTorch bug.

  3. Sigmoid activations cause vanishing gradients in deep networks. Use ReLU-family activations with He initialization.

  4. Missing zero_grad() causes silent gradient accumulation. The model may still "train" but with degraded performance.

  5. Calling model.eval() during training disables batch norm and dropout. Always verify the train/eval mode.

  6. A systematic checklist catches bugs faster than guessing. Work through the debugging checklist in Section 12.9.4 methodically.

  7. Multiple bugs can mask each other. The original code had 7 bugs that interacted in complex ways. Fix them one at a time and verify each fix.


Discussion Questions

  1. If the colleague had only one of these bugs (e.g., only the softmax-before-CrossEntropyLoss issue), would the model still achieve reasonable accuracy? Why or why not?

  2. The fixed model shows a gap between training accuracy (96.4%) and validation accuracy (84.7%). What additional techniques from this chapter could reduce this gap?

  3. How would your debugging approach differ if the loss was decreasing but validation accuracy was not improving?

  4. The original code used SGD with lr=0.0001. With the fixed architecture (ReLU, batch norm), would SGD with momentum and a proper learning rate (e.g., 0.1 with step decay) achieve comparable results to AdamW?