Case Study 2: Conditional Image Generation with cGAN
Overview
In Case Study 1, we built a DCGAN that generates random digits. But what if we want to generate a specific digit? A conditional GAN (cGAN) adds this control by providing the desired class label as input to both the generator and discriminator.
In this case study, we implement a conditional GAN for MNIST that generates digits on demand. We explore conditioning mechanisms, evaluate the accuracy of conditional generation, and visualize how the noise vector controls style while the condition controls content.
Problem Definition
Task: Generate images of a specified digit class (0--9) on demand.
Dataset: MNIST with labels (labels are used for conditioning, not for classification).
Success criteria: 1. Generated images conditioned on class $c$ should actually look like digit $c$. 2. Generated images should be visually realistic. 3. Different noise vectors with the same condition should produce diverse images.
Architecture
Conditioning Strategy
We use concatenation-based conditioning: - Generator: The noise vector $\mathbf{z} \in \mathbb{R}^{100}$ is concatenated with the one-hot label $\mathbf{y} \in \mathbb{R}^{10}$, producing a $110$-dimensional input. - Discriminator: The one-hot label is expanded to a full image-sized channel and concatenated with the input image, producing a $2$-channel input.
Conditional Generator
"""Conditional GAN generator."""
import torch
import torch.nn as nn
torch.manual_seed(42)
class ConditionalGenerator(nn.Module):
"""Generator conditioned on class label.
Concatenates noise and one-hot label before generation.
Args:
latent_dim: Dimension of noise vector.
n_classes: Number of classes for conditioning.
feature_maps: Base number of feature maps.
"""
def __init__(
self,
latent_dim: int = 100,
n_classes: int = 10,
feature_maps: int = 256,
) -> None:
super().__init__()
self.latent_dim = latent_dim
self.n_classes = n_classes
input_dim = latent_dim + n_classes
self.fc = nn.Sequential(
nn.Linear(input_dim, feature_maps * 7 * 7),
nn.BatchNorm1d(feature_maps * 7 * 7),
nn.ReLU(True),
)
self.conv = nn.Sequential(
nn.ConvTranspose2d(
feature_maps, feature_maps // 2,
4, stride=2, padding=1, bias=False
),
nn.BatchNorm2d(feature_maps // 2),
nn.ReLU(True),
nn.ConvTranspose2d(
feature_maps // 2, 1,
4, stride=2, padding=1, bias=False
),
nn.Tanh(),
)
self._feature_maps = feature_maps
def forward(
self, z: torch.Tensor, labels: torch.Tensor
) -> torch.Tensor:
"""Generate images conditioned on class labels.
Args:
z: Noise tensor of shape (batch_size, latent_dim).
labels: One-hot label tensor of shape (batch_size, n_classes).
Returns:
Generated images of shape (batch_size, 1, 28, 28).
"""
x = torch.cat([z, labels], dim=1)
h = self.fc(x)
h = h.view(-1, self._feature_maps, 7, 7)
return self.conv(h)
Conditional Discriminator
class ConditionalDiscriminator(nn.Module):
"""Discriminator conditioned on class label.
Expands the label to an image-sized channel and concatenates
with the input image.
Args:
n_classes: Number of classes for conditioning.
feature_maps: Base number of feature maps.
"""
def __init__(
self, n_classes: int = 10, feature_maps: int = 64
) -> None:
super().__init__()
self.n_classes = n_classes
# Label embedding: expands class to spatial map
self.label_embedding = nn.Sequential(
nn.Linear(n_classes, 28 * 28),
nn.LeakyReLU(0.2),
)
# 2 channels: image + label map
self.main = nn.Sequential(
nn.Conv2d(2, feature_maps, 4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(
feature_maps, feature_maps * 2,
4, stride=2, padding=1, bias=False
),
nn.BatchNorm2d(feature_maps * 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Flatten(),
nn.Linear(feature_maps * 2 * 7 * 7, 1),
nn.Sigmoid(),
)
def forward(
self, x: torch.Tensor, labels: torch.Tensor
) -> torch.Tensor:
"""Classify real/fake conditioned on class.
Args:
x: Image tensor of shape (batch_size, 1, 28, 28).
labels: One-hot label tensor of shape (batch_size, n_classes).
Returns:
Probability of being real, shape (batch_size, 1).
"""
label_map = self.label_embedding(labels)
label_map = label_map.view(-1, 1, 28, 28)
combined = torch.cat([x, label_map], dim=1)
return self.main(combined)
Training
"""Conditional GAN training loop."""
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
])
train_dataset = datasets.MNIST(
root="./data", train=True, download=True, transform=transform
)
train_loader = DataLoader(
train_dataset, batch_size=128, shuffle=True, drop_last=True
)
generator = ConditionalGenerator(latent_dim=100, n_classes=10)
discriminator = ConditionalDiscriminator(n_classes=10)
optimizer_g = torch.optim.Adam(
generator.parameters(), lr=2e-4, betas=(0.5, 0.999)
)
optimizer_d = torch.optim.Adam(
discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999)
)
criterion = nn.BCELoss()
for epoch in range(25):
for real_images, real_labels_idx in train_loader:
batch_size = real_images.size(0)
# One-hot encode labels
real_labels_onehot = F.one_hot(
real_labels_idx, num_classes=10
).float()
real_target = torch.ones(batch_size, 1) * 0.9
fake_target = torch.zeros(batch_size, 1)
# --- Train Discriminator ---
optimizer_d.zero_grad()
# Real images with correct labels
output_real = discriminator(real_images, real_labels_onehot)
loss_real = criterion(output_real, real_target)
# Fake images with random labels
noise = torch.randn(batch_size, 100)
fake_labels_idx = torch.randint(0, 10, (batch_size,))
fake_labels_onehot = F.one_hot(
fake_labels_idx, num_classes=10
).float()
fake_images = generator(noise, fake_labels_onehot)
output_fake = discriminator(
fake_images.detach(), fake_labels_onehot
)
loss_fake = criterion(output_fake, fake_target)
loss_d = loss_real + loss_fake
loss_d.backward()
optimizer_d.step()
# --- Train Generator ---
optimizer_g.zero_grad()
output_fake = discriminator(fake_images, fake_labels_onehot)
loss_g = criterion(output_fake, torch.ones(batch_size, 1))
loss_g.backward()
optimizer_g.step()
if (epoch + 1) % 5 == 0:
print(
f"Epoch {epoch+1}/25 | "
f"D Loss: {loss_d.item():.4f} | "
f"G Loss: {loss_g.item():.4f}"
)
Evaluation
Conditional Generation Grid
After training, we generate a grid where each row corresponds to a digit class and each column uses a different noise vector:
"""Generate a 10x10 conditional generation grid."""
generator.eval()
n_per_class = 10
with torch.no_grad():
all_images = []
for digit in range(10):
label = F.one_hot(
torch.tensor([digit] * n_per_class), num_classes=10
).float()
noise = torch.randn(n_per_class, 100)
images = generator(noise, label)
all_images.append(images)
grid = torch.cat(all_images, dim=0)
print(f"Generated grid: {grid.shape}")
# Shape: (100, 1, 28, 28) - 10 rows x 10 columns
Style Consistency Test
Fix the noise vector and vary the condition to see if a consistent "style" transfers across digit classes:
"""Fix noise, vary condition to test style transfer."""
generator.eval()
fixed_z = torch.randn(1, 100)
with torch.no_grad():
for digit in range(10):
label = F.one_hot(
torch.tensor([digit]), num_classes=10
).float()
image = generator(fixed_z, label)
pixel_mean = image.mean().item()
print(
f" Digit {digit}: pixel mean = {pixel_mean:.3f}, "
f"pixel std = {image.std().item():.3f}"
)
When the same noise vector is used with different conditions, the generated digits share common stylistic attributes (stroke thickness, slant, size) while varying in identity. This demonstrates successful disentanglement of content (controlled by the condition) from style (controlled by the noise vector).
Classification Accuracy
To quantitatively evaluate conditioning accuracy, we use a pretrained digit classifier:
"""Evaluate conditioning accuracy with a pretrained classifier."""
# Train a simple classifier on real MNIST
classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 10),
)
# (Training code omitted for brevity - standard supervised training)
# Generate 1000 images per class and check classifier predictions
generator.eval()
total_correct = 0
total_generated = 0
with torch.no_grad():
for digit in range(10):
n_samples = 1000
label = F.one_hot(
torch.tensor([digit] * n_samples), num_classes=10
).float()
noise = torch.randn(n_samples, 100)
images = generator(noise, label)
# Denormalize from [-1,1] to [0,1]
images = (images + 1) / 2
logits = classifier(images)
preds = logits.argmax(dim=1)
correct = (preds == digit).sum().item()
total_correct += correct
total_generated += n_samples
print(f" Digit {digit}: {correct}/{n_samples} correct "
f"({100*correct/n_samples:.1f}%)")
accuracy = total_correct / total_generated
print(f"\nOverall conditioning accuracy: {accuracy:.4f}")
A well-trained conditional GAN achieves 85--95% conditioning accuracy, meaning the generated images are correctly classified as the intended digit class by the pretrained classifier.
Results and Analysis
What Conditioning Controls
- Condition (label): Controls which digit is generated (0, 1, 2, ..., 9).
- Noise vector: Controls all other visual attributes---stroke thickness, slant, size, position, and fine details.
Per-Class Quality
Some digit classes are easier to generate than others: - Easy: 0, 1 (simple structure, less variation). - Medium: 4, 7 (moderate complexity). - Hard: 2, 8 (more complex structure, higher variation).
Comparison with Unconditional DCGAN
| Aspect | Unconditional DCGAN | Conditional GAN |
|---|---|---|
| Control | None (random digit) | Choose which digit |
| Mode coverage | May miss some digits | All classes covered (forced) |
| Quality | Good | Slightly better (label helps) |
| Evaluation | FID only | FID + conditioning accuracy |
Key Takeaways
- Conditional GANs add control to generation by providing class labels to both the generator and discriminator.
- Concatenation-based conditioning is simple and effective for discrete labels.
- The noise vector controls style (thickness, slant) while the condition controls content (which digit).
- Conditioning implicitly helps with mode coverage, since the generator must learn to produce all classes.
- Quantitative evaluation of cGANs includes both sample quality (FID) and conditioning accuracy (classifier agreement).