Case Study 1: Anomaly Detection with Autoencoders

Overview

In this case study, we build a complete anomaly detection system using autoencoders. The core idea is simple but powerful: train an autoencoder to reconstruct "normal" data, then use high reconstruction error as a signal that an input is anomalous. This approach is widely used in manufacturing (detecting defective products), cybersecurity (identifying unusual network traffic), healthcare (flagging abnormal medical scans), and finance (spotting fraudulent transactions).

We use the MNIST dataset as a controlled testbed. We train the autoencoder on a subset of digit classes (the "normal" distribution) and then evaluate its ability to detect digits from classes it has never seen (the "anomalies"). This setup mirrors real-world scenarios where we have abundant examples of normal behavior but few or no examples of anomalies.


Problem Definition

Task: Given an image of a handwritten digit, determine whether it belongs to the "normal" set (digits 0--4) or is anomalous (digits 5--9).

Training data: Only images of digits 0--4 (no anomalies during training).

Evaluation: The model must assign an anomaly score to each test image. We evaluate using ROC-AUC and precision-recall metrics.

Why this is challenging: The model never sees anomalies during training. It must learn what "normal" looks like well enough that anything different stands out.


Approach

Step 1: Data Preparation

We split MNIST into normal (classes 0--4) and anomalous (classes 5--9). The autoencoder is trained exclusively on normal data. At test time, we evaluate on both normal and anomalous examples.

"""Anomaly detection data preparation.

Splits MNIST into normal (digits 0-4) and anomalous (digits 5-9).
"""

import torch
from torchvision import datasets, transforms

torch.manual_seed(42)

transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
test_dataset = datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)

# Filter training data to normal classes only
normal_classes = [0, 1, 2, 3, 4]
normal_mask = torch.tensor(
    [label in normal_classes for label in train_dataset.targets]
)
train_normal = torch.utils.data.Subset(
    train_dataset, torch.where(normal_mask)[0]
)

print(f"Training samples (normal only): {len(train_normal)}")
print(f"Test samples (all classes): {len(test_dataset)}")

Step 2: Autoencoder Architecture

We use a convolutional autoencoder that can efficiently capture spatial patterns in images. The encoder compresses the $28 \times 28$ input into a 32-dimensional latent code, and the decoder reconstructs it.

"""Convolutional autoencoder for anomaly detection."""

import torch.nn as nn


class AnomalyAutoencoder(nn.Module):
    """Convolutional autoencoder optimized for anomaly detection.

    Uses a moderate bottleneck to ensure anomalous inputs
    that differ from normal training data produce high
    reconstruction error.
    """

    def __init__(self, latent_dim: int = 32) -> None:
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, latent_dim),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 64 * 7 * 7),
            nn.ReLU(),
            nn.Unflatten(1, (64, 7, 7)),
            nn.ConvTranspose2d(
                64, 32, 3, stride=2, padding=1, output_padding=1
            ),
            nn.ReLU(),
            nn.ConvTranspose2d(
                32, 1, 3, stride=2, padding=1, output_padding=1
            ),
            nn.Sigmoid(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z = self.encoder(x)
        return self.decoder(z)

Step 3: Training

We train with MSE reconstruction loss on normal data only. The model learns the manifold of normal digits.

"""Training loop for anomaly detection autoencoder."""

from torch.utils.data import DataLoader

torch.manual_seed(42)

model = AnomalyAutoencoder(latent_dim=32)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

train_loader = DataLoader(train_normal, batch_size=128, shuffle=True)

for epoch in range(20):
    total_loss = 0.0
    for batch_images, _ in train_loader:
        reconstructed = model(batch_images)
        loss = criterion(reconstructed, batch_images)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * batch_images.size(0)

    avg_loss = total_loss / len(train_normal)
    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1}/20, Loss: {avg_loss:.6f}")

Step 4: Anomaly Scoring

The anomaly score for each test image is its per-pixel reconstruction error. Normal images (similar to training data) should reconstruct well (low error), while anomalous images should reconstruct poorly (high error).

"""Compute anomaly scores for all test images."""

import numpy as np

model.eval()
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

all_scores = []
all_labels = []  # 0 = normal, 1 = anomaly

with torch.no_grad():
    for images, targets in test_loader:
        reconstructed = model(images)
        # Per-sample reconstruction error
        errors = ((images - reconstructed) ** 2).mean(dim=(1, 2, 3))
        all_scores.extend(errors.numpy())
        # Mark digits 5-9 as anomalous
        is_anomaly = torch.tensor(
            [1 if t.item() >= 5 else 0 for t in targets]
        )
        all_labels.extend(is_anomaly.numpy())

all_scores = np.array(all_scores)
all_labels = np.array(all_labels)

Step 5: Evaluation

We evaluate using standard anomaly detection metrics: ROC-AUC, precision-recall AUC, and detection rate at a fixed false positive rate.

"""Evaluate anomaly detection performance."""

from sklearn.metrics import roc_auc_score, average_precision_score

roc_auc = roc_auc_score(all_labels, all_scores)
pr_auc = average_precision_score(all_labels, all_scores)

# Detection rate at 5% false positive rate
threshold = np.percentile(
    all_scores[all_labels == 0], 95
)
detected = (all_scores[all_labels == 1] > threshold).mean()

print(f"ROC-AUC: {roc_auc:.4f}")
print(f"PR-AUC: {pr_auc:.4f}")
print(f"Detection rate at 5% FPR: {detected:.4f}")

Results and Analysis

Expected Performance

On the MNIST normal-vs-anomalous split, this autoencoder typically achieves: - ROC-AUC: 0.92--0.96 - Detection rate at 5% FPR: 0.75--0.85

The digits 5--9 have structural differences from 0--4 (different stroke patterns, curves, line orientations), so the autoencoder struggles to reconstruct them, producing high reconstruction error.

Per-Class Analysis

Not all anomalous digits are equally easy to detect: - Digit 9 is often the hardest to detect because it shares structural similarity with digit 4 (both have a vertical stroke with a top element). - Digit 8 can be confused with digit 0 (both are closed loops). - Digit 7 is typically easy to detect because its diagonal stroke is rare in digits 0--4.

This per-class analysis reveals an important principle: anomaly detection performance depends on the distance between the anomalous and normal distributions in feature space.

Choosing the Threshold

In practice, the detection threshold must balance false positives and false negatives based on the application's cost structure: - Manufacturing: Missing a defect (false negative) is costly. Set a low threshold to catch more anomalies, accepting more false positives. - Fraud detection: Investigating false positives is expensive. Set a higher threshold to ensure high precision.

Limitations and Extensions

This approach has several limitations: 1. Reconstruction error is not always the best anomaly score. Some anomalies may coincidentally reconstruct well. 2. The model assumes the training data is clean. Contaminated training data degrades performance. 3. Scalability: For high-resolution images, consider using a VAE (the KL divergence provides an additional anomaly signal) or a more sophisticated architecture.

Extensions include using the latent code distance from the training distribution (via a GMM or KDE fit in latent space) as an alternative anomaly score, or combining reconstruction error with latent space distance for improved detection.


Key Takeaways

  1. Autoencoders naturally provide anomaly scores via reconstruction error---no anomaly labels needed during training.
  2. The bottleneck dimension controls sensitivity: too large and the model reconstructs everything (including anomalies); too small and it fails to reconstruct normal data.
  3. Per-class analysis reveals which anomalies are hard to detect, guiding architecture improvements.
  4. Threshold selection is a business decision, not a purely technical one.
  5. This approach generalizes beyond MNIST to any domain where normal data is abundant but anomalies are rare or undefined.