39 min read

from the adversarial training framework and minimax formulation to DCGAN, Wasserstein GAN, conditional GANs, StyleGAN, and evaluation metrics for generative models."

from the adversarial training framework and minimax formulation to DCGAN, Wasserstein GAN, conditional GANs, StyleGAN, and evaluation metrics for generative models." prerequisites: - "Chapter 11: Neural Networks from Scratch (forward/backward pass, gradient descent)" - "Chapter 12: Training Deep Networks (optimizers, batch normalization, training stability)" - "Chapter 14: Convolutional Neural Networks (convolutions, transposed convolutions)" - "Chapter 16: Autoencoders and Representation Learning (generative models, latent spaces, VAEs)" - "Chapter 4: Probability, Statistics, and Information Theory (distributions, KL divergence, Jensen-Shannon divergence)" learning_objectives: - "Formulate the GAN training objective as a minimax game between generator and discriminator" - "Explain and diagnose training instability, mode collapse, and vanishing gradients in GAN training" - "Implement a DCGAN with architectural best practices for stable training" - "Derive the Wasserstein distance and explain how WGAN addresses GAN training problems" - "Implement conditional GANs for class-conditional generation" - "Evaluate generative models using FID and Inception Score" - "Describe the key innovations in StyleGAN and their effect on generation quality" key_terms: - generative adversarial network - generator - discriminator - adversarial training - minimax game - Nash equilibrium - mode collapse - training instability - DCGAN - Wasserstein distance - Earth Mover's distance - Lipschitz constraint - weight clipping - gradient penalty - spectral normalization - conditional GAN - class-conditional generation - StyleGAN - progressive growing - style mixing - Frechet Inception Distance (FID) - Inception Score (IS) estimated_time: "4-5 hours" difficulty: "Intermediate to Advanced"


Chapter 17: Generative Adversarial Networks

"The coolest idea in deep learning in the last 20 years." --- Yann LeCun, on GANs

In Chapter 16, we explored two approaches to generation: autoencoders that learn by reconstruction, and Variational Autoencoders that maximize a lower bound on the data likelihood. Both approaches are principled, but they share a common weakness: they tend to produce blurry outputs, hedging their bets by averaging over multiple plausible generations.

Generative Adversarial Networks (GANs) take a radically different approach. Instead of maximizing likelihood or minimizing reconstruction error, GANs set up a game between two neural networks: a generator that creates fake data, and a discriminator that tries to tell real from fake. Through this adversarial process, the generator learns to produce data so realistic that the discriminator cannot distinguish it from real examples.

This game-theoretic framework, introduced by Ian Goodfellow and colleagues in 2014, unleashed a revolution in generative modeling. GANs can produce photorealistic images, create art, generate synthetic training data, perform image-to-image translation, and much more. They also introduced deep challenges in training stability that spawned an enormous body of research on improved architectures and loss functions.

In this chapter, we build GANs from the ground up. We start with the mathematical framework, work through the training dynamics, implement DCGAN for stable image generation, derive the Wasserstein GAN that addresses fundamental training issues, add conditioning for controlled generation, survey StyleGAN's architectural innovations, and learn to evaluate generative models rigorously.


17.1 The Adversarial Training Framework

17.1.1 The Core Idea

Imagine a counterfeiter (generator) and a detective (discriminator). The counterfeiter's goal is to produce fake currency that looks real. The detective's goal is to distinguish fake bills from genuine ones. As the detective gets better at spotting fakes, the counterfeiter must improve the quality of the forgeries. This arms race drives both parties toward perfection.

Formally, a GAN consists of:

  • Generator $G_\theta: \mathbb{R}^k \to \mathbb{R}^d$: Takes a random noise vector $\mathbf{z} \sim p_z(\mathbf{z})$ (typically $\mathcal{N}(\mathbf{0}, \mathbf{I})$ or $\text{Uniform}(-1, 1)$) and maps it to a synthetic data point $G(\mathbf{z})$ in data space.
  • Discriminator $D_\phi: \mathbb{R}^d \to [0, 1]$: Takes a data point (real or generated) and outputs the probability that it is real.

The noise vector $\mathbf{z}$ lives in a latent space of dimension $k$ (similar to the latent space in VAEs from Chapter 16, but without an encoder to map data points to it).

17.1.2 The Minimax Objective

GANs are trained by solving a minimax game:

$$\min_G \max_D V(D, G) = \mathbb{E}_{\mathbf{x} \sim p_{\text{data}}} [\log D(\mathbf{x})] + \mathbb{E}_{\mathbf{z} \sim p_z} [\log(1 - D(G(\mathbf{z})))]$$

Let us unpack this:

  • Discriminator's goal (maximize $V$): Assign high probability to real data ($D(\mathbf{x}) \to 1$) and low probability to fake data ($D(G(\mathbf{z})) \to 0$). This maximizes both terms.

  • Generator's goal (minimize $V$): Fool the discriminator by generating data that receives high probability ($D(G(\mathbf{z})) \to 1$). This minimizes the second term (since $\log(1 - D(G(\mathbf{z}))) \to -\infty$ as $D(G(\mathbf{z})) \to 1$).

17.1.3 Optimal Discriminator

For a fixed generator $G$, the optimal discriminator is:

$$D^*(\mathbf{x}) = \frac{p_{\text{data}}(\mathbf{x})}{p_{\text{data}}(\mathbf{x}) + p_g(\mathbf{x})}$$

where $p_g$ is the distribution induced by the generator. This is a density ratio estimator: it outputs the probability that a sample came from the real distribution rather than the generator.

Proof sketch: The discriminator maximizes:

$$V(D) = \int \left[ p_{\text{data}}(\mathbf{x}) \log D(\mathbf{x}) + p_g(\mathbf{x}) \log(1 - D(\mathbf{x})) \right] d\mathbf{x}$$

Taking the derivative with respect to $D(\mathbf{x})$ for each $\mathbf{x}$ and setting it to zero yields the result.

17.1.4 Generator's Implicit Objective

Substituting the optimal discriminator $D^*$ into the value function:

$$V(D^*, G) = -\log 4 + 2 \cdot D_{\text{JS}}(p_{\text{data}} \| p_g)$$

where $D_{\text{JS}}$ is the Jensen-Shannon divergence:

$$D_{\text{JS}}(p \| q) = \frac{1}{2} D_{\text{KL}}\left(p \middle\| \frac{p+q}{2}\right) + \frac{1}{2} D_{\text{KL}}\left(q \middle\| \frac{p+q}{2}\right)$$

The JSD is always non-negative and equals zero if and only if $p = q$. Therefore, the global minimum of the minimax game is achieved when $p_g = p_{\text{data}}$---the generator perfectly reproduces the data distribution.

This is a beautiful theoretical result: GAN training implicitly minimizes the Jensen-Shannon divergence between the generated and real distributions.

Derivation of the JSD connection. Let us trace through the substitution step by step. With $D^*(\mathbf{x}) = \frac{p_{\text{data}}(\mathbf{x})}{p_{\text{data}}(\mathbf{x}) + p_g(\mathbf{x})}$, we substitute into the value function:

$$V(D^*, G) = \int p_{\text{data}}(\mathbf{x}) \log \frac{p_{\text{data}}(\mathbf{x})}{p_{\text{data}}(\mathbf{x}) + p_g(\mathbf{x})} d\mathbf{x} + \int p_g(\mathbf{x}) \log \frac{p_g(\mathbf{x})}{p_{\text{data}}(\mathbf{x}) + p_g(\mathbf{x})} d\mathbf{x}$$

Let $m(\mathbf{x}) = \frac{1}{2}(p_{\text{data}}(\mathbf{x}) + p_g(\mathbf{x}))$ be the mixture distribution. Then:

$$V(D^*, G) = \int p_{\text{data}} \log \frac{p_{\text{data}}}{2m} d\mathbf{x} + \int p_g \log \frac{p_g}{2m} d\mathbf{x}$$

$$= D_{\text{KL}}(p_{\text{data}} \| m) + D_{\text{KL}}(p_g \| m) - 2\log 2$$

$$= 2 \cdot D_{\text{JS}}(p_{\text{data}} \| p_g) - \log 4$$

Since $D_{\text{JS}} \geq 0$ with equality iff $p_{\text{data}} = p_g$, the minimum of $V(D^*, G)$ is $-\log 4$, achieved when the generator perfectly matches the data distribution.

Worked example: 1D Gaussian. Consider the simplest possible GAN: both $p_{\text{data}}$ and $p_g$ are 1D Gaussians. Let $p_{\text{data}} = \mathcal{N}(0, 1)$ and $p_g = \mathcal{N}(\mu, 1)$, where the generator controls $\mu$. The JSD between two Gaussians with the same variance but different means can be computed numerically, and it decreases monotonically as $\mu \to 0$. The GAN training process pushes $\mu$ toward 0, aligning the generated distribution with the data distribution. This toy example helps build intuition: the GAN is performing distribution matching, not point matching.

17.1.5 Training Algorithm

In practice, we alternate between updating the discriminator and the generator:

For each training iteration:
    1. Sample a minibatch of m real examples {x_1, ..., x_m}
    2. Sample a minibatch of m noise vectors {z_1, ..., z_m}
    3. Update discriminator by ascending its stochastic gradient:
       nabla_phi [1/m sum(log D(x_i) + log(1 - D(G(z_i))))]
    4. Sample a new minibatch of m noise vectors {z_1, ..., z_m}
    5. Update generator by descending its stochastic gradient:
       nabla_theta [1/m sum(log(1 - D(G(z_i))))]

A common practice is to update the discriminator $k$ times for each generator update (often $k = 1$, sometimes $k = 5$ for WGAN).

17.1.6 Non-Saturating Generator Loss

The original generator loss $\log(1 - D(G(\mathbf{z})))$ has a practical problem: when the discriminator easily identifies generated data (early in training), $D(G(\mathbf{z})) \approx 0$, and $\log(1 - D(G(\mathbf{z}))) \approx 0$. The gradient is near zero, so the generator learns nothing.

To see this quantitatively, the gradient of the original loss with respect to the generator parameters is:

$$\nabla_\theta \log(1 - D(G(\mathbf{z}))) = \frac{-\nabla_\theta D(G(\mathbf{z}))}{1 - D(G(\mathbf{z}))}$$

When $D(G(\mathbf{z})) \approx 0$ (discriminator is confident the sample is fake), the denominator is close to 1 and the numerator is small (the discriminator's gradient with respect to an obvious fake is small). This is the vanishing gradient problem.

The non-saturating alternative replaces the generator's objective:

$$\text{Original:} \quad \min_G \mathbb{E}_{\mathbf{z}} [\log(1 - D(G(\mathbf{z})))]$$ $$\text{Non-saturating:} \quad \max_G \mathbb{E}_{\mathbf{z}} [\log D(G(\mathbf{z}))]$$

The gradient of the non-saturating loss is:

$$\nabla_\theta \log D(G(\mathbf{z})) = \frac{\nabla_\theta D(G(\mathbf{z}))}{D(G(\mathbf{z}))}$$

When $D(G(\mathbf{z})) \approx 0$, the denominator is small, amplifying the gradient rather than suppressing it. This provides much stronger learning signals early in training.

Both objectives have the same fixed point (the optimal generator), but the non-saturating version provides stronger gradients early in training. Equivalently, the generator's loss is:

$$\mathcal{L}_G = -\mathbb{E}_{\mathbf{z}} [\log D(G(\mathbf{z}))]$$

This is the version used in virtually all practical GAN implementations. Note that this changes the nature of the game: it is no longer a true minimax game (the generator and discriminator optimize different objectives), which has implications for convergence analysis.

17.1.7 Alternative GAN Loss Functions

Beyond the original and non-saturating losses, several other loss functions have been proposed:

Least Squares GAN (LSGAN): Replaces the binary cross-entropy with a squared error:

$$\mathcal{L}_D = \frac{1}{2}\mathbb{E}_{\mathbf{x}}[(D(\mathbf{x}) - 1)^2] + \frac{1}{2}\mathbb{E}_{\mathbf{z}}[D(G(\mathbf{z}))^2]$$ $$\mathcal{L}_G = \frac{1}{2}\mathbb{E}_{\mathbf{z}}[(D(G(\mathbf{z})) - 1)^2]$$

The squared error penalizes generated samples that are far from the decision boundary, providing stronger gradients for samples that the discriminator confidently classifies as fake. LSGAN is related to minimizing the Pearson $\chi^2$ divergence between $p_{\text{data}}$ and $p_g$.

Hinge Loss GAN: Used in spectral normalization GAN (SN-GAN) and BigGAN:

$$\mathcal{L}_D = -\mathbb{E}_{\mathbf{x}}[\min(0, -1 + D(\mathbf{x}))] - \mathbb{E}_{\mathbf{z}}[\min(0, -1 - D(G(\mathbf{z})))]$$ $$\mathcal{L}_G = -\mathbb{E}_{\mathbf{z}}[D(G(\mathbf{z}))]$$

The hinge loss saturates once the discriminator is confident (score > 1 for real, score < -1 for fake), preventing the discriminator from becoming arbitrarily confident and providing more stable training.


17.2 Training Dynamics and Challenges

17.2.1 The Fragile Equilibrium and Nash Equilibrium

GAN training seeks a Nash equilibrium: a point where neither the generator nor the discriminator can improve unilaterally. Unlike standard optimization (minimizing a single loss), GAN training involves two players with competing objectives. This makes training fundamentally harder.

Formally, a Nash equilibrium of the GAN game is a pair $(\theta^*, \phi^*)$ such that:

$$V(D_{\phi^*}, G_{\theta^*}) \geq V(D_\phi, G_{\theta^*}) \quad \text{for all } \phi$$ $$V(D_{\phi^*}, G_{\theta^*}) \leq V(D_{\phi^*}, G_\theta) \quad \text{for all } \theta$$

In other words, the discriminator $D_{\phi^*}$ is the best response to the generator $G_{\theta^*}$, and vice versa. Goodfellow et al. (2014) proved that the unique Nash equilibrium of the GAN game (in function space, with unlimited capacity) is $p_g = p_{\text{data}}$ and $D^*(\mathbf{x}) = \frac{1}{2}$ everywhere.

However, gradient descent on minimax games does not always converge to Nash equilibria. Even for simple bilinear games $\min_x \max_y xy$, simultaneous gradient descent produces the updates $x_{t+1} = x_t - \eta y_t$ and $y_{t+1} = y_t + \eta x_t$. Writing this in matrix form:

$$\begin{bmatrix} x_{t+1} \\ y_{t+1} \end{bmatrix} = \begin{bmatrix} 1 & -\eta \\ \eta & 1 \end{bmatrix} \begin{bmatrix} x_t \\ y_t \end{bmatrix}$$

The eigenvalues of this update matrix are $1 \pm i\eta$, which have magnitude $\sqrt{1 + \eta^2} > 1$. The iterates spiral outward from the equilibrium $(0, 0)$ rather than converging to it. This is the fundamental challenge: gradient descent finds minima, not saddle points of minimax games.

Several techniques have been proposed to address this convergence issue:

  • Alternating optimization (updating D several times for each G update) partially addresses the problem by keeping D closer to its best response.
  • Gradient penalty methods (Section 17.4) modify the loss landscape to be more amenable to gradient-based optimization.
  • Spectral normalization constrains the discriminator's Lipschitz constant, preventing it from changing too rapidly.

17.2.2 Mode Collapse: Analysis and Solutions

Mode collapse is the most common GAN failure mode. Instead of learning the full data distribution, the generator discovers a few outputs that fool the discriminator and produces only those, ignoring the diversity of the real data.

For example, when training a GAN on MNIST, mode collapse might manifest as the generator producing only a single digit (say, all 1s) rather than all ten digits. The discriminator eventually catches on, the generator switches to a different digit, and the cycle continues without convergence.

Mode collapse occurs because the generator's objective does not explicitly reward diversity. The generator is rewarded for fooling the discriminator, and producing a single highly realistic sample can achieve this, at least temporarily.

Formal analysis. Consider a data distribution with $K$ modes (clusters): $p_{\text{data}} = \sum_{k=1}^{K} \pi_k \cdot p_k$. Full mode collapse occurs when $p_g$ concentrates on a single mode. Partial mode collapse occurs when $p_g$ covers only a subset of the modes. The key insight is that the JSD objective is symmetric in its treatment of modes: $D_{\text{JS}}$ can be low even if the generator covers only one mode perfectly, because the JSD between $p_g$ and $p_{\text{data}}$ depends on the quality of coverage within the modes it does cover, not on the number of modes covered.

Diagnosing mode collapse: - Visual inspection of generated samples for lack of diversity. - Compute the number of distinct modes in generated data (e.g., using a pretrained classifier). For MNIST, classify 10,000 generated samples and count how many of the 10 digit classes are represented. - Monitor the discriminator's loss: if it oscillates dramatically, mode collapse may be occurring. - Compute the reverse KL divergence between the generated and real class distributions. A uniform distribution over all classes indicates good diversity. - Track the birthday paradox test: generate many samples and measure how often near-duplicates appear. High duplication rates indicate low diversity.

Solutions to mode collapse:

  1. Minibatch discrimination (Salimans et al., 2016): Add a layer to the discriminator that computes statistics across the minibatch, allowing it to detect when the generator produces too-similar samples. This gives the discriminator a "diversity detector."

  2. Unrolled GANs (Metz et al., 2017): Instead of optimizing the generator against the current discriminator, optimize against a discriminator that has been updated $k$ steps into the future. This gives the generator a longer-term view that discourages mode-hopping.

  3. Wasserstein distance (Section 17.4): The Wasserstein distance provides gradients that encourage the generator to cover all modes, because it measures the cost of transporting mass from $p_g$ to $p_{\text{data}}$.

  4. Mode regularization: Add a diversity-encouraging term to the generator loss, such as penalizing low variance in the generated samples or rewarding coverage of known modes.

17.2.3 Training Instability

GAN training is notoriously unstable for several reasons:

  1. Vanishing gradients: When the discriminator is too good, it assigns near-zero probability to all generated samples. The generator receives vanishing gradients and stops learning.

  2. Oscillation: The generator and discriminator chase each other without converging. The generator shifts from one mode to another, and the discriminator adapts, creating oscillatory dynamics.

  3. Hyperparameter sensitivity: Learning rates, architecture choices, batch size, and the ratio of discriminator to generator updates all significantly affect training. A setting that works for one dataset may fail on another.

17.2.4 Practical Stabilization Techniques

Over the years, numerous tricks have been developed to stabilize GAN training:

  • Label smoothing: Use soft labels (0.9 instead of 1.0 for real, 0.1 instead of 0.0 for fake) to prevent the discriminator from becoming overconfident.
  • Feature matching: Instead of maximizing $D(G(\mathbf{z}))$, minimize the difference between the discriminator's intermediate features for real and generated data.
  • Minibatch discrimination: Allow the discriminator to compare samples within a minibatch, making it harder for the generator to collapse to a single mode.
  • Historical averaging: Penalize the parameters for deviating from their running average.
  • Two-timescale learning rates: Use different learning rates for the generator and discriminator.

We will see more principled solutions in Sections 17.4 (WGAN) and 17.3 (DCGAN architecture guidelines).


17.3 DCGAN: Deep Convolutional GAN

17.3.1 Architecture Guidelines

Radford et al. (2016) identified a set of architectural guidelines that dramatically improved GAN training stability. Their Deep Convolutional GAN (DCGAN) became the standard architecture for image generation and remains influential today.

The key architectural principles:

  1. Replace pooling with strided convolutions: Use strided convolutions in the discriminator and transposed convolutions (fractionally-strided) in the generator instead of max pooling or upsampling layers.

  2. Use batch normalization in both networks: BatchNorm stabilizes training by normalizing intermediate activations. Exception: do not use BatchNorm in the generator's output layer or the discriminator's input layer.

  3. Remove fully connected layers: Use fully convolutional architectures for deeper models.

  4. Generator activations: Use ReLU for all layers except the output, which uses Tanh (mapping to $[-1, 1]$).

  5. Discriminator activations: Use LeakyReLU for all layers (slope 0.2 for the negative part).

17.3.2 Generator Architecture

The DCGAN generator maps a noise vector $\mathbf{z} \in \mathbb{R}^{100}$ to an image through a series of transposed convolutions:

z (100)  → Reshape → 512 x 4 x 4
         → ConvT  → 256 x 8 x 8    (BN + ReLU)
         → ConvT  → 128 x 16 x 16  (BN + ReLU)
         → ConvT  → 64 x 32 x 32   (BN + ReLU)
         → ConvT  → 3 x 64 x 64    (Tanh)

Each transposed convolution doubles the spatial resolution. The number of channels decreases as the resolution increases, following a pyramid structure.

17.3.3 Discriminator Architecture

The discriminator is the mirror image: it takes an image and progressively reduces spatial resolution while increasing channel count:

Image (3 x 64 x 64)
  → Conv  → 64 x 32 x 32   (LeakyReLU)
  → Conv  → 128 x 16 x 16  (BN + LeakyReLU)
  → Conv  → 256 x 8 x 8    (BN + LeakyReLU)
  → Conv  → 512 x 4 x 4    (BN + LeakyReLU)
  → Conv  → 1 x 1 x 1      (Sigmoid)

17.3.4 Weight Initialization

DCGAN uses a specific weight initialization: all weights are initialized from $\mathcal{N}(0, 0.02)$. This small standard deviation prevents large activations early in training.

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

17.3.5 Full PyTorch Implementation

Here is a complete DCGAN implementation for generating $64 \times 64$ images:

import torch
import torch.nn as nn

torch.manual_seed(42)


class DCGANGenerator(nn.Module):
    """DCGAN Generator: maps noise to 3x64x64 images.

    Args:
        latent_dim: Dimensionality of the noise vector z.
        ngf: Base number of generator feature maps.
    """

    def __init__(self, latent_dim: int = 100, ngf: int = 64) -> None:
        super().__init__()
        self.main = nn.Sequential(
            # Input: (latent_dim) -> (ngf*8, 4, 4)
            nn.ConvTranspose2d(latent_dim, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # -> (ngf*4, 8, 8)
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # -> (ngf*2, 16, 16)
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # -> (ngf, 32, 32)
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # -> (3, 64, 64)
            nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False),
            nn.Tanh(),
        )

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        return self.main(z.view(z.size(0), -1, 1, 1))


class DCGANDiscriminator(nn.Module):
    """DCGAN Discriminator: classifies 3x64x64 images as real/fake.

    Args:
        ndf: Base number of discriminator feature maps.
    """

    def __init__(self, ndf: int = 64) -> None:
        super().__init__()
        self.main = nn.Sequential(
            # Input: (3, 64, 64) -> (ndf, 32, 32)
            nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # -> (ndf*2, 16, 16)
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # -> (ndf*4, 8, 8)
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # -> (ndf*8, 4, 4)
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # -> (1, 1, 1)
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.main(x).view(-1, 1)


# --- Demo ---
G = DCGANGenerator(latent_dim=100)
D = DCGANDiscriminator()
G.apply(weights_init)  # Using weights_init from Section 17.3.4
D.apply(weights_init)

z = torch.randn(4, 100)
fake_images = G(z)
predictions = D(fake_images)
print(f"Generated images shape: {fake_images.shape}")  # (4, 3, 64, 64)
print(f"Discriminator output: {predictions.squeeze().tolist()}")

17.3.6 Training Loop

The training loop alternates between discriminator and generator updates:

import torch
import torch.nn as nn

torch.manual_seed(42)

# Assume G, D, dataloader are defined
latent_dim = 100
lr = 2e-4
criterion = nn.BCELoss()

optimizer_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))

# One training step (simplified)
real_images = torch.randn(64, 3, 64, 64)  # Placeholder for real batch
batch_size = real_images.size(0)

real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)

# --- Update Discriminator ---
optimizer_D.zero_grad()
# Real images
real_output = D(real_images)
loss_real = criterion(real_output, real_labels)
# Fake images
z = torch.randn(batch_size, latent_dim)
fake_images = G(z)
fake_output = D(fake_images.detach())  # detach: don't update G
loss_fake = criterion(fake_output, fake_labels)
loss_D = loss_real + loss_fake
loss_D.backward()
optimizer_D.step()

# --- Update Generator ---
optimizer_G.zero_grad()
fake_output = D(fake_images)  # No detach: update G through D
loss_G = criterion(fake_output, real_labels)  # Non-saturating loss
loss_G.backward()
optimizer_G.step()

print(f"D loss: {loss_D.item():.4f}, G loss: {loss_G.item():.4f}")

Key implementation details to note: - The detach() call on fake_images when updating the discriminator prevents gradients from flowing into the generator. - The generator loss uses real_labels (not fake_labels): it wants the discriminator to think its outputs are real. This is the non-saturating formulation from Section 17.1.6. - The reduced $\beta_1 = 0.5$ in Adam is important for stability; the default $\beta_1 = 0.9$ causes oscillations.

17.3.7 Training Details

Typical DCGAN training uses: - Adam optimizer with $\beta_1 = 0.5, \beta_2 = 0.999$ (reduced $\beta_1$ from the default 0.9) - Learning rate: $2 \times 10^{-4}$ - Latent dimension: 100 (standard normal) - Images normalized to $[-1, 1]$ (matching Tanh output) - Equal updates for generator and discriminator ($k = 1$)


17.4 Wasserstein GAN (WGAN)

17.4.1 The Problem with Jensen-Shannon Divergence

The standard GAN implicitly minimizes the Jensen-Shannon divergence (Section 17.1.4). But JSD has a critical flaw: when the supports of $p_{\text{data}}$ and $p_g$ do not overlap (which is almost always the case for high-dimensional data on low-dimensional manifolds), the JSD is a constant ($\log 2$) regardless of how close the distributions are. This means:

  1. The discriminator can perfectly separate real from fake data.
  2. The generator receives zero useful gradient information.
  3. Training fails or requires extremely careful balancing.

This is the root cause of training instability in standard GANs.

17.4.2 The Wasserstein Distance

The Wasserstein-1 distance (also called the Earth Mover's Distance) addresses this problem:

$$W(p_{\text{data}}, p_g) = \inf_{\gamma \in \Pi(p_{\text{data}}, p_g)} \mathbb{E}_{(\mathbf{x}, \mathbf{y}) \sim \gamma} \left[ \|\mathbf{x} - \mathbf{y}\| \right]$$

Intuitively, it is the minimum cost of transporting mass from distribution $p_g$ to $p_{\text{data}}$, where the cost of moving a unit of mass from $\mathbf{x}$ to $\mathbf{y}$ is $\|\mathbf{x} - \mathbf{y}\|$.

The crucial advantage: the Wasserstein distance is continuous and differentiable even when the distributions have non-overlapping support. It provides meaningful gradients that tell the generator which direction to move.

17.4.3 Kantorovich-Rubinstein Duality

The Wasserstein distance has a dual formulation that is much more practical to compute:

$$W(p_{\text{data}}, p_g) = \sup_{\|f\|_L \leq 1} \left[ \mathbb{E}_{\mathbf{x} \sim p_{\text{data}}} [f(\mathbf{x})] - \mathbb{E}_{\mathbf{x} \sim p_g} [f(\mathbf{x})] \right]$$

where the supremum is over all 1-Lipschitz functions $f$ (functions satisfying $|f(\mathbf{x}) - f(\mathbf{y})| \leq \|\mathbf{x} - \mathbf{y}\|$ for all $\mathbf{x}, \mathbf{y}$).

The intuition behind this duality is elegant. A 1-Lipschitz function cannot change too rapidly --- its output changes by at most 1 unit when its input changes by 1 unit. The supremum finds the Lipschitz function that maximizes the difference in expected values between the two distributions. If the distributions are far apart, we can find a Lipschitz function that assigns very different values to them. If they are close, no Lipschitz function can distinguish them well.

Worked example. Consider two 1D distributions: $p_{\text{data}} = \delta(x - 1)$ (a point mass at 1) and $p_g = \delta(x - 3)$ (a point mass at 3). The optimal 1-Lipschitz function assigns the maximum possible gap between the two points: $f(x) = x$ gives $\mathbb{E}_{p_{\text{data}}}[f] - \mathbb{E}_{p_g}[f] = 1 - 3 = -2$, but we want the supremum, so $f(x) = -x$ gives $-1 - (-3) = 2$. More generally, $W(p_{\text{data}}, p_g) = |1 - 3| = 2$, which is simply the distance between the two point masses. This confirms the "earth mover" intuition: we need to move one unit of mass a distance of 2.

In the WGAN framework: - The discriminator (now called the critic, since it no longer classifies) approximates the optimal 1-Lipschitz function $f$. - The critic outputs a real-valued score (not a probability---no sigmoid). - The generator minimizes the Wasserstein distance.

17.4.4 WGAN Training Objective

The WGAN objectives are:

$$\mathcal{L}_{\text{critic}} = \mathbb{E}_{\mathbf{z} \sim p_z} [f_w(G(\mathbf{z}))] - \mathbb{E}_{\mathbf{x} \sim p_{\text{data}}} [f_w(\mathbf{x})]$$

$$\mathcal{L}_{\text{generator}} = -\mathbb{E}_{\mathbf{z} \sim p_z} [f_w(G(\mathbf{z}))]$$

Note the absence of $\log$---the critic gives raw scores, not log-probabilities.

17.4.5 Enforcing the Lipschitz Constraint

The critical challenge in WGAN is enforcing the 1-Lipschitz constraint on the critic. Three approaches have been proposed:

Weight clipping (WGAN): After each gradient update, clip all critic weights to $[-c, c]$ (e.g., $c = 0.01$). This is simple but crude: it biases the critic toward simple functions and can cause training difficulties if $c$ is too small or too large.

Gradient penalty (WGAN-GP): Gulrajani et al. (2017) proposed penalizing the gradient norm along interpolations between real and generated samples:

$$\mathcal{L}_{\text{GP}} = \lambda \mathbb{E}_{\hat{\mathbf{x}} \sim p_{\hat{x}}} \left[ (\|\nabla_{\hat{\mathbf{x}}} f_w(\hat{\mathbf{x}})\|_2 - 1)^2 \right]$$

where $\hat{\mathbf{x}} = \alpha \mathbf{x} + (1 - \alpha) G(\mathbf{z})$ with $\alpha \sim \text{Uniform}(0, 1)$. This penalizes the critic for having gradients that deviate from norm 1, encouraging the Lipschitz constraint. The penalty weight $\lambda = 10$ is standard.

Why interpolated points? The theoretical justification is that the optimal critic for the Wasserstein distance has gradient norm exactly 1 almost everywhere along the straight lines between points sampled from $p_{\text{data}}$ and $p_g$. By enforcing this property on interpolated points, we approximate the Lipschitz constraint without the need for weight clipping.

PyTorch implementation of the gradient penalty:

import torch
import torch.autograd as autograd

torch.manual_seed(42)


def gradient_penalty(
    critic: torch.nn.Module,
    real_data: torch.Tensor,
    fake_data: torch.Tensor,
    lambda_gp: float = 10.0,
) -> torch.Tensor:
    """Compute WGAN-GP gradient penalty.

    Args:
        critic: The critic network.
        real_data: Batch of real samples (B, C, H, W).
        fake_data: Batch of generated samples (B, C, H, W).
        lambda_gp: Penalty coefficient (default: 10).

    Returns:
        The gradient penalty loss term.
    """
    batch_size = real_data.size(0)
    # Random interpolation coefficient
    alpha = torch.rand(batch_size, 1, 1, 1, device=real_data.device)
    alpha = alpha.expand_as(real_data)

    # Interpolate between real and fake
    interpolated = (
        alpha * real_data + (1 - alpha) * fake_data
    ).requires_grad_(True)

    # Critic score on interpolated data
    critic_interpolated = critic(interpolated)

    # Compute gradients w.r.t. interpolated data
    gradients = autograd.grad(
        outputs=critic_interpolated,
        inputs=interpolated,
        grad_outputs=torch.ones_like(critic_interpolated),
        create_graph=True,
        retain_graph=True,
    )[0]

    # Flatten and compute norm
    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)

    # Penalize deviation from norm 1
    penalty = lambda_gp * ((gradient_norm - 1) ** 2).mean()
    return penalty

The total critic loss for WGAN-GP is:

$$\mathcal{L}_{\text{critic}} = \underbrace{\mathbb{E}_{\mathbf{z}}[f_w(G(\mathbf{z}))] - \mathbb{E}_{\mathbf{x}}[f_w(\mathbf{x})]}_{\text{Wasserstein estimate}} + \underbrace{\lambda \mathbb{E}_{\hat{\mathbf{x}}}[(\|\nabla_{\hat{\mathbf{x}}} f_w(\hat{\mathbf{x}})\|_2 - 1)^2]}_{\text{gradient penalty}}$$

Spectral normalization: Miyato et al. (2018) divide each weight matrix by its spectral norm (largest singular value) after each update, directly enforcing a Lipschitz constraint layer by layer. For a weight matrix $\mathbf{W}$, the spectral norm is:

$$\sigma(\mathbf{W}) = \max_{\mathbf{h}: \|\mathbf{h}\| = 1} \|\mathbf{W}\mathbf{h}\|$$

The normalized weight is $\bar{\mathbf{W}} = \mathbf{W} / \sigma(\mathbf{W})$. Since $\|\bar{\mathbf{W}}\mathbf{h}\| \leq \|\mathbf{h}\|$ for all $\mathbf{h}$, each layer is 1-Lipschitz. By composition, the entire network is 1-Lipschitz. In PyTorch, spectral normalization is available as:

from torch.nn.utils import spectral_norm

# Apply spectral normalization to a convolutional layer
layer = spectral_norm(nn.Conv2d(64, 128, 4, 2, 1))

Spectral normalization is simpler than gradient penalty (no need for interpolated samples), adds minimal computational overhead, and works well in practice. It has become the default choice for discriminator/critic normalization in many modern GAN architectures.

17.4.6 WGAN Training Guidelines

WGAN training differs from standard GAN training in several ways:

Aspect Standard GAN WGAN / WGAN-GP
Critic output Sigmoid (probability) Linear (real-valued score)
Loss function Binary cross-entropy Wasserstein distance
Critic updates per generator update 1 5 (typical)
Optimizer Adam ($\beta_1 = 0.5$) RMSprop or Adam ($\beta_1 = 0$)
BatchNorm in critic Yes No (for WGAN-GP)
Training signal Saturates when critic is strong Meaningful gradients throughout

17.4.7 WGAN-GP Training Implementation

Here is the key training step for WGAN-GP, showing how it differs from standard GAN training:

import torch

torch.manual_seed(42)

# Assume critic, generator, and gradient_penalty are defined
# (see Section 17.4.5 for gradient_penalty implementation)
n_critic = 5  # Critic updates per generator update
latent_dim = 100

optimizer_C = torch.optim.Adam(
    critic.parameters(), lr=1e-4, betas=(0.0, 0.9)
)
optimizer_G = torch.optim.Adam(
    generator.parameters(), lr=1e-4, betas=(0.0, 0.9)
)

# One training iteration (simplified)
for _ in range(n_critic):
    real_data = torch.randn(64, 3, 64, 64)  # Placeholder
    z = torch.randn(64, latent_dim)
    fake_data = generator(z).detach()

    # Critic loss: maximize E[f(real)] - E[f(fake)]
    # (minimize the negative)
    critic_real = critic(real_data).mean()
    critic_fake = critic(fake_data).mean()
    gp = gradient_penalty(critic, real_data, fake_data)
    loss_C = critic_fake - critic_real + gp

    optimizer_C.zero_grad()
    loss_C.backward()
    optimizer_C.step()

# Generator update
z = torch.randn(64, latent_dim)
fake_data = generator(z)
loss_G = -critic(fake_data).mean()

optimizer_G.zero_grad()
loss_G.backward()
optimizer_G.step()

# The Wasserstein estimate is: critic_real - critic_fake
# This should increase over training (more negative loss_C)
wasserstein_estimate = (critic_real - critic_fake).item()
print(f"Wasserstein estimate: {wasserstein_estimate:.4f}")

Note the key differences from standard GAN training: no log, no Sigmoid in the critic, 5 critic updates per generator update, Adam with $\beta_1 = 0$, and the gradient penalty term.

17.4.8 Advantages of WGAN

  1. Meaningful loss: The critic loss correlates with sample quality. In standard GANs, the discriminator loss is not informative about generation quality. In WGAN, a decreasing Wasserstein estimate indicates improving generation.

  2. Stable training: No need for careful balancing between generator and discriminator. Training the critic to completion is actually beneficial --- in fact, a stronger critic provides better gradients to the generator, unlike standard GANs where a strong discriminator causes vanishing gradients.

  3. No mode collapse: The Wasserstein distance captures distance between distributions even when they differ, providing gradients that encourage the generator to cover all modes. Intuitively, even if the generator covers only one mode, the Wasserstein distance "knows" that the other modes exist and provides gradients pointing toward them.


17.5 Conditional GANs

17.5.1 Adding Control to Generation

Standard GANs generate data from random noise, with no control over what is produced. Conditional GANs (cGANs) add control by conditioning both the generator and discriminator on additional information $\mathbf{y}$ (e.g., class labels, text descriptions, or other images).

The conditional minimax objective is:

$$\min_G \max_D V(D, G) = \mathbb{E}_{\mathbf{x} \sim p_{\text{data}}} [\log D(\mathbf{x} | \mathbf{y})] + \mathbb{E}_{\mathbf{z} \sim p_z} [\log(1 - D(G(\mathbf{z} | \mathbf{y}) | \mathbf{y}))]$$

17.5.2 Conditioning Mechanisms

There are several ways to inject the condition $\mathbf{y}$ into the networks:

Concatenation: For the generator, concatenate the one-hot label $\mathbf{y}$ with the noise vector $\mathbf{z}$. For the discriminator, concatenate $\mathbf{y}$ with the image (as an additional channel) or with a flattened feature vector.

Embedding and projection: Learn an embedding of $\mathbf{y}$ and combine it with intermediate features via element-wise multiplication or addition. This is more flexible than concatenation and works better for complex conditions.

Auxiliary classifier (AC-GAN): Add a classification head to the discriminator that predicts the class of the input. The generator is trained to both fool the discriminator and produce correctly classified samples.

17.5.3 Class-Conditional Generation

The most common cGAN application is class-conditional image generation: generating images of a specified class. For MNIST, this means generating images conditioned on the digit label (0--9).

The generator receives both noise $\mathbf{z}$ and a one-hot label $\mathbf{y}$:

$$G(\mathbf{z}, \mathbf{y}) = \text{image of class } \mathbf{y}$$

The discriminator evaluates both the image quality and its consistency with the label:

$$D(\mathbf{x}, \mathbf{y}) = \text{probability that } \mathbf{x} \text{ is a real image of class } \mathbf{y}$$

17.5.4 PyTorch Implementation of a Conditional GAN

Here is the key modification for a conditional generator and discriminator on MNIST:

import torch
import torch.nn as nn

torch.manual_seed(42)


class ConditionalGenerator(nn.Module):
    """Conditional generator for MNIST (28x28 grayscale).

    Args:
        latent_dim: Dimensionality of noise vector z.
        num_classes: Number of classes for conditioning.
        embed_dim: Dimensionality of class embedding.
    """

    def __init__(
        self,
        latent_dim: int = 100,
        num_classes: int = 10,
        embed_dim: int = 50,
    ) -> None:
        super().__init__()
        self.label_embedding = nn.Embedding(num_classes, embed_dim)
        self.model = nn.Sequential(
            nn.Linear(latent_dim + embed_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh(),
        )

    def forward(
        self, z: torch.Tensor, labels: torch.Tensor
    ) -> torch.Tensor:
        label_embed = self.label_embedding(labels)
        x = torch.cat([z, label_embed], dim=1)
        return self.model(x).view(-1, 1, 28, 28)


# --- Demo ---
G = ConditionalGenerator()
z = torch.randn(8, 100)
labels = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
fake_images = G(z, labels)
print(f"Conditional output shape: {fake_images.shape}")  # (8, 1, 28, 28)

The class label is embedded into a dense vector and concatenated with the noise vector before being passed through the generator. The discriminator receives both the image and the label, allowing it to judge not just whether the image is realistic, but whether it matches the specified class.

17.5.5 Pix2Pix: Image-to-Image Translation

Pix2Pix (Isola et al., 2017) extends conditional GANs to image-to-image translation: converting images from one domain to another (e.g., edges to photos, day to night, segmentation maps to photorealistic images).

The generator uses a U-Net architecture (encoder-decoder with skip connections), and the discriminator is a PatchGAN that classifies whether each $N \times N$ patch of the output is real or fake. The total loss combines the adversarial loss with an L1 reconstruction loss:

$$\mathcal{L} = \mathcal{L}_{\text{GAN}} + \lambda \mathcal{L}_{L1}$$

where:

  • $\mathcal{L}_{\text{GAN}}$ is the standard conditional adversarial loss
  • $\mathcal{L}_{L1} = \mathbb{E}_{\mathbf{x}, \mathbf{y}} [\|\mathbf{y} - G(\mathbf{x})\|_1]$ is the L1 reconstruction loss between the generated output and the ground-truth target
  • $\lambda = 100$ is the standard weighting

The L1 loss provides low-frequency correctness (overall structure), while the GAN loss provides high-frequency details (textures, sharpness). Using L1 rather than L2 produces less blurring, because L1 does not penalize large deviations as severely (it does not average modes).

The PatchGAN discriminator is a critical innovation. Instead of producing a single real/fake score for the entire image, it produces a grid of scores, one for each overlapping $70 \times 70$ patch. This is equivalent to assuming that pixels separated by more than 70 pixels are independent, which is a reasonable assumption for texture quality. The PatchGAN is also much cheaper to evaluate than a full-image discriminator.

17.5.6 CycleGAN: Unpaired Image Translation

A limitation of Pix2Pix is that it requires paired training data (input-output pairs). CycleGAN (Zhu et al., 2017) removes this requirement using a cycle consistency loss. Given two domains $A$ and $B$ (e.g., horses and zebras):

  • Generator $G_{A \to B}$ maps domain $A$ to domain $B$.
  • Generator $G_{B \to A}$ maps domain $B$ to domain $A$.
  • Cycle consistency: If we translate from $A$ to $B$ and back, we should recover the original: $G_{B \to A}(G_{A \to B}(\mathbf{x}_A)) \approx \mathbf{x}_A$.

The cycle consistency loss is:

$$\mathcal{L}_{\text{cyc}} = \mathbb{E}_{\mathbf{x}_A} [\|G_{B \to A}(G_{A \to B}(\mathbf{x}_A)) - \mathbf{x}_A\|_1] + \mathbb{E}_{\mathbf{x}_B} [\|G_{A \to B}(G_{B \to A}(\mathbf{x}_B)) - \mathbf{x}_B\|_1]$$

This elegant constraint ensures that the translation is meaningful: the model cannot simply map all inputs to a single output in the target domain, because it must be able to recover the original input.


17.6 StyleGAN: State-of-the-Art Generation

17.6.1 Key Innovations

StyleGAN (Karras et al., 2019) and its successors (StyleGAN2, StyleGAN3) represent the pinnacle of GAN-based image generation. Several architectural innovations make this possible:

Mapping network: Instead of feeding $\mathbf{z}$ directly to the generator, a mapping network $f: \mathbb{R}^{512} \to \mathbb{R}^{512}$ first transforms it into an intermediate latent space $\mathcal{W}$: $\mathbf{w} = f(\mathbf{z})$. The $\mathcal{W}$ space is less entangled than the $\mathcal{Z}$ space, meaning each dimension in $\mathcal{W}$ is more likely to control a single factor of variation.

Adaptive Instance Normalization (AdaIN): The style vector $\mathbf{w}$ controls generation through AdaIN at each layer:

$$\text{AdaIN}(\mathbf{x}_i, \mathbf{w}) = \gamma_s(\mathbf{w}) \cdot \frac{\mathbf{x}_i - \mu(\mathbf{x}_i)}{\sigma(\mathbf{x}_i)} + \beta_s(\mathbf{w})$$

Different layers control different levels of detail: early layers control high-level attributes (pose, face shape), middle layers control medium features (facial features, hairstyle), and late layers control fine details (color, microstructure).

Noise injection: Stochastic variation (freckles, hair strands, background texture) is added through per-pixel noise inputs at each layer, separate from the style vector.

Progressive growing: The original StyleGAN trained progressively, starting at $4 \times 4$ resolution and gradually adding higher-resolution layers. StyleGAN2 replaced this with a fixed architecture with skip connections and residual connections.

17.6.2 Style Mixing

Because different layers control different attributes, StyleGAN supports style mixing: using one $\mathbf{w}$ vector for early layers (coarse styles) and a different $\mathbf{w}$ vector for later layers (fine styles). This enables disentangled control over different aspects of the generated image.

17.6.3 Progressive Growing

The original Progressive GAN (Karras et al., 2018), which preceded StyleGAN, introduced the idea of progressive training: start by training the generator and discriminator at $4 \times 4$ resolution, then gradually add layers to increase the resolution to $8 \times 8$, $16 \times 16$, and so on up to $1024 \times 1024$.

The rationale is that low-resolution images capture coarse structure (overall face shape, background color), while higher resolutions add progressively finer details. By training from low to high resolution, the model first learns the global structure and then refines the details --- this is much more stable than trying to learn all resolutions simultaneously.

New layers are faded in smoothly using a linear interpolation parameter $\alpha$ that transitions from 0 to 1 over several thousand training iterations. This prevents the sudden introduction of new parameters from destabilizing training.

17.6.4 The W and W+ Latent Spaces

StyleGAN's mapping network introduces a hierarchy of latent spaces:

  • Z space ($\mathcal{Z}$): The input noise space. $\mathbf{z} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$.
  • W space ($\mathcal{W}$): The intermediate latent space. $\mathbf{w} = f(\mathbf{z})$ where $f$ is the 8-layer MLP mapping network.
  • W+ space ($\mathcal{W}^+$): Extends $\mathcal{W}$ by allowing a different $\mathbf{w}$ vector for each layer of the synthesis network. This dramatically increases expressiveness.

The $\mathcal{W}$ space is particularly important for GAN inversion --- the task of finding the latent code corresponding to a real image. Because $\mathcal{W}$ is more disentangled than $\mathcal{Z}$, searching in $\mathcal{W}$ space produces more semantically meaningful edits. Many image editing applications (aging faces, changing hairstyles, adding expressions) work by first projecting a real image into $\mathcal{W}^+$ space and then manipulating the latent code.

17.6.5 The Evolution to StyleGAN3

StyleGAN2 addressed artifacts in StyleGAN (water droplet artifacts from progressive growing) by using a fixed architecture with weight demodulation instead of AdaIN. Weight demodulation normalizes the convolution weights based on the incoming style, achieving the same effect as AdaIN without the artifacts that come from instance normalization.

StyleGAN3 addressed a subtle but important issue: features in StyleGAN2 were anchored to pixel coordinates rather than following the content of the image. This means that translating the content slightly causes features to "snap" to the pixel grid rather than moving smoothly. StyleGAN3 introduces equivariance constraints that ensure features move naturally with the content, enabling smooth latent space interpolations and consistent video generation. The key technical contribution is replacing standard discrete-domain operations with continuous-domain equivalents that are inherently translation and rotation equivariant.

17.6.6 Truncation Trick

A practical technique for improving sample quality at the cost of diversity is the truncation trick. Instead of sampling $\mathbf{w} = f(\mathbf{z})$ with $\mathbf{z} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$, we truncate the distribution:

$$\mathbf{w}' = \bar{\mathbf{w}} + \psi (\mathbf{w} - \bar{\mathbf{w}})$$

where $\bar{\mathbf{w}} = \mathbb{E}_{\mathbf{z}}[f(\mathbf{z})]$ is the mean latent code and $\psi \in [0, 1]$ is the truncation parameter. Setting $\psi = 1$ gives the full distribution; $\psi = 0$ produces only the "average" face. Values around $\psi = 0.7$ are common for producing high-quality, reasonably diverse samples. This technique directly trades off the quality-diversity spectrum, which is a common theme in generative modeling (as we saw with temperature in Chapter 16's discussion of VAEs and will see again with sampling temperature in Chapter 21).


17.7 Evaluating Generative Models

17.7.1 The Evaluation Challenge

Evaluating generative models is fundamentally difficult. Unlike classification (where accuracy provides a clear metric), generation requires assessing both quality (do individual samples look realistic?) and diversity (does the model capture the full variety of the data?).

Human evaluation is the gold standard but is expensive, slow, and subjective. Automated metrics aim to approximate human judgment.

17.7.2 Inception Score (IS)

The Inception Score (Salimans et al., 2016) uses a pretrained Inception-v3 classifier to evaluate generated images.

Intuition first. A good generative model should produce images that are (a) individually recognizable (high quality) and (b) collectively diverse. The IS captures both by comparing each image's class distribution to the overall class distribution.

Formula:

$$\text{IS} = \exp\left(\mathbb{E}_{\mathbf{x} \sim p_g} \left[ D_{\text{KL}}(p(y|\mathbf{x}) \| p(y)) \right]\right)$$

where:

  • $p(y|\mathbf{x})$ is the Inception network's class distribution for image $\mathbf{x}$ (a softmax over 1,000 ImageNet classes)
  • $p(y) = \mathbb{E}_{\mathbf{x} \sim p_g}[p(y|\mathbf{x})]$ is the marginal class distribution over all generated images
  • $D_{\text{KL}}$ is the KL divergence (as we defined in Chapter 4)

Worked example. Suppose we generate 3 images, and the Inception network produces these class distributions (simplified to 3 classes):

  • Image 1: $p(y|\mathbf{x}_1) = [0.9, 0.05, 0.05]$ (confidently class 1)
  • Image 2: $p(y|\mathbf{x}_2) = [0.05, 0.9, 0.05]$ (confidently class 2)
  • Image 3: $p(y|\mathbf{x}_3) = [0.05, 0.05, 0.9]$ (confidently class 3)

The marginal is $p(y) = \frac{1}{3}[1.0, 1.0, 1.0] = [0.333, 0.333, 0.333]$ (uniform).

$D_{\text{KL}}(p(y|\mathbf{x}_1) \| p(y)) = 0.9 \ln(0.9/0.333) + 0.05 \ln(0.05/0.333) + 0.05 \ln(0.05/0.333) \approx 0.796$

The IS is $\exp(0.796) \approx 2.22$, close to the maximum of 3 (= number of classes). Compare this to a collapsed generator producing only class 1: $p(y) = [0.9, 0.05, 0.05]$ with $D_{\text{KL}}(p(y|\mathbf{x}) \| p(y)) \approx 0$ and IS $\approx 1$.

High IS requires: - Quality: Each image should be confidently classified (low entropy of $p(y|\mathbf{x})$). - Diversity: The marginal distribution $p(y)$ should be uniform across classes (high entropy of $p(y)$).

The IS ranges from 1 (worst) to the number of classes (best). On ImageNet, real images achieve an IS around 200--250.

Limitations of IS: - Does not compare to real data (a model could generate perfect dogs and score highly even if the dataset contains cats). - Sensitive to the Inception model's biases. - Does not capture intra-class diversity. - Provides only a scalar summary, losing information about the quality-diversity tradeoff.

17.7.3 Frechet Inception Distance (FID)

The Frechet Inception Distance (Heusel et al., 2017) is the most widely used metric for evaluating generative models. It compares the distribution of generated images to the distribution of real images in the feature space of a pretrained Inception-v3 network.

Both distributions are modeled as multivariate Gaussians: - Real data: $\mathcal{N}(\boldsymbol{\mu}_r, \boldsymbol{\Sigma}_r)$ - Generated data: $\mathcal{N}(\boldsymbol{\mu}_g, \boldsymbol{\Sigma}_g)$

The FID is the Frechet distance (Wasserstein-2 distance) between these Gaussians:

$$\text{FID} = \|\boldsymbol{\mu}_r - \boldsymbol{\mu}_g\|^2 + \text{Tr}\left(\boldsymbol{\Sigma}_r + \boldsymbol{\Sigma}_g - 2(\boldsymbol{\Sigma}_r \boldsymbol{\Sigma}_g)^{1/2}\right)$$

  • FID = 0 when the distributions are identical.
  • Lower FID indicates better quality and diversity.
  • Real-world GANs achieve FID scores of 2--20 on standard benchmarks.

Advantages of FID over IS: - Compares to real data distribution (not just individual sample quality). - More sensitive to mode dropping. - More robust and consistent across runs.

Best practices for FID computation: - Use at least 10,000 generated samples (50,000 is standard). - Use the same preprocessing as the Inception network. - Report the number of samples and the specific Inception checkpoint used.

17.7.4 Other Metrics

Metric Measures Advantage Limitation
IS Quality + class diversity Simple No comparison to real data
FID Quality + diversity vs. real data Robust, widely used Assumes Gaussian features
Precision Quality (fraction of realistic samples) Interpretable Needs a threshold
Recall Diversity (fraction of real data covered) Interpretable Needs a threshold
KID Like FID but unbiased Unbiased for small samples Less established

17.8 GAN Variants Landscape

17.8.1 A Brief Taxonomy

The GAN ecosystem is vast. Here is a high-level map of the most important variants:

By architecture: - DCGAN: Convolutional architecture with BatchNorm - ProGAN / Progressive GAN: Progressive training from low to high resolution - StyleGAN: Style-based generator with mapping network - BigGAN: Large-scale conditional generation with class embeddings

By loss function: - Standard GAN: Binary cross-entropy - WGAN / WGAN-GP: Wasserstein distance - LSGAN: Least squares loss (replaces log with squared error) - Hinge loss GAN: Uses hinge loss for the discriminator

By conditioning: - cGAN: Conditional on class labels - Pix2Pix: Image-to-image translation (paired data) - CycleGAN: Unpaired image-to-image translation - Text-to-image GANs: Conditional on text descriptions

17.8.2 GANs vs. Other Generative Models

Aspect GAN VAE (Ch. 16) Diffusion Model
Training Adversarial (unstable) ELBO maximization (stable) Denoising score matching (stable)
Sample quality High (sharp) Lower (blurry) Highest
Diversity Mode collapse risk Good coverage Excellent coverage
Likelihood Not available Lower bound (ELBO) Exact or lower bound
Speed Fast generation Fast generation Slow generation
Latent space Unstructured Structured (KL regularized) No explicit latent space

GANs excel at fast, high-quality generation. Their main weaknesses---training instability and mode collapse---have been largely addressed by WGAN-GP, spectral normalization, and the architectural insights from StyleGAN. However, diffusion models have recently surpassed GANs in image quality on many benchmarks, and the field continues to evolve.

The shift toward diffusion models. Starting around 2021, diffusion models (which we will explore in Chapter 27) overtook GANs on key benchmarks like ImageNet FID. The reasons are instructive:

  1. Training stability: Diffusion models optimize a simple denoising objective (as we discussed in the context of denoising autoencoders in Chapter 16). There is no adversarial game, no mode collapse, and no need for careful balancing of two networks.

  2. Mode coverage: Because diffusion models are trained with a likelihood-based objective, they naturally cover all modes of the data distribution. GANs must rely on heuristics (minibatch discrimination, WGAN) to avoid mode collapse.

  3. Scalability: Diffusion models scale predictably with compute --- more training almost always improves quality. GAN training does not have this monotonic relationship between compute and quality.

  4. Generation speed (GANs win here): GANs generate images in a single forward pass (milliseconds), while diffusion models require dozens to thousands of denoising steps (seconds). This speed advantage keeps GANs relevant for real-time applications.

The modern landscape is converging: techniques from GANs (discriminator-based quality metrics, adversarial training) are being combined with diffusion models to get the best of both worlds. For example, some approaches use a diffusion model for the initial generation and a GAN-like discriminator for refinement.

17.8.3 Historical Perspective: The GAN Revolution

It is worth stepping back to appreciate the impact of GANs on the field. When Goodfellow introduced GANs in 2014, the best generative models produced blurry, low-resolution images. Within five years, StyleGAN2 was generating photorealistic $1024 \times 1024$ faces that fooled human observers. This progression drove enormous investment in generative AI and directly contributed to the cultural shift toward AI-generated content.

The GAN framework also influenced research beyond image generation:

  • Drug discovery: GANs generate novel molecular structures with desired properties.
  • Data augmentation: GANs synthesize training data for rare classes.
  • Super-resolution: SRGAN and ESRGAN use adversarial training for image upscaling.
  • Anomaly detection: The discriminator can serve as an anomaly detector --- real data gets high scores, anomalies get low scores.
  • Domain adaptation: GANs align feature distributions across domains.

Even as diffusion models take the lead in raw image quality, the GAN framework remains influential, and the training dynamics and architectural principles we have studied in this chapter remain essential knowledge for any AI engineer.


17.9 Practical Considerations

17.9.1 Training Monitoring

Monitor these signals during GAN training:

  1. Discriminator loss: Should not go to zero (discriminator too strong) or diverge (discriminator too weak).
  2. Generator loss: Should decrease over time but may oscillate.
  3. Generated samples: Periodically save generated images and visually inspect quality and diversity.
  4. FID during training: If computational resources allow, compute FID every few epochs as a quantitative quality measure.

17.9.2 Common Failure Modes

Symptom Likely Cause Fix
Generator produces noise Discriminator too strong early Reduce D learning rate, add label smoothing
Mode collapse (same output) Generator finds a shortcut Use minibatch discrimination, WGAN-GP
Checkerboard artifacts Transposed convolution overlap Use upsampling + convolution
Training diverges Learning rate too high Reduce learning rate, use spectral norm
Discriminator loss = 0 Discriminator memorizes Add noise to D input, reduce D capacity

17.9.3 Choosing a GAN Variant

  • Quick prototyping: Start with DCGAN. It is simple, well-understood, and works on many datasets.
  • Stable training needed: Use WGAN-GP. It provides meaningful loss curves and avoids mode collapse.
  • High-quality face generation: Use StyleGAN2/3. The architecture is optimized for this task.
  • Controlled generation: Use conditional GAN (class labels) or Pix2Pix (image-to-image).
  • Best possible quality: Consider diffusion models instead of GANs (Chapter 27).

17.9.4 Practical Training Tips

Over years of GAN research, the community has compiled a set of practical tips that can make the difference between success and failure:

Data preprocessing: - Normalize images to $[-1, 1]$ (matching Tanh output in the generator) rather than $[0, 1]$. - Use data augmentation carefully: random horizontal flips are generally safe, but aggressive augmentations can confuse the discriminator. - For small datasets, consider differentiable augmentation (DiffAugment) which applies augmentations to both real and generated images, preventing the discriminator from memorizing the training set.

Architecture tips: - Use transposed convolutions with kernel size divisible by stride to avoid checkerboard artifacts. Better yet, use nearest-neighbor upsampling followed by a regular convolution. - Avoid max pooling in the discriminator; use strided convolutions instead. - Use spectral normalization in the discriminator by default --- it is cheap and effective. - For the generator, batch normalization works well. For the discriminator, layer normalization or spectral normalization are preferred (especially with WGAN-GP, where batch normalization should be avoided).

Optimization tips: - Use Adam with $\beta_1 = 0.0$ and $\beta_2 = 0.9$ for WGAN-GP (note: $\beta_1 = 0$, not the default 0.9). - Learning rates between $10^{-4}$ and $2 \times 10^{-4}$ work well for most architectures. - If training diverges, reduce the learning rate rather than adding regularization. - Save checkpoints frequently --- GAN training is non-monotonic, and the best model may occur well before the end of training.

Monitoring: - Save generated samples every few hundred iterations for visual inspection. - Track the discriminator's accuracy on real and fake data separately. If it reaches 100% on both, training is likely in a good state. If it reaches 50% on both, the generator is winning. If it oscillates wildly, training is unstable. - If using WGAN-GP, the critic loss (Wasserstein estimate) should decrease over time and correlate with sample quality.


17.9.5 GAN Evaluation in Practice

Computing FID and IS requires careful implementation. Here is a practical workflow:

import torch
import torch.nn as nn
import numpy as np

torch.manual_seed(42)


def compute_fid(
    real_features: np.ndarray,
    fake_features: np.ndarray,
) -> float:
    """Compute Frechet Inception Distance between two sets of features.

    Args:
        real_features: Features from real images (N, D).
        fake_features: Features from generated images (M, D).

    Returns:
        The FID score (lower is better).
    """
    # Compute statistics
    mu_r = np.mean(real_features, axis=0)
    mu_g = np.mean(fake_features, axis=0)
    sigma_r = np.cov(real_features, rowvar=False)
    sigma_g = np.cov(fake_features, rowvar=False)

    # Compute FID
    diff = mu_r - mu_g
    from scipy.linalg import sqrtm
    covmean = sqrtm(sigma_r @ sigma_g)
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = diff @ diff + np.trace(sigma_r + sigma_g - 2 * covmean)
    return float(fid)


# In practice, features are extracted from the penultimate layer
# of an Inception-v3 network pretrained on ImageNet.
# Use at least 10,000 (ideally 50,000) samples for reliable FID.

Important practical considerations for FID: - Always use the same Inception checkpoint and preprocessing for all comparisons. - FID is sensitive to the number of samples: fewer samples lead to higher variance. Report the number of samples used. - FID computed on different random seeds will vary. Compute FID multiple times and report the mean and standard deviation. - The pytorch-fid and cleanfid packages provide reference implementations that handle these details correctly.


17.10 Summary

This chapter explored the adversarial approach to generative modeling:

  1. The GAN framework sets up a minimax game between a generator and a discriminator. The generator learns to produce realistic data by fooling the discriminator; the discriminator learns to distinguish real from fake.

  2. Training dynamics are fundamentally different from standard optimization. Mode collapse, training instability, and vanishing gradients are the main challenges.

  3. DCGAN established architectural guidelines (strided convolutions, BatchNorm, ReLU/LeakyReLU) that made GAN training practical and reproducible.

  4. WGAN replaced the Jensen-Shannon divergence with the Wasserstein distance, providing meaningful gradients even when distributions have non-overlapping support. WGAN-GP improved on weight clipping with gradient penalty.

  5. Conditional GANs add control to generation by conditioning on class labels, text, or other images. Pix2Pix and CycleGAN extended this to image-to-image translation.

  6. StyleGAN introduced the mapping network, AdaIN, and noise injection to achieve state-of-the-art face generation with disentangled style control.

  7. Evaluation of generative models requires careful metrics. FID is the current standard, measuring the distance between real and generated feature distributions.

GANs represent a fundamentally different philosophy from the reconstruction-based (autoencoders) and likelihood-based (VAEs) approaches of Chapter 16. By replacing explicit density modeling with an adversarial game, GANs achieve sharper outputs at the cost of training stability. Understanding when to use which generative model---and how to train them reliably---is an essential skill for the AI engineer.


References

  • Goodfellow, I. J., Pouget-Abadie, J., Mirza, M., et al. (2014). "Generative Adversarial Nets." NeurIPS 2014.
  • Radford, A., Metz, L., and Chintala, S. (2016). "Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks." ICLR 2016.
  • Arjovsky, M., Chintala, S., and Bottou, L. (2017). "Wasserstein Generative Adversarial Networks." ICML 2017.
  • Gulrajani, I., Ahmed, F., Arjovsky, M., Dumoulin, V., and Courville, A. (2017). "Improved Training of Wasserstein GANs." NeurIPS 2017.
  • Mirza, M. and Osindero, S. (2014). "Conditional Generative Adversarial Nets." arXiv:1411.1784.
  • Isola, P., Zhu, J.-Y., Zhou, T., and Efros, A. A. (2017). "Image-to-Image Translation with Conditional Adversarial Networks." CVPR 2017.
  • Karras, T., Laine, S., and Aila, T. (2019). "A Style-Based Generator Architecture for Generative Adversarial Networks." CVPR 2019.
  • Karras, T., Laine, S., Aittala, M., et al. (2020). "Analyzing and Improving the Image Quality of StyleGAN." CVPR 2020.
  • Salimans, T., Goodfellow, I., Zaremba, W., et al. (2016). "Improved Techniques for Training GANs." NeurIPS 2016.
  • Heusel, M., Ramsauer, H., Unterthiner, T., Nessler, B., and Hochreiter, S. (2017). "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium." NeurIPS 2017.
  • Miyato, T., Kataoka, T., Koyama, M., and Yoshida, Y. (2018). "Spectral Normalization for Generative Adversarial Networks." ICLR 2018.