Case Study 1: Building a DCGAN for Image Generation
Overview
In this case study, we build a complete Deep Convolutional GAN (DCGAN) pipeline for generating handwritten digit images. We follow the architectural guidelines from Radford et al. (2016), implement proper weight initialization, handle the training loop with careful monitoring, and evaluate the quality of generated samples over the course of training.
This case study demonstrates the practical challenges of GAN training: balancing the generator and discriminator, diagnosing training failure modes, and knowing when training has produced good results.
Problem Definition
Task: Train a DCGAN to generate realistic $28 \times 28$ grayscale images of handwritten digits.
Dataset: MNIST (60,000 training images, no labels used).
Success criteria: Generated images should be visually recognizable as digits, with diversity across all 10 digit classes.
Architecture Design
Generator
The generator maps a 100-dimensional noise vector to a $1 \times 28 \times 28$ image. We adapt the DCGAN architecture for the smaller MNIST image size:
"""DCGAN Generator for MNIST."""
import torch
import torch.nn as nn
torch.manual_seed(42)
class DCGANGenerator(nn.Module):
"""DCGAN generator producing 28x28 grayscale images.
Architecture: z(100) -> 256x7x7 -> 128x14x14 -> 1x28x28
Args:
latent_dim: Dimension of the input noise vector.
feature_maps: Base number of feature maps.
"""
def __init__(
self, latent_dim: int = 100, feature_maps: int = 256
) -> None:
super().__init__()
self.latent_dim = latent_dim
self.main = nn.Sequential(
# Project and reshape: 100 -> 256*7*7
nn.Linear(latent_dim, feature_maps * 7 * 7),
nn.BatchNorm1d(feature_maps * 7 * 7),
nn.ReLU(True),
# Reshape handled in forward()
# 256x7x7 -> 128x14x14
nn.ConvTranspose2d(
feature_maps, feature_maps // 2,
kernel_size=4, stride=2, padding=1, bias=False
),
nn.BatchNorm2d(feature_maps // 2),
nn.ReLU(True),
# 128x14x14 -> 1x28x28
nn.ConvTranspose2d(
feature_maps // 2, 1,
kernel_size=4, stride=2, padding=1, bias=False
),
nn.Tanh(),
)
self._fc = self.main[0]
self._bn = self.main[1]
self._relu = self.main[2]
self._conv_layers = self.main[3:]
self._feature_maps = feature_maps
def forward(self, z: torch.Tensor) -> torch.Tensor:
"""Generate images from noise vectors.
Args:
z: Noise tensor of shape (batch_size, latent_dim).
Returns:
Generated images of shape (batch_size, 1, 28, 28).
"""
h = self._relu(self._bn(self._fc(z)))
h = h.view(-1, self._feature_maps, 7, 7)
return self._conv_layers(h)
Discriminator
The discriminator mirrors the generator, using strided convolutions to downsample:
class DCGANDiscriminator(nn.Module):
"""DCGAN discriminator for 28x28 grayscale images.
Architecture: 1x28x28 -> 64x14x14 -> 128x7x7 -> 1
Args:
feature_maps: Base number of feature maps.
"""
def __init__(self, feature_maps: int = 64) -> None:
super().__init__()
self.main = nn.Sequential(
# 1x28x28 -> 64x14x14 (no BatchNorm on first layer)
nn.Conv2d(1, feature_maps, 4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# 64x14x14 -> 128x7x7
nn.Conv2d(
feature_maps, feature_maps * 2,
4, stride=2, padding=1, bias=False
),
nn.BatchNorm2d(feature_maps * 2),
nn.LeakyReLU(0.2, inplace=True),
# 128x7x7 -> 1
nn.Flatten(),
nn.Linear(feature_maps * 2 * 7 * 7, 1),
nn.Sigmoid(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Classify images as real or fake.
Args:
x: Image tensor of shape (batch_size, 1, 28, 28).
Returns:
Probability of being real, shape (batch_size, 1).
"""
return self.main(x)
Weight Initialization
Following DCGAN best practices:
def weights_init(m: nn.Module) -> None:
"""Initialize weights from N(0, 0.02) as recommended by DCGAN."""
classname = m.__class__.__name__
if "Conv" in classname:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif "BatchNorm" in classname:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
elif "Linear" in classname:
nn.init.normal_(m.weight.data, 0.0, 0.02)
if m.bias is not None:
nn.init.constant_(m.bias.data, 0)
Training
Training Loop
"""DCGAN training loop with monitoring."""
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# Data: normalize to [-1, 1] to match Tanh output
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
])
train_dataset = datasets.MNIST(
root="./data", train=True, download=True, transform=transform
)
train_loader = DataLoader(
train_dataset, batch_size=128, shuffle=True, drop_last=True
)
# Models
generator = DCGANGenerator(latent_dim=100)
discriminator = DCGANDiscriminator()
generator.apply(weights_init)
discriminator.apply(weights_init)
# Optimizers: Adam with beta1=0.5 (DCGAN recommendation)
optimizer_g = torch.optim.Adam(
generator.parameters(), lr=2e-4, betas=(0.5, 0.999)
)
optimizer_d = torch.optim.Adam(
discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999)
)
criterion = nn.BCELoss()
# Fixed noise for visualization
fixed_noise = torch.randn(64, 100)
# Training
n_epochs = 25
g_losses = []
d_losses = []
for epoch in range(n_epochs):
for i, (real_images, _) in enumerate(train_loader):
batch_size = real_images.size(0)
real_labels = torch.ones(batch_size, 1) * 0.9 # Label smoothing
fake_labels = torch.zeros(batch_size, 1)
# --- Train Discriminator ---
optimizer_d.zero_grad()
# Real images
output_real = discriminator(real_images)
loss_real = criterion(output_real, real_labels)
# Fake images
noise = torch.randn(batch_size, 100)
fake_images = generator(noise)
output_fake = discriminator(fake_images.detach())
loss_fake = criterion(output_fake, fake_labels)
loss_d = loss_real + loss_fake
loss_d.backward()
optimizer_d.step()
# --- Train Generator ---
optimizer_g.zero_grad()
output_fake = discriminator(fake_images)
loss_g = criterion(output_fake, torch.ones(batch_size, 1))
loss_g.backward()
optimizer_g.step()
g_losses.append(loss_g.item())
d_losses.append(loss_d.item())
# End-of-epoch report
print(
f"Epoch {epoch+1}/{n_epochs} | "
f"D Loss: {d_losses[-1]:.4f} | "
f"G Loss: {g_losses[-1]:.4f} | "
f"D(real): {output_real.mean():.3f} | "
f"D(fake): {output_fake.mean():.3f}"
)
Results Analysis
Training Progression
Over 25 epochs, we typically observe:
- Epoch 1: Generated images are random noise.
- Epoch 5: Vague digit-like shapes begin to emerge.
- Epoch 10: Most generated images are recognizable as digits.
- Epoch 25: Clear, sharp digits with good diversity.
Monitoring Signals
Healthy training indicators: - Discriminator loss stays between 0.5 and 1.5. - Generator loss decreases gradually but may oscillate. - $D(\text{real})$ stays near 0.7--0.9 (not 1.0, thanks to label smoothing). - $D(\text{fake})$ gradually increases from near 0 toward 0.3--0.5.
Warning signs: - $D(\text{real}) \to 1.0$ and $D(\text{fake}) \to 0.0$: Discriminator too strong. - Generator loss spikes repeatedly: Training instability. - All generated images look identical: Mode collapse.
Dealing with Failure
If training fails: 1. Reduce the discriminator's learning rate. 2. Increase label smoothing (e.g., 0.8 instead of 0.9). 3. Add instance noise to the discriminator's input. 4. Switch to WGAN-GP for more stable training.
Key Takeaways
- DCGAN's architectural guidelines (strided convolutions, BatchNorm, Tanh/LeakyReLU) are essential for stable training.
- Weight initialization from $\mathcal{N}(0, 0.02)$ prevents early training instabilities.
- Label smoothing (0.9 instead of 1.0 for real labels) is a simple but effective stabilization technique.
- Training monitoring requires tracking both losses AND visual inspection of generated samples.
- The discriminator's $D(\text{real})$ and $D(\text{fake})$ scores provide more diagnostic information than the raw losses.