Case Study 2: Exploring Latent Spaces with VAEs

Overview

One of the most fascinating properties of Variational Autoencoders is their smooth, structured latent space. Unlike deterministic autoencoders, where the latent space can be fragmented and uninterpretable, the VAE's KL divergence regularization produces a continuous manifold where nearby points decode to similar outputs and interpolation yields meaningful transitions.

In this case study, we train a VAE on MNIST and systematically explore its latent space through encoding, decoding, interpolation, and generation. We use a 2D latent space for visualization (despite the quality trade-off) and then scale up to higher dimensions for practical applications. This exploration builds deep intuition for how latent variable models organize information.


Problem Definition

Task: Train a VAE on MNIST and explore the structure of the learned latent space.

Goals: 1. Visualize how the encoder organizes different digit classes in latent space. 2. Generate new digits by sampling from the prior. 3. Perform smooth interpolation between digits. 4. Understand the trade-off between latent dimension, reconstruction quality, and latent space structure.


Implementation

VAE Architecture

We implement a convolutional VAE with configurable latent dimension:

"""Variational Autoencoder for latent space exploration."""

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(42)


class VAE(nn.Module):
    """Convolutional VAE with configurable latent dimension.

    Args:
        latent_dim: Dimensionality of the latent space.
    """

    def __init__(self, latent_dim: int = 2) -> None:
        super().__init__()
        self.latent_dim = latent_dim

        # Encoder
        self.enc_conv1 = nn.Conv2d(1, 32, 3, stride=2, padding=1)
        self.enc_conv2 = nn.Conv2d(32, 64, 3, stride=2, padding=1)
        self.enc_fc = nn.Linear(64 * 7 * 7, 256)
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)

        # Decoder
        self.dec_fc1 = nn.Linear(latent_dim, 256)
        self.dec_fc2 = nn.Linear(256, 64 * 7 * 7)
        self.dec_conv1 = nn.ConvTranspose2d(
            64, 32, 3, stride=2, padding=1, output_padding=1
        )
        self.dec_conv2 = nn.ConvTranspose2d(
            32, 1, 3, stride=2, padding=1, output_padding=1
        )

    def encode(
        self, x: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Encode input to latent distribution parameters."""
        h = F.relu(self.enc_conv1(x))
        h = F.relu(self.enc_conv2(h))
        h = h.view(h.size(0), -1)
        h = F.relu(self.enc_fc(h))
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(
        self, mu: torch.Tensor, logvar: torch.Tensor
    ) -> torch.Tensor:
        """Sample from q(z|x) using the reparameterization trick."""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + std * eps

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        """Decode latent code to reconstruction."""
        h = F.relu(self.dec_fc1(z))
        h = F.relu(self.dec_fc2(h))
        h = h.view(h.size(0), 64, 7, 7)
        h = F.relu(self.dec_conv1(h))
        return torch.sigmoid(self.dec_conv2(h))

    def forward(
        self, x: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Forward pass returning reconstruction, mu, and logvar."""
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

Loss Function

The VAE loss combines binary cross-entropy reconstruction loss with KL divergence:

def vae_loss(
    recon_x: torch.Tensor,
    x: torch.Tensor,
    mu: torch.Tensor,
    logvar: torch.Tensor,
    beta: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Compute VAE loss (negative ELBO).

    Args:
        recon_x: Reconstructed images.
        x: Original images.
        mu: Mean of approximate posterior.
        logvar: Log-variance of approximate posterior.
        beta: Weight on KL divergence term.

    Returns:
        Tuple of (total_loss, reconstruction_loss, kl_divergence).
    """
    recon_loss = F.binary_cross_entropy(
        recon_x, x, reduction="sum"
    )
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    total = recon_loss + beta * kl_div
    return total, recon_loss, kl_div

Training with KL Annealing

We implement linear KL annealing over the first 10 epochs to prevent posterior collapse:

"""Training loop with KL annealing."""

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

transform = transforms.ToTensor()
train_dataset = datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

model = VAE(latent_dim=2)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

n_epochs = 30
warmup_epochs = 10

for epoch in range(n_epochs):
    model.train()
    total_loss = 0.0
    total_recon = 0.0
    total_kl = 0.0

    # Linear KL annealing
    beta = min(1.0, epoch / warmup_epochs)

    for images, _ in train_loader:
        recon, mu, logvar = model(images)
        loss, recon_loss, kl = vae_loss(recon, images, mu, logvar, beta)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_recon += recon_loss.item()
        total_kl += kl.item()

    n = len(train_dataset)
    if (epoch + 1) % 5 == 0:
        print(
            f"Epoch {epoch+1}/{n_epochs} | "
            f"Loss: {total_loss/n:.2f} | "
            f"Recon: {total_recon/n:.2f} | "
            f"KL: {total_kl/n:.2f} | "
            f"Beta: {beta:.2f}"
        )

Exploration 1: Latent Space Encoding

After training, we encode all test images and plot them in 2D latent space, colored by digit class.

"""Encode test set and visualize latent space."""

import numpy as np

test_dataset = datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

model.eval()
all_mu = []
all_labels = []

with torch.no_grad():
    for images, labels in test_loader:
        mu, _ = model.encode(images)
        all_mu.append(mu.numpy())
        all_labels.append(labels.numpy())

all_mu = np.concatenate(all_mu)
all_labels = np.concatenate(all_labels)

# Plot (using matplotlib if available)
# Each digit class forms a cluster, with smooth transitions between them.
# Digits with similar structure (e.g., 3, 5, 8) are closer together.
print(f"Latent space range: [{all_mu.min():.2f}, {all_mu.max():.2f}]")
print(f"Latent space std: {all_mu.std(axis=0)}")

What to observe: Each digit class forms a roughly Gaussian cluster. The clusters overlap at their boundaries, creating transition regions where the model is uncertain. Structurally similar digits (like 4 and 9, or 3 and 8) tend to be closer together in latent space.


Exploration 2: Decoding a Grid

We create a grid of points in latent space and decode each one to visualize the generative manifold.

"""Decode a grid of latent points to visualize the manifold."""

n_grid = 20
z_range = 3.0

# Create evenly spaced grid in latent space
z1 = np.linspace(-z_range, z_range, n_grid)
z2 = np.linspace(-z_range, z_range, n_grid)

grid_images = np.zeros((n_grid * 28, n_grid * 28))

model.eval()
with torch.no_grad():
    for i, z1_val in enumerate(z1):
        for j, z2_val in enumerate(z2):
            z = torch.tensor([[z1_val, z2_val]], dtype=torch.float32)
            decoded = model.decode(z).squeeze().numpy()
            grid_images[
                i * 28 : (i + 1) * 28,
                j * 28 : (j + 1) * 28
            ] = decoded

# The grid shows smooth transitions between digit types.
# Moving along one axis might transition from 0 to 6 to 5.
# Moving along the other might transition from 1 to 7 to 9.
print(f"Grid shape: {grid_images.shape}")
print("Manifold generated: smooth transitions between digit types")

What to observe: The manifold shows smooth, continuous transitions between digit types. No sharp boundaries or random noise patches appear---every point in the grid decodes to something that looks like a plausible digit or a blend between digits.


Exploration 3: Latent Space Interpolation

We pick pairs of test digits and interpolate between their latent codes.

"""Interpolate between pairs of digits in latent space."""

def interpolate(
    model: VAE,
    img1: torch.Tensor,
    img2: torch.Tensor,
    n_steps: int = 10,
) -> list[np.ndarray]:
    """Linearly interpolate between two images in latent space.

    Args:
        model: Trained VAE model.
        img1: First image tensor (1, 1, 28, 28).
        img2: Second image tensor (1, 1, 28, 28).
        n_steps: Number of interpolation steps.

    Returns:
        List of decoded images along the interpolation path.
    """
    model.eval()
    with torch.no_grad():
        mu1, _ = model.encode(img1)
        mu2, _ = model.encode(img2)

        images = []
        for alpha in np.linspace(0, 1, n_steps):
            z = (1 - alpha) * mu1 + alpha * mu2
            decoded = model.decode(z).squeeze().numpy()
            images.append(decoded)

    return images


# Example: Interpolate between a 3 and an 8
idx_3 = (test_dataset.targets == 3).nonzero(as_tuple=True)[0][0]
idx_8 = (test_dataset.targets == 8).nonzero(as_tuple=True)[0][0]

img_3 = test_dataset[idx_3][0].unsqueeze(0)
img_8 = test_dataset[idx_8][0].unsqueeze(0)

interp_images = interpolate(model, img_3, img_8, n_steps=10)
print(f"Interpolation: {len(interp_images)} steps from digit 3 to 8")
# Expected: smooth morphing where the 3 gradually closes its gaps
# to become an 8.

What to observe: The interpolation is smooth and semantically meaningful. A 3 gradually closes its openings to become an 8. A 1 develops curves to become a 7. These smooth transitions confirm that the latent space is well-organized and that nearby latent codes correspond to visually similar images.


Exploration 4: Random Generation

We sample random latent codes from the prior and decode them.

"""Generate new digits by sampling from the prior."""

n_samples = 100

model.eval()
with torch.no_grad():
    z_samples = torch.randn(n_samples, model.latent_dim)
    generated = model.decode(z_samples)

print(f"Generated {n_samples} images, shape: {generated.shape}")
# With a 2D latent space, generated samples are recognizable but
# often blurry. Higher latent dimensions produce sharper results.

Exploration 5: Effect of Latent Dimension

We compare VAEs with $k \in \{2, 8, 32\}$ to understand the trade-off:

Latent Dim Reconstruction MSE KL Divergence Sample Quality Visualization
2 High Low Blurry Direct 2D plot
8 Medium Medium Decent Requires t-SNE
32 Low Higher Good Requires t-SNE

Key insight: The 2D latent space is great for visualization but too constrained for high-quality reconstruction and generation. For practical applications, use 32--128 dimensions and visualize with t-SNE or UMAP (Chapter 7).


Analysis: What Makes a Good Latent Space?

Smoothness

A good latent space is smooth: small movements in latent space produce small changes in the decoded output. The KL divergence term enforces this by preventing the encoder from placing different classes in isolated, disconnected regions.

Coverage

The prior $\mathcal{N}(\mathbf{0}, \mathbf{I})$ should cover meaningful data. If large regions of the prior decode to noise, the latent space has gaps. This can be checked by sampling from the prior and inspecting the quality of decoded samples.

Disentanglement

In an ideal latent space, each dimension controls a single interpretable factor of variation. With a standard VAE ($\beta = 1$), dimensions are typically entangled. Using $\beta > 1$ ($\beta$-VAE) encourages disentanglement but reduces reconstruction quality.

Informativeness

The latent code should carry useful information about the input. If all encoded points collapse to the prior (posterior collapse), the latent space is uninformative. We monitor this by tracking the KL divergence during training---if it drops to zero, collapse has occurred.


Key Takeaways

  1. The VAE's KL regularization produces a smooth, continuous latent space where interpolation is meaningful.
  2. A 2D latent space provides excellent visualization but limited reconstruction quality. Practical applications require higher dimensions.
  3. KL annealing is essential for preventing posterior collapse and ensuring the latent codes carry information.
  4. The quality of generated samples from a VAE is limited by the Gaussian decoder assumption, which produces blurry outputs. This motivates more advanced generative models (GANs in Chapter 17, diffusion models in Chapter 18).
  5. Latent space exploration (encoding, grid decoding, interpolation) provides deep intuition for how the model organizes and generates data.