Case Study 1: Building a DCGAN for Image Generation

Overview

In this case study, we build a complete Deep Convolutional GAN (DCGAN) pipeline for generating handwritten digit images. We follow the architectural guidelines from Radford et al. (2016), implement proper weight initialization, handle the training loop with careful monitoring, and evaluate the quality of generated samples over the course of training.

This case study demonstrates the practical challenges of GAN training: balancing the generator and discriminator, diagnosing training failure modes, and knowing when training has produced good results.


Problem Definition

Task: Train a DCGAN to generate realistic $28 \times 28$ grayscale images of handwritten digits.

Dataset: MNIST (60,000 training images, no labels used).

Success criteria: Generated images should be visually recognizable as digits, with diversity across all 10 digit classes.


Architecture Design

Generator

The generator maps a 100-dimensional noise vector to a $1 \times 28 \times 28$ image. We adapt the DCGAN architecture for the smaller MNIST image size:

"""DCGAN Generator for MNIST."""

import torch
import torch.nn as nn

torch.manual_seed(42)


class DCGANGenerator(nn.Module):
    """DCGAN generator producing 28x28 grayscale images.

    Architecture: z(100) -> 256x7x7 -> 128x14x14 -> 1x28x28

    Args:
        latent_dim: Dimension of the input noise vector.
        feature_maps: Base number of feature maps.
    """

    def __init__(
        self, latent_dim: int = 100, feature_maps: int = 256
    ) -> None:
        super().__init__()
        self.latent_dim = latent_dim

        self.main = nn.Sequential(
            # Project and reshape: 100 -> 256*7*7
            nn.Linear(latent_dim, feature_maps * 7 * 7),
            nn.BatchNorm1d(feature_maps * 7 * 7),
            nn.ReLU(True),

            # Reshape handled in forward()

            # 256x7x7 -> 128x14x14
            nn.ConvTranspose2d(
                feature_maps, feature_maps // 2,
                kernel_size=4, stride=2, padding=1, bias=False
            ),
            nn.BatchNorm2d(feature_maps // 2),
            nn.ReLU(True),

            # 128x14x14 -> 1x28x28
            nn.ConvTranspose2d(
                feature_maps // 2, 1,
                kernel_size=4, stride=2, padding=1, bias=False
            ),
            nn.Tanh(),
        )

        self._fc = self.main[0]
        self._bn = self.main[1]
        self._relu = self.main[2]
        self._conv_layers = self.main[3:]
        self._feature_maps = feature_maps

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        """Generate images from noise vectors.

        Args:
            z: Noise tensor of shape (batch_size, latent_dim).

        Returns:
            Generated images of shape (batch_size, 1, 28, 28).
        """
        h = self._relu(self._bn(self._fc(z)))
        h = h.view(-1, self._feature_maps, 7, 7)
        return self._conv_layers(h)

Discriminator

The discriminator mirrors the generator, using strided convolutions to downsample:

class DCGANDiscriminator(nn.Module):
    """DCGAN discriminator for 28x28 grayscale images.

    Architecture: 1x28x28 -> 64x14x14 -> 128x7x7 -> 1

    Args:
        feature_maps: Base number of feature maps.
    """

    def __init__(self, feature_maps: int = 64) -> None:
        super().__init__()
        self.main = nn.Sequential(
            # 1x28x28 -> 64x14x14 (no BatchNorm on first layer)
            nn.Conv2d(1, feature_maps, 4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            # 64x14x14 -> 128x7x7
            nn.Conv2d(
                feature_maps, feature_maps * 2,
                4, stride=2, padding=1, bias=False
            ),
            nn.BatchNorm2d(feature_maps * 2),
            nn.LeakyReLU(0.2, inplace=True),

            # 128x7x7 -> 1
            nn.Flatten(),
            nn.Linear(feature_maps * 2 * 7 * 7, 1),
            nn.Sigmoid(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Classify images as real or fake.

        Args:
            x: Image tensor of shape (batch_size, 1, 28, 28).

        Returns:
            Probability of being real, shape (batch_size, 1).
        """
        return self.main(x)

Weight Initialization

Following DCGAN best practices:

def weights_init(m: nn.Module) -> None:
    """Initialize weights from N(0, 0.02) as recommended by DCGAN."""
    classname = m.__class__.__name__
    if "Conv" in classname:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif "BatchNorm" in classname:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
    elif "Linear" in classname:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)

Training

Training Loop

"""DCGAN training loop with monitoring."""

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Data: normalize to [-1, 1] to match Tanh output
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

train_dataset = datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
train_loader = DataLoader(
    train_dataset, batch_size=128, shuffle=True, drop_last=True
)

# Models
generator = DCGANGenerator(latent_dim=100)
discriminator = DCGANDiscriminator()
generator.apply(weights_init)
discriminator.apply(weights_init)

# Optimizers: Adam with beta1=0.5 (DCGAN recommendation)
optimizer_g = torch.optim.Adam(
    generator.parameters(), lr=2e-4, betas=(0.5, 0.999)
)
optimizer_d = torch.optim.Adam(
    discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999)
)

criterion = nn.BCELoss()

# Fixed noise for visualization
fixed_noise = torch.randn(64, 100)

# Training
n_epochs = 25
g_losses = []
d_losses = []

for epoch in range(n_epochs):
    for i, (real_images, _) in enumerate(train_loader):
        batch_size = real_images.size(0)
        real_labels = torch.ones(batch_size, 1) * 0.9  # Label smoothing
        fake_labels = torch.zeros(batch_size, 1)

        # --- Train Discriminator ---
        optimizer_d.zero_grad()

        # Real images
        output_real = discriminator(real_images)
        loss_real = criterion(output_real, real_labels)

        # Fake images
        noise = torch.randn(batch_size, 100)
        fake_images = generator(noise)
        output_fake = discriminator(fake_images.detach())
        loss_fake = criterion(output_fake, fake_labels)

        loss_d = loss_real + loss_fake
        loss_d.backward()
        optimizer_d.step()

        # --- Train Generator ---
        optimizer_g.zero_grad()
        output_fake = discriminator(fake_images)
        loss_g = criterion(output_fake, torch.ones(batch_size, 1))
        loss_g.backward()
        optimizer_g.step()

        g_losses.append(loss_g.item())
        d_losses.append(loss_d.item())

    # End-of-epoch report
    print(
        f"Epoch {epoch+1}/{n_epochs} | "
        f"D Loss: {d_losses[-1]:.4f} | "
        f"G Loss: {g_losses[-1]:.4f} | "
        f"D(real): {output_real.mean():.3f} | "
        f"D(fake): {output_fake.mean():.3f}"
    )

Results Analysis

Training Progression

Over 25 epochs, we typically observe:

  • Epoch 1: Generated images are random noise.
  • Epoch 5: Vague digit-like shapes begin to emerge.
  • Epoch 10: Most generated images are recognizable as digits.
  • Epoch 25: Clear, sharp digits with good diversity.

Monitoring Signals

Healthy training indicators: - Discriminator loss stays between 0.5 and 1.5. - Generator loss decreases gradually but may oscillate. - $D(\text{real})$ stays near 0.7--0.9 (not 1.0, thanks to label smoothing). - $D(\text{fake})$ gradually increases from near 0 toward 0.3--0.5.

Warning signs: - $D(\text{real}) \to 1.0$ and $D(\text{fake}) \to 0.0$: Discriminator too strong. - Generator loss spikes repeatedly: Training instability. - All generated images look identical: Mode collapse.

Dealing with Failure

If training fails: 1. Reduce the discriminator's learning rate. 2. Increase label smoothing (e.g., 0.8 instead of 0.9). 3. Add instance noise to the discriminator's input. 4. Switch to WGAN-GP for more stable training.


Key Takeaways

  1. DCGAN's architectural guidelines (strided convolutions, BatchNorm, Tanh/LeakyReLU) are essential for stable training.
  2. Weight initialization from $\mathcal{N}(0, 0.02)$ prevents early training instabilities.
  3. Label smoothing (0.9 instead of 1.0 for real labels) is a simple but effective stabilization technique.
  4. Training monitoring requires tracking both losses AND visual inspection of generated samples.
  5. The discriminator's $D(\text{real})$ and $D(\text{fake})$ scores provide more diagnostic information than the raw losses.