For decades, convolutional neural networks (CNNs) reigned as the undisputed champions of computer vision. From AlexNet's breakthrough in 2012 to the sophisticated architectures of EfficientNet and ResNeXt, convolutions provided the inductive biases...
In This Chapter
- Introduction
- 26.1 From CNNs to Transformers: A Paradigm Shift
- 26.2 The Vision Transformer (ViT) Architecture
- 26.3 Understanding ViT's Behavior
- 26.4 DeiT: Data-Efficient Image Transformers
- 26.5 Swin Transformer: Hierarchical Vision Transformer
- 26.6 Object Detection with DETR
- 26.7 Semantic Segmentation with Vision Transformers
- 26.8 ViT vs. CNN: Understanding the Tradeoffs
- 26.9 Fine-Tuning ViT with HuggingFace
- 26.10 Advanced Topics
- 26.11 Practical Considerations
- 26.12 Historical Context and Future Directions
- 26.13 Summary
- 26.14 Exercises
- References
Chapter 26: Vision Transformers and Modern Computer Vision
Introduction
For decades, convolutional neural networks (CNNs) reigned as the undisputed champions of computer vision. From AlexNet's breakthrough in 2012 to the sophisticated architectures of EfficientNet and ResNeXt, convolutions provided the inductive biases — translation equivariance and locality — that made learning from images tractable. Then, in 2020, a paper from Google Research posed a provocative question: what if we simply applied a standard Transformer, with minimal modifications, directly to images?
The result was the Vision Transformer (ViT), and its success triggered a seismic shift in how we think about visual representation learning. In this chapter, you will learn how transformers have been adapted for vision tasks, why they often surpass CNNs when given sufficient data, and how modern architectures like Swin Transformer and DETR have pushed the boundaries of object detection and segmentation. By the end of this chapter, you will be able to implement, train, and fine-tune vision transformers using PyTorch and HuggingFace.
It is worth noting the remarkable speed of this revolution. The original ViT paper appeared as a preprint in October 2020, and within just three years nearly every major computer vision benchmark was dominated by transformer-based architectures. As we saw in Chapters 4-6, the transformer's flexibility and scalability had already reshaped natural language processing; this chapter tells the story of how that same architecture conquered the visual domain. We will also see how the lessons learned here — patch embeddings, hierarchical feature maps, and attention-based detection — lay the groundwork for the multimodal models discussed in Chapter 28 and the video transformers explored in Chapter 30.
26.1 From CNNs to Transformers: A Paradigm Shift
26.1.1 The CNN Foundation
Before diving into vision transformers, let us briefly recall what made CNNs so effective. The convolutional neural network story begins with Yann LeCun's LeNet (1998), which demonstrated that learned convolutional filters could recognize handwritten digits. The approach languished for a decade due to limited computing power and data, until AlexNet (Krizhevsky et al., 2012) won the ImageNet Large Scale Visual Recognition Challenge by a dramatic margin, reducing top-5 error from 25.8% to 16.4%. This triggered the deep learning revolution in computer vision.
A convolutional layer applies a set of learned filters (kernels) across spatial locations of an input feature map. This design encodes two critical inductive biases:
- Translation equivariance: A cat in the top-left corner activates the same filters as a cat in the bottom-right corner.
- Locality: Each neuron only "sees" a small receptive field, forcing the network to build representations hierarchically — from edges to textures to parts to objects.
These biases dramatically reduce the number of parameters compared to fully connected layers and allow CNNs to learn efficiently from relatively small datasets. A typical CNN architecture like ResNet-50 has approximately 25 million parameters and processes images through a sequence of convolutional blocks with progressively increasing channel dimensions and decreasing spatial resolution.
26.1.2 Limitations of Pure Convolution
Despite their success, CNNs have notable limitations:
- Limited receptive field: Even deep CNNs struggle to capture long-range dependencies. A pixel in one corner of an image has no direct connection to a pixel in the opposite corner until very late in the network.
- Fixed geometric structure: Convolutions operate on a rigid grid, making it difficult to model relationships that don't follow spatial proximity.
- Scaling behavior: While CNNs improve with more data, their performance gains plateau faster than transformers when scaled to massive datasets.
Researchers attempted to address these limitations through attention mechanisms added to CNNs (such as Squeeze-and-Excitation networks and non-local neural networks), but these were incremental additions rather than fundamental redesigns.
26.1.3 The Transformer Opportunity
The Transformer architecture, introduced by Vaswani et al. (2017) for natural language processing, offered a compelling alternative. Its self-attention mechanism computes pairwise interactions between all elements in a sequence, naturally capturing long-range dependencies without the hierarchical bottleneck of convolutions. The question was: how could we apply this architecture, designed for 1D token sequences, to 2D images?
Early attempts to combine attention with vision included: - Non-local neural networks (Wang et al., 2018): Added self-attention blocks within CNN architectures, improving performance on video classification. - Stand-Alone Self-Attention (Ramachandran et al., 2019): Replaced convolutions with local self-attention in ResNet-style architectures, achieving competitive results but with higher computational cost. - BoTNet (Srinivas et al., 2021): Replaced the spatial convolutions in the last three ResNet bottleneck blocks with multi-head self-attention, improving ImageNet accuracy with minimal computational overhead.
These works demonstrated the potential of attention for vision but were incremental modifications to CNN architectures. The Vision Transformer took the more radical approach of applying the transformer architecture directly, with minimal vision-specific modifications.
26.2 The Vision Transformer (ViT) Architecture
26.2.1 Patch Embedding: Turning Images into Sequences
The key insight of the Vision Transformer is remarkably simple: treat an image as a sequence of patches, just as NLP transformers treat text as a sequence of tokens.
Given an image x of size $H \times W \times C$ (height, width, channels), we divide it into a grid of non-overlapping patches, each of size $P \times P$. This yields $N = HW / P^2$ patches. Each patch is flattened into a vector of dimension $P^2 \cdot C$ and then linearly projected to a $d$-dimensional embedding:
$$\mathbf{z}_i^0 = \mathbf{x}_{\text{patch}}^{(i)} \mathbf{E} + \mathbf{e}_{\text{pos}}^{(i)}$$
where: - $\mathbf{x}_{\text{patch}}^{(i)} \in \mathbb{R}^{P^2 C}$ is the flattened $i$-th patch - $\mathbf{E} \in \mathbb{R}^{(P^2 C) \times d}$ is the patch embedding projection matrix - $\mathbf{e}_{\text{pos}}^{(i)} \in \mathbb{R}^d$ is the learnable position embedding for position $i$ - $\mathbf{z}_i^0 \in \mathbb{R}^d$ is the initial embedding for patch $i$
Worked Example: For a standard ImageNet image of size $224 \times 224 \times 3$ with patch size $P = 16$: - Number of patches: $N = (224 \times 224) / (16 \times 16) = 196$ - Flattened patch dimension: $16 \times 16 \times 3 = 768$ - With embedding dimension $d = 768$, the projection matrix E has shape $768 \times 768$
This is equivalent to applying a single convolutional layer with kernel size 16, stride 16, and 768 output channels — a connection that makes implementation straightforward.
Implementation in PyTorch: The patch embedding can be implemented cleanly as a convolutional layer:
import torch
import torch.nn as nn
class PatchEmbedding(nn.Module):
"""Convert an image into a sequence of patch embeddings.
Args:
img_size: Input image resolution (assumes square images).
patch_size: Size of each square patch.
in_channels: Number of input channels (3 for RGB).
embed_dim: Dimension of the output embeddings.
"""
def __init__(
self,
img_size: int = 224,
patch_size: int = 16,
in_channels: int = 3,
embed_dim: int = 768,
) -> None:
super().__init__()
self.num_patches = (img_size // patch_size) ** 2
self.proj = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Project image patches into embedding space.
Args:
x: Input images of shape [batch, channels, height, width].
Returns:
Patch embeddings of shape [batch, num_patches, embed_dim].
"""
x = self.proj(x) # [B, embed_dim, H/P, W/P]
x = x.flatten(2).transpose(1, 2) # [B, num_patches, embed_dim]
return x
# Quick sanity check
patch_embed = PatchEmbedding()
dummy = torch.randn(2, 3, 224, 224)
tokens = patch_embed(dummy)
print(f"Patch tokens shape: {tokens.shape}") # [2, 196, 768]
Why a single convolution works: The convolutional layer with kernel size equal to stride is mathematically identical to flattening each patch and multiplying by the projection matrix E. Each output "pixel" in the convolution output corresponds to one patch, and the 768 output channels correspond to the embedding dimension. This equivalence was noted in the original ViT paper and is the standard implementation strategy used in all major frameworks.
Overlapping patch embeddings: Some later architectures, such as SegFormer, use overlapping patches (stride smaller than kernel size) to capture information at patch boundaries. For instance, a kernel size of 7 with stride 4 produces patches that overlap by 3 pixels, smoothing the transition between adjacent tokens and improving performance on dense prediction tasks.
26.2.2 The [CLS] Token
Following BERT's convention, ViT prepends a special learnable [CLS] token to the sequence of patch embeddings. This token does not correspond to any image patch; instead, it serves as a global representation that aggregates information from all patches through the self-attention mechanism. After passing through all transformer layers, the [CLS] token's output is used as the image representation for classification.
The complete input sequence is:
$$\mathbf{Z}^0 = [\mathbf{z}_{\text{cls}}; \mathbf{z}_1^0; \mathbf{z}_2^0; \ldots; \mathbf{z}_N^0]$$
This gives a sequence of length $N + 1$, where $N = 196$ for the standard configuration.
26.2.3 Position Embeddings
Unlike CNNs, transformers have no built-in notion of spatial position. Without position embeddings, the model would be permutation-invariant — it couldn't distinguish between two images that contain the same patches in different arrangements.
ViT uses learnable 1D position embeddings, one for each position in the sequence (including the [CLS] token). These are added to the patch embeddings before being fed into the transformer encoder.
Interestingly, the authors found that learned 1D embeddings performed comparably to more sophisticated 2D-aware schemes. When visualized, the learned position embeddings exhibit a clear 2D spatial structure — nearby patches have similar embeddings — demonstrating that the model discovers spatial relationships from data alone.
Position embedding strategies compared: Several alternatives have been explored in the literature, each with distinct tradeoffs:
| Strategy | Description | Pros | Cons |
|---|---|---|---|
| Learnable 1D | One learnable vector per sequence position | Simple; works well | Fixed sequence length |
| Learnable 2D | Separate row and column embeddings, summed | Explicit 2D awareness | Marginal improvement over 1D |
| Sinusoidal 2D | Fixed sin/cos functions of 2D coordinates | No parameters; any resolution | Slightly worse than learned |
| Relative (RPB) | Bias based on relative patch positions | Resolution-flexible; used in Swin | Adds complexity |
| Rotary (RoPE) | Rotation-based encoding in attention | Extrapolates to unseen lengths | Requires implementation care |
| Conditional (CPE) | Depth-wise convolution generating position info | Handles variable resolutions | Additional computation |
The choice of position embedding becomes especially important when fine-tuning at a different resolution than pre-training. Learnable 1D embeddings trained at 224x224 (196 patches) cannot be directly used at 384x384 (576 patches). The standard solution is bicubic interpolation: reshape the 1D position embeddings into a 2D grid (14x14 for ViT-B/16), bicubically interpolate to the new grid size (24x24 for 384x384 input), and flatten back to 1D. This works remarkably well, with fine-tuning at higher resolution typically improving accuracy by 1-2% on ImageNet.
Worked Example — Position embedding interpolation: Suppose you pre-trained ViT-B/16 at 224x224 (14x14 = 196 patches). Now you want to fine-tune at 384x384 resolution, giving 24x24 = 576 patches. The [CLS] token embedding stays unchanged. The 196 patch position embeddings are reshaped to a 14x14 grid, interpolated to 24x24 using bicubic interpolation, and flattened back to 576 vectors. The total position embedding matrix goes from shape $(197, 768)$ to $(577, 768)$.
26.2.4 The Computational Cost of Self-Attention for Vision
Before examining the transformer encoder details, it is worth understanding the computational implications of applying self-attention to images. For a sequence of $N + 1$ tokens (including [CLS]), each attention layer computes query, key, and value projections, followed by the attention matrix multiplication. The dominant costs are:
- QKV projections: $3 \times (N + 1) \times d \times d$ FLOPs
- Attention matrix: $(N + 1)^2 \times d$ FLOPs
- Attention-value product: $(N + 1)^2 \times d$ FLOPs
- Output projection: $(N + 1) \times d \times d$ FLOPs
For ViT-Base/16 with $N = 196$ and $d = 768$, the total per-layer attention FLOPs are approximately 460 million. With 12 layers, the full encoder requires about 5.5 billion FLOPs for a single image — roughly comparable to ResNet-50 (4.1 billion FLOPs). However, at higher resolutions the quadratic term $(N+1)^2$ dominates: doubling the resolution quadruples the number of patches and increases attention cost by 16x, while a CNN's cost only quadruples.
This analysis motivates the efficiency innovations in later architectures: Swin Transformer's window attention (Section 26.5), efficient self-attention in SegFormer (Section 26.7.2), and the various factored attention schemes for video in Chapter 30.
26.2.5 The Transformer Encoder
The transformer encoder in ViT follows the standard architecture with $L$ layers, each consisting of:
- Layer Normalization (Pre-Norm): Applied before each sub-layer, which differs from the post-norm convention in the original transformer.
- Multi-Head Self-Attention (MHSA):
$$\text{MHSA}(\mathbf{Z}) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)\mathbf{W}^O$$
$$\text{head}_j = \text{Attention}(\mathbf{Z}\mathbf{W}_j^Q, \mathbf{Z}\mathbf{W}_j^K, \mathbf{Z}\mathbf{W}_j^V)$$
$$\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^\top}{\sqrt{d_k}}\right)\mathbf{V}$$
- Feed-Forward Network (FFN): A two-layer MLP with GELU activation:
$$\text{FFN}(\mathbf{z}) = \text{GELU}(\mathbf{z}\mathbf{W}_1 + \mathbf{b}_1)\mathbf{W}_2 + \mathbf{b}_2$$
The complete forward pass for layer $\ell$ is:
$$\mathbf{Z}'^{\ell} = \text{MHSA}(\text{LN}(\mathbf{Z}^{\ell-1})) + \mathbf{Z}^{\ell-1}$$
$$\mathbf{Z}^{\ell} = \text{FFN}(\text{LN}(\mathbf{Z}'^{\ell})) + \mathbf{Z}'^{\ell}$$
where LN denotes Layer Normalization and residual connections are used in both sub-layers.
26.2.5 Classification Head
The final classification is performed using the [CLS] token output from the last layer:
$$\hat{y} = \text{MLP}_{\text{head}}(\text{LN}(\mathbf{z}_{\text{cls}}^L))$$
During pre-training, the MLP head consists of one hidden layer with a tanh activation. During fine-tuning, it is replaced by a single linear layer.
Alternative pooling strategies: While the [CLS] token is the standard approach, alternatives exist: - Global Average Pooling (GAP): Average all patch token outputs from the last layer to produce a single representation. Some implementations find this slightly outperforms the [CLS] token, particularly when the model is used as a feature extractor for downstream tasks. - Attention pooling: Use a learned attention mechanism to produce a weighted average of patch tokens, allowing the model to focus on the most informative patches for classification. - Multi-token pooling: Use outputs from multiple layers (not just the last) by concatenating or averaging, capturing both low-level and high-level features.
26.2.6 ViT Model Variants
The original paper defined three model sizes:
| Model | Layers ($L$) | Hidden Dim ($d$) | Heads ($h$) | MLP Dim | Params |
|---|---|---|---|---|---|
| ViT-Base | 12 | 768 | 12 | 3072 | 86M |
| ViT-Large | 24 | 1024 | 16 | 4096 | 307M |
| ViT-Huge | 32 | 1280 | 16 | 5120 | 632M |
The naming convention includes patch size, e.g., ViT-B/16 means ViT-Base with 16x16 patches.
26.3 Understanding ViT's Behavior
26.3.1 The Data Hunger Problem
The original ViT paper revealed a crucial finding: ViT underperforms ResNets when trained on mid-sized datasets like ImageNet-1K (~1.3M images), but surpasses them when pre-trained on larger datasets like ImageNet-21K (~14M images) or JFT-300M (~300M images).
This makes intuitive sense. CNNs encode strong inductive biases about images (locality, translation equivariance) that provide an excellent learning prior when data is scarce. Transformers, lacking these biases, must learn spatial relationships entirely from data. With sufficient data, this flexibility becomes an advantage — the model discovers representations unconstrained by architectural assumptions.
26.3.2 Attention Patterns
Visualizing attention maps reveals that ViT learns meaningful spatial patterns:
- Early layers: Attention heads attend to local neighborhoods, mimicking small convolutional filters. Some heads specialize in horizontal edges, others in vertical edges.
- Middle layers: Attention becomes more diverse, with some heads capturing medium-range dependencies while others develop global patterns.
- Late layers: Many heads exhibit long-range, global attention, integrating information across the entire image.
This progression from local to global is reminiscent of CNN feature hierarchies, but emerges entirely from the data rather than being imposed by architecture.
26.3.3 Effective Receptive Field
Unlike CNNs, where the receptive field grows linearly or logarithmically with depth, every ViT layer has a global receptive field. Each patch token can attend to every other patch token in a single layer. This is both a strength (capturing long-range dependencies) and a computational challenge (quadratic complexity in sequence length).
26.3.4 Positional Embedding Analysis
When we visualize the learned position embeddings by computing cosine similarities between all pairs, we see that:
- Nearby patches have similar position embeddings (capturing locality).
- Patches in the same row or column show stronger similarity (capturing 2D grid structure).
- The model effectively learns a 2D positional encoding from a 1D parameterization.
26.3.5 Scaling Laws for ViT
One of the most important findings from the original ViT paper concerns scaling behavior. When plotted against pre-training dataset size, the performance curves of ViT and ResNet cross: below approximately 10 million images, ResNet consistently outperforms ViT; above that threshold, ViT begins to pull ahead, with the gap widening as data increases.
This finding was later quantified more precisely by Zhai et al. (2022) in their "Scaling Vision Transformers" paper, which trained ViT models with up to 2 billion parameters (ViT-22B) and found: - Performance scales log-linearly with model size and training data, analogous to the scaling laws observed for language models. - The largest models (ViT-22B) achieve 89.5% top-1 on ImageNet, leaving very little room for improvement on this benchmark. - These scaling laws suggest that further improvements in vision may come primarily from scale — more data, more parameters, more compute — rather than architectural innovation, echoing the findings for language models discussed in Chapter 14.
26.4 DeiT: Data-Efficient Image Transformers
26.4.1 Motivation
The requirement for massive pre-training datasets was a significant practical limitation of ViT. The DeiT (Data-efficient Image Transformers) paper by Touvron et al. (2021) demonstrated that with the right training recipe, ViT can be trained competitively on ImageNet-1K alone.
26.4.2 Training Strategy
DeiT's key contributions are in training methodology rather than architecture:
-
Strong data augmentation: RandAugment, random erasing, Mixup, and CutMix provide the regularization that compensates for the lack of convolutional inductive biases.
-
Regularization: Stochastic depth (dropping entire transformer layers during training), label smoothing, and repeated augmentation prevent overfitting.
-
Optimization: AdamW optimizer with cosine learning rate schedule and warm-up, trained for 300 epochs (compared to ViT's ~30 epochs on JFT).
-
Distillation token: DeiT introduces a novel distillation token in addition to the [CLS] token. This token is trained to match the output of a CNN teacher (typically a RegNet), providing an architectural mechanism for knowledge distillation.
26.4.3 Distillation Mechanisms
DeiT explored two forms of distillation:
Soft distillation uses the KL divergence between teacher and student softmax outputs:
$$\mathcal{L}_{\text{soft}} = (1 - \lambda) \mathcal{L}_{\text{CE}}(\psi(\mathbf{z}_s), y) + \lambda \tau^2 \text{KL}(\psi(\mathbf{z}_s / \tau), \psi(\mathbf{z}_t / \tau))$$
where: - $\psi$ is the softmax function - $\mathbf{z}_s, \mathbf{z}_t$ are student and teacher logits - $y$ is the ground truth label - $\tau$ is the temperature parameter - $\lambda$ balances the two loss terms
Hard-label distillation uses the teacher's argmax prediction as a pseudo-label:
$$\mathcal{L}_{\text{hard}} = \frac{1}{2} \mathcal{L}_{\text{CE}}(\psi(\mathbf{z}_s^{\text{cls}}), y) + \frac{1}{2} \mathcal{L}_{\text{CE}}(\psi(\mathbf{z}_s^{\text{dist}}), y_t)$$
where $y_t = \arg\max_c z_t^{(c)}$ is the hard label from the teacher.
Remarkably, hard-label distillation outperformed soft distillation, and the distillation token learned complementary representations to the [CLS] token — their predictions agreed only 60% of the time on ImageNet, yet ensembling them improved accuracy.
26.4.4 Results
DeiT-Base achieved 83.4% top-1 accuracy on ImageNet with only ImageNet-1K training data, compared to ViT-Base's 77.9% trained on the same data. With distillation, DeiT-Base reached 85.2%, competitive with the state-of-the-art EfficientNet-B7 while being more GPU-efficient.
26.4.5 Lessons from DeiT for Practitioners
DeiT's contributions extend beyond any single architecture — they demonstrated that training methodology matters as much as architecture. The specific recipe provides a template for training any ViT from scratch on moderate datasets:
-
Data augmentation is non-negotiable: Without RandAugment, Mixup, and CutMix, ViT-Base accuracy drops by more than 5% on ImageNet-1K. The augmentations compensate for the lack of convolutional inductive biases by showing the model diverse transformations of the same image.
-
Regularization must be aggressive: Stochastic depth with drop rate 0.1, label smoothing of 0.1, and dropout of 0.0 (surprisingly, dropout hurts in combination with stochastic depth) form the regularization baseline. The rationale is that transformers have enormous capacity and will overfit without strong constraints.
-
Longer training pays off: DeiT trains for 300 epochs compared to the typical 90-epoch CNN schedule. The learning rate follows a cosine decay after a 5-epoch linear warmup. Transformers appear to need more training iterations to develop the spatial inductive biases that CNNs get for free.
-
Teacher choice matters for distillation: DeiT found that a CNN teacher (RegNetY-16GF) produced a better student than a transformer teacher. The intuition is that the CNN provides complementary inductive biases that the student lacks, whereas a transformer teacher would simply reinforce the same failure modes.
These principles have been adopted widely and form the basis of most modern ViT training pipelines, as we will also see in the training strategies for video transformers in Chapter 30.
26.5 Swin Transformer: Hierarchical Vision Transformer
26.5.1 Motivation and Design Philosophy
While ViT demonstrated the viability of transformers for vision, it had two practical limitations for general vision tasks:
- Quadratic complexity: Self-attention over all $N$ patches scales as $O(N^2)$, limiting application to high-resolution images.
- Single-scale features: ViT produces feature maps at a single resolution, but tasks like object detection and segmentation require multi-scale feature pyramids.
The Swin Transformer (Shifted Window Transformer) by Liu et al. (2021) addresses both issues with an elegant hierarchical architecture.
26.5.2 Architecture Overview
Swin Transformer builds a hierarchical feature representation with the following structure:
Stage 1: Input image ($H \times W \times 3$) is split into $4 \times 4$ patches, producing $\frac{H}{4} \times \frac{W}{4}$ tokens of dimension $C = 96$.
Stage 2: Patch merging reduces spatial resolution by 2x (concatenating $2 \times 2$ groups and projecting), yielding $\frac{H}{8} \times \frac{W}{8}$ tokens of dimension $2C$.
Stage 3: Another patch merging step produces $\frac{H}{16} \times \frac{W}{16}$ tokens of dimension $4C$.
Stage 4: Final patch merging produces $\frac{H}{32} \times \frac{W}{32}$ tokens of dimension $8C$.
This hierarchical structure mirrors CNNs and produces feature maps at multiple scales, making Swin Transformer directly compatible with Feature Pyramid Networks (FPN) and other multi-scale architectures.
26.5.3 Window-Based Self-Attention
Instead of computing global self-attention, Swin Transformer partitions the feature map into non-overlapping windows of size $M \times M$ (default $M = 7$) and computes self-attention within each window.
For a feature map of $h \times w$ tokens, this reduces the complexity from $O((hw)^2)$ to $O(hw \cdot M^2)$ — linear in the number of tokens for a fixed window size.
However, non-overlapping windows prevent information flow between windows. Swin Transformer solves this with shifted windows.
26.5.4 Shifted Window Mechanism
In consecutive transformer layers, the window partition is shifted by $(\lfloor M/2 \rfloor, \lfloor M/2 \rfloor)$ pixels. This creates cross-window connections at every other layer:
- Layer $\ell$ (regular windows): Self-attention within standard $M \times M$ windows.
- Layer $\ell + 1$ (shifted windows): Windows are shifted by half the window size, creating new windows that span the boundaries of the previous partition.
The two consecutive layers are computed as:
$$\hat{\mathbf{z}}^{\ell} = \text{W-MSA}(\text{LN}(\mathbf{z}^{\ell-1})) + \mathbf{z}^{\ell-1}$$ $$\mathbf{z}^{\ell} = \text{FFN}(\text{LN}(\hat{\mathbf{z}}^{\ell})) + \hat{\mathbf{z}}^{\ell}$$ $$\hat{\mathbf{z}}^{\ell+1} = \text{SW-MSA}(\text{LN}(\mathbf{z}^{\ell})) + \mathbf{z}^{\ell}$$ $$\mathbf{z}^{\ell+1} = \text{FFN}(\text{LN}(\hat{\mathbf{z}}^{\ell+1})) + \hat{\mathbf{z}}^{\ell+1}$$
where W-MSA and SW-MSA denote window-based and shifted-window-based multi-head self-attention, respectively.
26.5.5 Efficient Batch Computation
The shifted window approach creates windows of varying sizes at the borders. Rather than padding (which wastes computation), Swin Transformer uses a clever cyclic shift strategy:
- Cyclically shift the feature map by $(-\lfloor M/2 \rfloor, -\lfloor M/2 \rfloor)$.
- Apply standard window partitioning to the shifted feature map.
- Use attention masks to prevent interaction between tokens from different original windows that now share a shifted window.
- Reverse the cyclic shift on the output.
This maintains efficient batched computation while correctly implementing shifted-window attention.
26.5.6 Relative Position Bias
Rather than using absolute position embeddings, Swin Transformer adds a relative position bias to each attention head:
$$\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^\top}{\sqrt{d_k}} + \mathbf{B}\right)\mathbf{V}$$
where $\mathbf{B} \in \mathbb{R}^{M^2 \times M^2}$ is the relative position bias matrix. Since relative positions along each axis range from $-M+1$ to $M+1$, the bias is parameterized by a smaller matrix $\hat{\mathbf{B}} \in \mathbb{R}^{(2M-1) \times (2M-1)}$ and indexed appropriately.
This relative bias significantly improves performance (1-2% on ImageNet) and enables generalization to different input resolutions.
26.5.7 Model Variants
| Model | Channels ($C$) | Layers per Stage | Params | ImageNet Top-1 |
|---|---|---|---|---|
| Swin-T | 96 | (2, 2, 6, 2) | 29M | 81.3% |
| Swin-S | 96 | (2, 2, 18, 2) | 50M | 83.0% |
| Swin-B | 128 | (2, 2, 18, 2) | 88M | 83.5% |
| Swin-L | 192 | (2, 2, 18, 2) | 197M | 86.4%* |
*Pre-trained on ImageNet-22K.
26.5.8 Swin Transformer for Dense Prediction
Swin Transformer's hierarchical design makes it a direct replacement for CNN backbones in dense prediction frameworks. For object detection, Swin is typically paired with Feature Pyramid Network (FPN) or its variants:
-
Swin + Cascade Mask R-CNN: On the COCO detection benchmark, Swin-L achieves 58.7 box AP, significantly outperforming the previous best CNN backbone (HTC++ with ResNeXt-101, at 52.3 AP).
-
Swin + UperNet: For semantic segmentation on ADE20K, Swin-L achieves 53.5 mIoU, a substantial improvement over the CNN-based state of the art.
The key reason Swin excels at these tasks is that its multi-scale feature maps (at 1/4, 1/8, 1/16, and 1/32 of the input resolution) naturally provide the feature pyramid that detection and segmentation frameworks require. By contrast, a plain ViT produces only single-scale features, requiring additional engineering to create multi-scale representations.
26.5.9 Swin V2 Improvements
Swin Transformer V2 (Liu et al., 2022) introduced three improvements for scaling to larger models and higher resolutions:
-
Post-norm instead of pre-norm: Moving layer normalization to after attention and FFN (the original transformer convention) stabilizes training for models with over 1 billion parameters.
-
Scaled cosine attention: Replacing the dot-product attention with cosine similarity and a learnable scaling factor prevents attention values from growing too large at high resolutions.
-
Log-spaced continuous position bias (Log-CPB): Instead of learning a fixed bias table, Swin V2 generates position biases using a small MLP on log-spaced relative coordinates, enabling smooth transfer to unseen window sizes.
These improvements enabled training Swin V2 at up to 3 billion parameters and 1,536x1,536 resolution — an unprecedented scale for vision transformers.
26.6 Object Detection with DETR
26.6.1 The Problem with Traditional Detectors
Traditional object detection pipelines involve complex, hand-designed components:
- Anchor generation: Pre-defined bounding box templates at multiple scales and aspect ratios.
- Non-Maximum Suppression (NMS): Post-processing to remove duplicate detections.
- Region proposal networks: Separate sub-networks to propose candidate regions.
These components require careful tuning and inject significant domain-specific engineering into the pipeline.
26.6.2 DETR: DEtection TRansformer
DETR (Carion et al., 2020) eliminates these components with an elegant end-to-end architecture:
- CNN Backbone: A standard CNN (e.g., ResNet-50) extracts feature maps from the input image.
- Transformer Encoder: Flattened spatial features with positional encodings are processed by a standard transformer encoder.
- Transformer Decoder: A set of $N$ learnable object queries (typically $N = 100$) attend to the encoder output through cross-attention. Each query independently predicts one object or "no object."
- Prediction Heads: Simple FFNs predict class labels and bounding box coordinates for each query.
26.6.3 Bipartite Matching Loss
DETR's training uses Hungarian matching to find the optimal one-to-one assignment between predictions and ground truth objects. Given a set of $N$ predictions and a set of ground truth objects (padded with "no object" to size $N$), we find the permutation $\sigma$ that minimizes:
$$\hat{\sigma} = \arg\min_{\sigma \in \mathfrak{S}_N} \sum_{i=1}^{N} \mathcal{L}_{\text{match}}(y_i, \hat{y}_{\sigma(i)})$$
where the matching cost combines classification and box regression:
$$\mathcal{L}_{\text{match}} = -\mathbb{1}_{c_i \neq \varnothing} \hat{p}_{\sigma(i)}(c_i) + \mathbb{1}_{c_i \neq \varnothing} \mathcal{L}_{\text{box}}(b_i, \hat{b}_{\sigma(i)})$$
The box loss combines L1 loss with generalized IoU (GIoU) loss:
$$\mathcal{L}_{\text{box}} = \lambda_{\text{L1}} \|b_i - \hat{b}_{\sigma(i)}\|_1 + \lambda_{\text{GIoU}} \mathcal{L}_{\text{GIoU}}(b_i, \hat{b}_{\sigma(i)})$$
26.6.4 Object Queries as Learned Anchors
Each object query can be thought of as a learned, adaptive anchor. Through training, different queries specialize in detecting objects at specific spatial locations and scales. Visualizing the attention maps of object queries reveals that each query attends to a distinct spatial region, effectively dividing the image into detection zones without explicit anchor design.
26.6.5 Strengths and Limitations
Strengths: - Eliminates NMS, anchors, and other hand-designed components. - Excels at detecting large objects due to global attention. - Simple, elegant architecture with fewer hyperparameters.
Limitations: - Slow convergence (500 epochs vs. ~36 for Faster R-CNN). - Struggles with small objects due to coarse feature resolution. - Fixed number of queries limits maximum detections per image.
Later works like Deformable DETR and DINO addressed these limitations through deformable attention, mixed query selection, and contrastive denoising training.
26.6.6 Deformable DETR and Improvements
Deformable DETR (Zhu et al., 2021) addresses DETR's convergence and small-object limitations by replacing global attention with deformable attention, where each query attends to only a small set of key sampling points around a reference point:
$$\text{DeformAttn}(\mathbf{q}, \mathbf{p}, \mathbf{x}) = \sum_{m=1}^{M} \mathbf{W}_m \sum_{k=1}^{K} A_{mk} \cdot \mathbf{x}(\mathbf{p} + \Delta\mathbf{p}_{mk})$$
where: - $\mathbf{q}$ is the query feature - $\mathbf{p}$ is the reference point - $K$ is the number of sampled keys per attention head (typically 4) - $A_{mk}$ is the attention weight (predicted by the network) - $\Delta\mathbf{p}_{mk}$ is the sampling offset (also predicted) - $M$ is the number of attention heads
The intuition is that each query "looks at" only a few relevant locations rather than the entire feature map. This reduces complexity from $O(N^2)$ to $O(NK)$ where $K \ll N$, and enables multi-scale feature aggregation by attending to points across different feature pyramid levels.
Deformable DETR converges 10x faster than DETR (50 vs. 500 epochs) and significantly improves small-object detection, making it practical for real-world deployment.
DINO (DETR with Improved deNoising anchOr boxes) further advances the DETR family through contrastive denoising training and mixed query selection, achieving 63.3 AP on COCO — a landmark result for end-to-end detection.
26.6.7 Practical Object Detection with HuggingFace
You can use DETR models directly from HuggingFace for inference:
import torch
from transformers import DetrForObjectDetection, DetrImageProcessor
from PIL import Image
def detect_objects(
image_path: str,
model_name: str = "facebook/detr-resnet-50",
threshold: float = 0.9,
) -> list[dict]:
"""Detect objects in an image using DETR.
Args:
image_path: Path to the input image.
model_name: HuggingFace model identifier.
threshold: Confidence threshold for detections.
Returns:
List of dicts with 'label', 'score', and 'box' keys.
"""
processor = DetrImageProcessor.from_pretrained(model_name)
model = DetrForObjectDetection.from_pretrained(model_name)
image = Image.open(image_path)
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(
outputs, target_sizes=target_sizes, threshold=threshold,
)[0]
detections = []
for score, label, box in zip(
results["scores"], results["labels"], results["boxes"]
):
detections.append({
"label": model.config.id2label[label.item()],
"score": score.item(),
"box": box.tolist(),
})
return detections
26.7 Semantic Segmentation with Vision Transformers
26.7.1 From Classification to Dense Prediction
Semantic segmentation requires assigning a class label to every pixel in an image, producing a dense prediction map rather than a single class label. This requires preserving spatial resolution, which is challenging for ViT's fixed-resolution patch tokens.
26.7.2 SegFormer
SegFormer (Xie et al., 2021) is a transformer-based segmentation architecture that combines hierarchical features with a simple MLP decoder:
Hierarchical Encoder: Similar to Swin Transformer, SegFormer produces multi-scale features at 1/4, 1/8, 1/16, and 1/32 of the input resolution. It uses overlapping patch embeddings (using convolutions with overlap) and Mix-FFN (FFN with depthwise convolution) for better positional awareness.
Efficient Self-Attention: SegFormer reduces the spatial resolution of keys and values by a factor $R$, reducing attention complexity from $O(N^2)$ to $O(N^2/R)$.
MLP Decoder: All four feature scales are upsampled to 1/4 resolution, concatenated, and processed by a simple MLP to produce segmentation maps:
$$\hat{m} = \text{Linear}(\text{Concat}(\text{Upsample}(\mathbf{F}_1), \ldots, \text{Upsample}(\mathbf{F}_4)))$$
26.7.3 DPT (Dense Prediction Transformer)
DPT (Ranftl et al., 2021) takes a different approach, using a plain ViT encoder with a convolutional decoder:
- Extract tokens from multiple ViT layers (e.g., layers 3, 6, 9, 12).
- Reshape tokens back to 2D spatial maps.
- Apply a convolutional refinement network that progressively upsamples and fuses multi-layer features.
- Produce the final dense prediction at full resolution.
This approach demonstrates that even a plain ViT, without hierarchical design, can be adapted for dense prediction through appropriate decoding.
26.7.4 Segment Anything Model (SAM)
The Segment Anything Model (SAM) by Meta AI represents a paradigm shift in segmentation. Trained on over 1 billion masks, SAM can segment any object in any image given various prompts (points, boxes, text). Its architecture consists of:
- Image Encoder: A ViT-H backbone that processes the image once. This is the computationally expensive step, producing image embeddings of shape 64x64x256 from a 1024x1024 input.
- Prompt Encoder: Encodes sparse prompts (points, boxes) with positional embeddings and dense prompts (masks) with convolutions. Point prompts are encoded as the sum of a learned embedding for the prompt type (foreground/background) and a positional encoding.
- Mask Decoder: A lightweight transformer decoder (only two layers) that combines image and prompt embeddings to predict segmentation masks. It uses bidirectional cross-attention — prompt tokens attend to image tokens and vice versa — followed by an MLP that predicts three mask candidates at different granularities along with their quality scores.
SAM's data engine: The 1 billion masks in the SA-1B dataset were collected through a three-phase process that illustrates the power of human-AI collaboration: (1) manual annotation assisted by SAM, (2) semi-automatic annotation where SAM proposed masks and annotators refined them, and (3) fully automatic mask generation where SAM segmented everything in millions of images without human intervention. This data flywheel is a design pattern worth studying for any large-scale annotation effort.
SAM 2 extended the model to video segmentation, adding a memory mechanism that propagates object information across frames. It uses a memory attention module that conditions the current frame's segmentation on stored memory tokens from previous frames, enabling real-time interactive video segmentation.
Practical usage with HuggingFace:
import torch
from transformers import SamModel, SamProcessor
from PIL import Image
def segment_with_point(
image_path: str,
point_coords: list[list[int]],
point_labels: list[int],
model_name: str = "facebook/sam-vit-base",
) -> torch.Tensor:
"""Segment an image using point prompts with SAM.
Args:
image_path: Path to the input image.
point_coords: List of [x, y] coordinates for prompt points.
point_labels: List of labels (1=foreground, 0=background).
model_name: HuggingFace model identifier.
Returns:
Predicted mask tensor.
"""
processor = SamProcessor.from_pretrained(model_name)
model = SamModel.from_pretrained(model_name)
image = Image.open(image_path)
inputs = processor(
image,
input_points=[point_coords],
input_labels=[point_labels],
return_tensors="pt",
)
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks,
inputs["original_sizes"],
inputs["reshaped_input_sizes"],
)
return masks[0]
26.8 ViT vs. CNN: Understanding the Tradeoffs
26.8.1 Inductive Biases
| Property | CNN | ViT |
|---|---|---|
| Translation equivariance | Built-in (weight sharing) | Must be learned |
| Locality | Built-in (small kernels) | Must be learned |
| Scale invariance | Partial (pooling) | Must be learned |
| Global context | Requires depth | Available in every layer |
| Permutation sensitivity | Always (spatial structure) | Only with position embeddings |
26.8.2 Computational Characteristics
FLOPs and Memory: For an image with $N$ patches and model dimension $d$: - Self-attention: $O(N^2 d)$ FLOPs, $O(N^2)$ memory for attention maps. - Convolution: $O(N k^2 d^2)$ FLOPs where $k$ is kernel size, $O(Nd)$ memory.
For typical ImageNet resolution (196 patches), self-attention is comparable to convolutions. At higher resolutions (e.g., 1024x1024 with patch size 16 gives 4096 patches), the quadratic cost becomes significant.
26.8.3 Data Efficiency vs. Scalability
We see that CNNs are more data-efficient — they achieve strong performance with modest training sets thanks to their built-in inductive biases. ViTs require either large-scale pre-training or sophisticated data augmentation (as in DeiT) to compete.
However, ViTs scale more favorably with increasing data and model size. On the JFT-300M dataset, ViT-H/14 achieved 88.55% top-1 accuracy on ImageNet, surpassing the best CNNs. This suggests that inductive biases help when data is limited but can become constraints when data is abundant.
26.8.4 Robustness
Studies have shown that ViTs tend to be more robust to: - Distribution shift: ViTs trained on ImageNet transfer better to stylized or corrupted versions. - Occlusion: Due to global attention, ViTs can recognize objects even when significant portions are occluded. - Texture bias: CNNs tend to rely heavily on texture cues, while ViTs develop more shape-based representations.
However, ViTs are more sensitive to patch corruption — corrupting a small number of patches can disproportionately affect performance, whereas CNNs degrade more gracefully.
26.8.5 Hybrid Architectures
The boundary between CNNs and ViTs is increasingly blurred:
- CoAtNet: Combines depthwise convolutions in early stages with self-attention in later stages.
- ConvNeXt: Modernizes the ResNet design with ViT-inspired training recipes and architectural choices, achieving competitive performance with a pure CNN.
- MaxViT: Uses multi-axis attention that combines local (window) and global (grid) attention patterns.
These hybrids suggest that the future of vision architectures may not be a binary choice between convolutions and attention but a thoughtful combination of both.
26.8.6 When to Choose What: A Decision Framework
For practitioners deciding between CNN and ViT architectures, the following guidelines can help:
Choose a CNN (ConvNeXt, EfficientNet) when: - Your training dataset is small (under 10,000 images) and no suitable pre-trained ViT exists for your domain. - You need guaranteed translation equivariance (e.g., certain medical or satellite imaging tasks). - Deployment targets have limited memory and you need a highly compact model.
Choose a standard ViT when: - You can leverage a strong pre-trained checkpoint (ImageNet-21K, CLIP, DINOv2). - Your task benefits from global context (e.g., scene classification, fine-grained recognition). - You plan to integrate the vision encoder into a multimodal pipeline, as we will explore in Chapter 28.
Choose a hierarchical ViT (Swin, SegFormer) when: - You need multi-scale feature maps (detection, segmentation, panoptic tasks). - You work with high-resolution inputs where quadratic attention cost is prohibitive. - You want a drop-in replacement for CNN backbones in existing FPN or U-Net pipelines.
Choose a hybrid (CoAtNet, MaxViT) when: - You want the data efficiency of convolutions in early layers with the global modeling of attention in later layers. - You operate at medium scale and cannot afford the massive pre-training that pure ViTs benefit from.
26.9 Fine-Tuning ViT with HuggingFace
26.9.1 Setting Up the Environment
Fine-tuning a pre-trained ViT for a custom classification task is straightforward with HuggingFace's transformers library. The pipeline involves:
- Loading a pre-trained ViT model and its associated image processor.
- Replacing the classification head for the target number of classes.
- Setting up the dataset with appropriate transforms.
- Training with the
TrainerAPI or a custom loop.
26.9.2 Image Preprocessing
ViT models expect images in a specific format determined by the pre-training configuration. The HuggingFace ViTImageProcessor handles:
- Resizing to the expected resolution (e.g., 224x224).
- Normalizing with the appropriate mean and standard deviation.
- Converting to the correct color space and tensor format.
from transformers import ViTImageProcessor
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
# processor.size: {'height': 224, 'width': 224}
# processor.image_mean: [0.5, 0.5, 0.5]
# processor.image_std: [0.5, 0.5, 0.5]
26.9.3 Dataset Preparation
For fine-tuning, you will typically use HuggingFace's datasets library or a custom torch.utils.data.Dataset. The key is applying the image processor to each image:
from datasets import load_dataset
from torchvision.transforms import (
Compose, RandomResizedCrop, RandomHorizontalFlip, ToTensor, Normalize
)
# Load a dataset
dataset = load_dataset("food101", split="train[:5000]")
# Define transforms
train_transforms = Compose([
RandomResizedCrop(processor.size["height"]),
RandomHorizontalFlip(),
ToTensor(),
Normalize(mean=processor.image_mean, std=processor.image_std),
])
def preprocess(examples):
examples["pixel_values"] = [
train_transforms(img.convert("RGB")) for img in examples["image"]
]
return examples
dataset = dataset.with_transform(preprocess)
26.9.4 Model Configuration
When fine-tuning for a new classification task, you need to replace the pre-trained head:
from transformers import ViTForImageClassification
model = ViTForImageClassification.from_pretrained(
"google/vit-base-patch16-224",
num_labels=101, # Number of classes in Food-101
ignore_mismatched_sizes=True, # Allow replacing the head
)
26.9.5 Training
Using HuggingFace's Trainer:
from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(
output_dir="./vit-food101",
num_train_epochs=5,
per_device_train_batch_size=32,
learning_rate=2e-5,
weight_decay=0.01,
warmup_ratio=0.1,
logging_steps=100,
save_strategy="epoch",
evaluation_strategy="epoch",
load_best_model_at_end=True,
fp16=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
data_collator=collate_fn,
compute_metrics=compute_metrics,
)
trainer.train()
26.9.6 Fine-Tuning Best Practices
-
Learning rate: Use a small learning rate (1e-5 to 5e-5) with warmup. The pre-trained features are already good; large updates can destroy them.
-
Layer-wise learning rate decay: Apply progressively smaller learning rates to earlier layers. A typical decay factor of 0.65-0.75 works well, meaning the first layer learns at 0.65^L times the base rate.
-
Data augmentation: Even for fine-tuning, augmentations like RandAugment, Mixup, and CutMix help, especially for small datasets.
-
Resolution: ViT can be fine-tuned at higher resolutions than the pre-training resolution. The position embeddings are interpolated (typically bicubic interpolation of the 2D grid) to accommodate the new sequence length.
-
Freezing strategies: For very small datasets, freeze the early layers and only fine-tune the later transformer blocks and classification head. Gradually unfreeze layers as training progresses.
26.9.7 Complete Fine-Tuning Example
Below is a more complete example that includes metric computation and a custom data collator:
import torch
import numpy as np
from datasets import load_dataset
from transformers import (
ViTForImageClassification,
ViTImageProcessor,
TrainingArguments,
Trainer,
)
from torchvision.transforms import (
Compose, RandomResizedCrop, RandomHorizontalFlip,
ToTensor, Normalize,
)
def build_vit_trainer(
dataset_name: str = "food101",
model_name: str = "google/vit-base-patch16-224",
num_train_samples: int = 5000,
num_val_samples: int = 1000,
num_epochs: int = 5,
batch_size: int = 32,
learning_rate: float = 2e-5,
) -> Trainer:
"""Build a HuggingFace Trainer for ViT fine-tuning.
Args:
dataset_name: Name of the HuggingFace dataset.
model_name: Pre-trained ViT model identifier.
num_train_samples: Number of training samples to use.
num_val_samples: Number of validation samples.
num_epochs: Number of training epochs.
batch_size: Per-device batch size.
learning_rate: Peak learning rate.
Returns:
Configured Trainer ready for training.
"""
processor = ViTImageProcessor.from_pretrained(model_name)
train_transforms = Compose([
RandomResizedCrop(processor.size["height"]),
RandomHorizontalFlip(),
ToTensor(),
Normalize(mean=processor.image_mean, std=processor.image_std),
])
val_transforms = Compose([
ToTensor(),
Normalize(mean=processor.image_mean, std=processor.image_std),
])
train_ds = load_dataset(
dataset_name, split=f"train[:{num_train_samples}]"
)
val_ds = load_dataset(
dataset_name, split=f"validation[:{num_val_samples}]"
)
num_labels = len(set(train_ds["label"]))
def preprocess_train(examples):
examples["pixel_values"] = [
train_transforms(img.convert("RGB"))
for img in examples["image"]
]
return examples
def preprocess_val(examples):
pixel_values = processor(
[img.convert("RGB") for img in examples["image"]],
return_tensors="pt",
)["pixel_values"]
examples["pixel_values"] = [pv for pv in pixel_values]
return examples
train_ds = train_ds.with_transform(preprocess_train)
val_ds = val_ds.with_transform(preprocess_val)
model = ViTForImageClassification.from_pretrained(
model_name,
num_labels=num_labels,
ignore_mismatched_sizes=True,
)
def collate_fn(batch):
return {
"pixel_values": torch.stack(
[x["pixel_values"] for x in batch]
),
"labels": torch.tensor([x["label"] for x in batch]),
}
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
accuracy = (predictions == labels).mean()
return {"accuracy": accuracy}
training_args = TrainingArguments(
output_dir="./vit-finetuned",
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
learning_rate=learning_rate,
weight_decay=0.01,
warmup_ratio=0.1,
logging_steps=50,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="accuracy",
fp16=torch.cuda.is_available(),
)
return Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=val_ds,
data_collator=collate_fn,
compute_metrics=compute_metrics,
)
# Usage:
# trainer = build_vit_trainer()
# trainer.train()
# metrics = trainer.evaluate()
# print(f"Validation accuracy: {metrics['eval_accuracy']:.4f}")
26.9.8 Layer-Wise Learning Rate Decay
Layer-wise learning rate decay (LLRD) is a technique that assigns progressively smaller learning rates to earlier layers, reflecting the intuition that earlier layers capture more general features that need less adaptation. The implementation assigns a learning rate $\eta_l = \eta_{\text{base}} \times \gamma^{L - l}$ to layer $l$, where $\gamma$ is the decay factor and $L$ is the total number of layers.
def create_layer_wise_lr_groups(
model: ViTForImageClassification,
base_lr: float = 2e-5,
decay_factor: float = 0.7,
) -> list[dict]:
"""Create optimizer parameter groups with layer-wise LR decay.
Args:
model: The ViT model to configure.
base_lr: Learning rate for the classification head.
decay_factor: Multiplicative decay per layer (0 to 1).
Returns:
List of parameter group dicts for the optimizer.
"""
num_layers = model.config.num_hidden_layers
param_groups = []
# Classification head gets full learning rate
param_groups.append({
"params": list(model.classifier.parameters()),
"lr": base_lr,
})
# Each encoder layer gets a decayed learning rate
for layer_idx in range(num_layers - 1, -1, -1):
layer = model.vit.encoder.layer[layer_idx]
depth_from_top = num_layers - 1 - layer_idx
lr = base_lr * (decay_factor ** (depth_from_top + 1))
param_groups.append({
"params": list(layer.parameters()),
"lr": lr,
})
# Embeddings get the smallest learning rate
lr = base_lr * (decay_factor ** (num_layers + 1))
param_groups.append({
"params": list(model.vit.embeddings.parameters()),
"lr": lr,
})
return param_groups
Worked Example: For ViT-Base (12 layers) with base LR = 2e-5 and decay = 0.7: the classification head trains at 2e-5, layer 11 at 1.4e-5, layer 10 at 9.8e-6, and so on, down to the embeddings at about 2.7e-7. This 75x difference between the highest and lowest learning rates prevents catastrophic forgetting of pre-trained features while still allowing meaningful adaptation across the network.
26.10 Advanced Topics
26.10.1 Masked Image Modeling (MAE)
Masked Autoencoders (MAE) by He et al. (2022) apply the masked language modeling paradigm to images:
- Randomly mask a large portion (75%) of image patches.
- Encode only the visible patches with a ViT encoder.
- Decode the full set of patches (visible + masked) with a lightweight decoder.
- Reconstruct the masked patches in pixel space.
The high masking ratio (much higher than BERT's 15%) is possible because images have high spatial redundancy. This creates a challenging pre-training task that forces the model to learn deep visual representations.
The asymmetric encoder-decoder design (heavy encoder sees only 25% of patches, light decoder handles full reconstruction) makes pre-training efficient — about 3x faster than training ViT from scratch.
Why 75% masking works for images: In NLP, BERT masks only 15% of tokens because language has relatively low redundancy — each word contributes significant information. Images, by contrast, have high spatial redundancy: neighboring patches share similar textures, colors, and structures. A 75% masking ratio forces the encoder to learn abstract, semantic features rather than relying on low-level interpolation from nearby visible patches.
MAE vs. supervised pre-training: On ImageNet, MAE pre-trained ViT-Large achieves 85.9% top-1 accuracy when fine-tuned, compared to 82.6% for supervised training from scratch. The gains are even larger for ViT-Huge (86.9% vs. 83.1%), suggesting that self-supervised pre-training becomes increasingly important as model size grows. MAE pre-training is also considerably more computationally efficient than supervised pre-training on ImageNet-21K, making it the preferred approach for large ViT models.
Connection to BERT: The success of MAE and its auditory counterpart AudioMAE (discussed in Chapter 29) highlights a general principle: masked prediction — whether of text tokens, image patches, or audio segments — is an effective self-supervised objective across modalities. This universality is one of the reasons transformers have become the dominant architecture for AI.
26.10.2 DINO and Self-Supervised ViT
DINO (Self-Distillation with No Labels) trains a ViT using self-supervised distillation. A student network is trained to match the output of a teacher network (an exponential moving average of the student) on different augmented views of the same image. The resulting features exhibit remarkable properties:
- Attention maps in the final layer segment objects without any supervision.
- The features work well for k-NN classification without any fine-tuning.
- The [CLS] token naturally encodes semantic information suitable for retrieval.
DINOv2, the successor, scaled this approach with curated datasets and improved training, producing features that rival or exceed supervised pre-training across many tasks.
26.10.3 FlashAttention for Vision
FlashAttention, an IO-aware exact attention algorithm, is particularly impactful for vision transformers processing high-resolution images. By reducing memory reads/writes through tiling and kernel fusion, FlashAttention enables:
- Training ViTs on higher resolution images within the same memory budget.
- 2-4x speedup on standard ViT training.
- Enabling longer sequence lengths for dense prediction tasks.
26.10.4 Resolution Flexibility with NaViT
NaViT (Native Resolution ViT) from Google addresses ViT's fixed-resolution limitation. Instead of resizing all images to the same resolution, NaViT:
- Packs patches from images of different resolutions into the same batch.
- Uses sequence packing to maximize hardware utilization.
- Handles variable-length patch sequences within the transformer.
This eliminates the information loss from resizing and improves both training efficiency and accuracy.
26.11 Practical Considerations
26.11.1 Choosing the Right Architecture
- Classification on small datasets (<10K images): Fine-tune a pre-trained ViT-Base or use a CNN like EfficientNet.
- Classification on large datasets (>1M images): Pre-train or fine-tune ViT-Large or Swin-B.
- Object detection: Use DETR variants (Deformable DETR, DINO) or Swin Transformer with standard detection heads.
- Semantic segmentation: SegFormer for efficiency, Swin + UperNet for accuracy, SAM for zero-shot versatility.
- Edge deployment: Use smaller variants (ViT-Tiny, Swin-T) with quantization and distillation.
26.11.2 Pre-training Strategies
- Supervised pre-training (ImageNet-21K): Best for most downstream tasks.
- Self-supervised pre-training (MAE, DINO): Competitive with supervised, no labels needed.
- CLIP pre-training: Best for zero-shot transfer and multimodal applications (covered in Chapter 28).
26.11.3 Hardware and Efficiency
For training ViTs efficiently:
- Mixed precision: Always use FP16/BF16 training with gradient scaling.
- Gradient checkpointing: Reduces memory by recomputing activations during backward pass.
- FlashAttention: Use when available for significant speedup.
- Data loading: Vision workloads are often data-loading bound; use efficient image decoders (DALI, ffcv) and sufficient data workers.
26.12 Historical Context and Future Directions
26.12.1 The Timeline of Vision Transformers
The evolution of vision transformers proceeded at a remarkable pace:
- October 2020: ViT preprint demonstrates transformers can match CNNs on ImageNet when pre-trained at scale.
- January 2021: DeiT shows ViTs can be trained data-efficiently on ImageNet-1K alone.
- March 2021: Swin Transformer introduces hierarchical design and shifted windows, making ViTs practical for dense prediction.
- May 2020: DETR eliminates hand-designed detection components (published before ViT but foundational to the transformer-in-vision movement).
- November 2021: MAE demonstrates masked image modeling as a powerful self-supervised objective for ViTs.
- April 2021: DINO reveals that self-supervised ViTs learn to segment objects without any supervision.
- January 2022: ConvNeXt shows that CNNs, when modernized with ViT-inspired design choices, remain competitive.
- April 2023: SAM demonstrates universal segmentation with a ViT-H backbone.
- February 2024: DINOv2 produces vision features that rival or exceed CLIP for many downstream tasks, all without language supervision.
This timeline illustrates a recurring pattern in deep learning: a new architecture family (transformers) initially underperforms the established approach (CNNs) on standard benchmarks, but rapidly catches up and surpasses it once the training methodology, data regime, and architectural refinements are developed.
26.12.2 Open Challenges
Despite their success, vision transformers face several open challenges:
-
Efficiency at extreme resolutions: Processing 4K or higher-resolution images remains expensive even with windowed attention. Emerging approaches like token pruning (dynamically dropping uninformative tokens) and token merging (combining similar tokens) show promise.
-
Robustness to adversarial perturbations: While ViTs are more robust than CNNs to common corruptions, they remain vulnerable to adversarial attacks, though the attack patterns differ from those effective against CNNs.
-
Interpretability: Attention maps provide some interpretability, but they do not always align with human-intuitive explanations. Understanding what ViTs learn at each layer remains an active research area.
-
Deployment on edge devices: ViTs are compute-intensive and memory-hungry compared to lightweight CNNs like MobileNet. Model distillation, quantization, and architecture search for efficient ViTs are important practical research directions.
26.13 Summary
Vision transformers have fundamentally changed the landscape of computer vision. Starting from the simple idea of treating images as sequences of patches, the field has rapidly evolved to produce architectures that match or exceed CNN performance across virtually all vision tasks. Key takeaways from this chapter include:
- ViT demonstrated that a standard transformer, with minimal modifications, can process images effectively when given sufficient pre-training data. The patch embedding mechanism — equivalent to a strided convolution — turns 2D images into 1D token sequences that a standard transformer can process.
- DeiT showed that careful training recipes — strong augmentation, regularization, and distillation — can make ViTs data-efficient. The training methodology proved as important as the architecture itself.
- Swin Transformer introduced hierarchical features and shifted-window attention, making transformers practical for dense prediction tasks. Its multi-scale design provides a direct replacement for CNN backbones in detection and segmentation pipelines.
- DETR eliminated hand-designed components in object detection with an elegant set-prediction formulation. The bipartite matching loss and object queries replaced anchors, NMS, and region proposal networks. Deformable DETR and DINO further improved convergence speed and accuracy.
- SAM demonstrated that a ViT backbone trained on massive mask data can generalize to segment any object in any image, establishing a new paradigm of promptable segmentation.
- The CNN vs. ViT debate is nuanced: CNNs offer better inductive biases for small data, while ViTs scale better with data and model size. Hybrid architectures combine the strengths of both. In practice, the choice depends on dataset size, task requirements, and deployment constraints.
- Self-supervised methods (MAE, DINO, DINOv2) have shown that ViTs can learn powerful visual representations without any labeled data, and these representations rival or exceed supervised pre-training for many downstream tasks.
The vision transformer revolution continues to accelerate, with new architectures and training methods appearing regularly. The foundations covered in this chapter — patch embeddings, position encoding strategies, hierarchical designs, set-based detection, and promptable segmentation — will equip you to understand and apply these advances as they emerge. In the next chapter, we will explore how diffusion models leverage these visual representations for image generation, and in Chapter 28, we will see how vision transformers form the visual backbone of multimodal models that bridge vision and language.
26.14 Exercises
-
Patch embedding analysis: Implement a ViT-Tiny (6 layers, 192-dim, 3 heads) patch embedding and verify that the number of output tokens matches $HW / P^2 + 1$ for various image sizes and patch sizes.
-
Position embedding visualization: Load a pre-trained ViT-Base/16 from HuggingFace and compute the pairwise cosine similarity matrix of its learned position embeddings. Reshape the result into a 2D grid and visualize it to confirm the emergence of 2D spatial structure.
-
Swin vs. ViT efficiency: Compute the theoretical FLOPs for self-attention in ViT-Base (196 tokens, global attention) versus Swin-T (56x56 tokens with 7x7 windows) at 224x224 resolution. At what resolution does the Swin approach become 10x cheaper?
-
DeiT training ablation: Using the DeiT training recipe, measure the impact of removing individual components (Mixup, CutMix, stochastic depth) on ImageNet validation accuracy. Which single component causes the largest accuracy drop?
-
Transfer learning: Fine-tune a pre-trained ViT-Base on a small dataset (e.g., Oxford Flowers-102 with 1,020 training images) using the layer-wise learning rate decay technique described in Section 26.9.8. Compare with uniform learning rate and with a frozen backbone + linear probe.
References
- Dosovitskiy, A., et al. (2020). "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale." ICLR 2021.
- Touvron, H., et al. (2021). "Training Data-Efficient Image Transformers & Distillation through Attention." ICML 2021.
- Liu, Z., et al. (2021). "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows." ICCV 2021.
- Carion, N., et al. (2020). "End-to-End Object Detection with Transformers." ECCV 2020.
- Xie, E., et al. (2021). "SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers." NeurIPS 2021.
- He, K., et al. (2022). "Masked Autoencoders Are Scalable Vision Learners." CVPR 2022.
- Caron, M., et al. (2021). "Emerging Properties in Self-Supervised Vision Transformers." ICCV 2021.
- Kirillov, A., et al. (2023). "Segment Anything." ICCV 2023.
- Ranftl, R., et al. (2021). "Vision Transformers for Dense Prediction." ICCV 2021.
- Liu, Z., et al. (2022). "A ConvNet for the 2020s." CVPR 2022.
Related Reading
Explore this topic in other books
College Football Analytics Computer Vision in Football Basketball Analytics Computer Vision in Basketball Soccer Analytics Computer Vision for Soccer