Case Study 1: Building a Simple Diffusion Model

Overview

In this case study, you will build a complete Denoising Diffusion Probabilistic Model (DDPM) from scratch using PyTorch. Starting with the mathematical foundations covered in the chapter, you will implement the forward diffusion process, a U-Net denoising network, the training loop, and the sampling algorithm. By the end, you will have a working model that generates novel images from pure noise, trained on the MNIST dataset.

Problem Statement

Build and train an unconditional DDPM that can generate realistic handwritten digits. The model should learn the data distribution of MNIST images (28x28 grayscale) and produce diverse, high-quality samples through the iterative denoising process.

Approach

Step 1: Noise Schedule Implementation

We implement a linear noise schedule with $T = 1000$ timesteps:

  • $\beta_{\min} = 10^{-4}$, $\beta_{\max} = 0.02$
  • Pre-compute $\alpha_t$, $\bar{\alpha}_t$, and all derived quantities needed for training and sampling
  • Store these as registered buffers in a PyTorch module for device-agnostic computation

Step 2: Forward Process

The forward process adds noise to a clean image at any arbitrary timestep:

$$\mathbf{x}_t = \sqrt{\bar{\alpha}_t}\,\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t}\,\boldsymbol{\epsilon}$$

We verify the implementation by visualizing the same image at different noise levels, confirming that early timesteps preserve most of the signal while later timesteps approach pure Gaussian noise.

Step 3: U-Net Architecture

Our U-Net uses: - 4 resolution levels: 28 -> 14 -> 7 -> 3 (approximately) - Channel progression: 64 -> 128 -> 256 -> 256 - 2 residual blocks per resolution level - Sinusoidal timestep embeddings projected through a 2-layer MLP - Group normalization (8 groups) and SiLU activation throughout - Skip connections between corresponding encoder and decoder levels

The architecture has approximately 10 million parameters — small enough to train on a single GPU in under an hour.

Step 4: Training

Training configuration: - Optimizer: AdamW with learning rate $2 \times 10^{-4}$, weight decay $10^{-4}$ - Batch size: 128 - Epochs: 50 - Loss: Mean squared error between predicted and actual noise - EMA decay: 0.9999 for the model weights

At each training step, we: 1. Sample a batch of images from MNIST 2. Sample random timesteps uniformly from $\{1, \ldots, T\}$ 3. Sample random noise $\boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$ 4. Compute $\mathbf{x}_t$ using the closed-form forward process 5. Predict the noise using the U-Net: $\hat{\boldsymbol{\epsilon}} = \boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t)$ 6. Compute the loss $\|\boldsymbol{\epsilon} - \hat{\boldsymbol{\epsilon}}\|^2$ and backpropagate

Step 5: Sampling

We implement both DDPM and DDIM sampling:

DDPM (1000 steps): Start from $\mathbf{x}_T \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$ and iteratively denoise using the learned noise prediction and the posterior variance.

DDIM (50 steps): Use the DDIM update rule with $\sigma_t = 0$ (deterministic) and a subsequence of 50 uniformly spaced timesteps for 20x faster sampling.

Results

After 50 epochs of training:

  • Training loss: Converges to approximately 0.025 (MSE on noise prediction)
  • Sample quality: Generated digits are clearly recognizable and diverse, covering all 10 digit classes
  • DDPM vs. DDIM: 50-step DDIM produces samples that are visually indistinguishable from 1000-step DDPM
  • FID score: Approximately 5-8 on MNIST (competitive with simple GAN baselines)

Observations

  1. Early training (epochs 1-5): The model first learns to generate blob-like structures with roughly correct intensity distributions.
  2. Mid training (epochs 10-20): Global digit structure emerges — shapes are recognizable but edges are soft.
  3. Late training (epochs 30-50): Fine details sharpen — strokes become crisp and digit-specific features become distinct.

Key Lessons

  1. Pre-computation matters: Pre-computing all schedule-derived quantities ($\bar{\alpha}_t$, posterior variance, etc.) avoids redundant computation during training and speeds up the loop significantly.

  2. EMA is critical: The exponential moving average model produces noticeably cleaner samples than the raw training model, especially for longer training runs.

  3. Timestep embedding design: The sinusoidal embedding followed by MLP projection gives the model a smooth, continuous representation of the noise level. Using a simple integer embedding degrades quality significantly.

  4. DDIM enables practical generation: 1000-step DDPM sampling takes about 30 seconds per batch on a GPU. 50-step DDIM reduces this to 1.5 seconds with negligible quality loss.

  5. Simple architecture suffices for simple data: MNIST is low-resolution and relatively simple, so a modest U-Net with 10M parameters achieves good results. Higher-resolution, more complex datasets require proportionally larger models.

Extensions

  • Add class conditioning by embedding the digit label and injecting it alongside the timestep embedding
  • Implement classifier-free guidance for conditional generation
  • Replace the linear schedule with a cosine schedule and compare results
  • Scale up to CIFAR-10 (32x32 RGB) and observe the increased difficulty

Code Reference

The complete implementation is available in code/case-study-code.py.