> "As model size increases, training time increases — but only if you fail to add proportionally more compute. The goal of distributed training is to turn money into time."
In This Chapter
- Learning Objectives
- 26.1 When a Single GPU Is Not Enough
- 26.2 Data Parallelism: The Foundation
- 26.3 Beyond Data Parallelism: Model, Pipeline, and Tensor Parallelism
- 26.4 GPU Memory Hierarchy and the Memory Wall
- 26.5 Mixed-Precision Training and AMP
- 26.6 Gradient Checkpointing: Trading Compute for Memory
- 26.7 FlashAttention: Algorithmic Memory Efficiency
- 26.8 Efficient Large-Batch Training
- 26.9 DeepSpeed and Fully Sharded Data Parallelism (FSDP)
- 26.10 Cost Estimation and Management
- 26.11 Putting It All Together: The Training Pipeline
- 26.12 Measuring Training Efficiency
- 26.13 Progressive Project: Scaling StreamRec Training
- 26.14 Chapter Summary
Chapter 26: Training at Scale — Distributed Training, GPU Optimization, and Managing Compute Costs
"As model size increases, training time increases — but only if you fail to add proportionally more compute. The goal of distributed training is to turn money into time." — Jeff Dean, "Large-Scale Distributed Systems and Infrastructure" (Google AI, 2020)
Learning Objectives
By the end of this chapter, you will be able to:
- Implement data-parallel training with PyTorch DistributedDataParallel (DDP), explain gradient synchronization via all-reduce, and diagnose common distributed training failures
- Compare and select among parallelism strategies — data parallelism, model parallelism, pipeline parallelism, and tensor parallelism — based on model architecture, memory constraints, and hardware topology
- Optimize GPU utilization through mixed-precision training (AMP), gradient checkpointing, and memory-efficient attention (FlashAttention), with quantitative understanding of the memory-compute trade-offs each entails
- Estimate and manage compute costs for large-scale training runs using cloud pricing models, spot/preemptible instances, and checkpoint-based fault tolerance
- Apply efficient large-batch training techniques — learning rate warmup, linear scaling rule, LARS, and LAMB — to maintain model quality as batch size increases
26.1 When a Single GPU Is Not Enough
In Chapter 7, you trained the StreamRec click-prediction MLP on a single GPU. The model had 2.3 million parameters. Training completed in under 10 minutes. The entire dataset fit in GPU memory with room to spare.
That era is over.
The two-tower retrieval model from Chapter 13 has 45 million parameters and trains on 1.2 billion user-item interaction pairs. The deep ranking model from Chapter 24's system design has 120 million parameters with cross-attention layers that require quadratic memory in sequence length. And these are small by modern standards: GPT-3 has 175 billion parameters, and climate foundation models (our Climate DL anchor) routinely train on petabyte-scale datasets with models exceeding 1 billion parameters.
Three forces drive the need for distributed training:
Data scale. StreamRec's full interaction dataset — 1.2 billion events from 50 million users over 18 months — does not fit in the memory of a single GPU. Even if it did, training would require hundreds of epochs to converge, taking weeks on a single accelerator. Distributing the data across multiple GPUs reduces wall-clock time proportionally (in the ideal case).
Model scale. When the model itself does not fit in the memory of a single GPU, you have no choice: the model must be split across devices. A transformer with 1 billion parameters in FP32 requires 4 GB for parameters alone, plus optimizer state (8 GB for Adam), plus activations that scale with batch size and sequence length. On an A100 with 80 GB of HBM, the model fits — barely — but there is no room for a meaningful batch size.
Time pressure. The business does not wait. StreamRec retrains daily (Chapter 24, Section 24.2). If training takes 20 hours on a single GPU, the model is always stale by the time it reaches production. Distributing across 8 GPUs reduces training to 2.5 hours (with communication overhead), fitting comfortably within the daily retraining window.
Production ML = Software Engineering: Distributed training introduces the same challenges that any distributed system faces: partial failures, communication overhead, synchronization cost, and non-determinism. The techniques in this chapter — gradient synchronization, fault-tolerant checkpointing, cost management — are software engineering problems that happen to involve GPUs. A production training pipeline is a distributed system first and a machine learning system second.
This chapter provides the complete toolkit for training at scale: data parallelism (the workhorse strategy for most production systems), model and pipeline parallelism (for models too large for a single device), GPU memory optimization (mixed precision, gradient checkpointing, FlashAttention), efficient large-batch training (linear scaling rule, LARS, LAMB), and cost management (cloud pricing, spot instances, checkpoint-based fault tolerance).
We ground the discussion in two anchor examples:
-
Climate DL. Training a 1.2-billion-parameter Vision Transformer for global weather prediction on a cluster of 64 A100 GPUs. The model processes $1024 \times 2048$ resolution atmospheric fields at 6-hour intervals, and the training data (ERA5 reanalysis) spans 40 years at hourly resolution — 350,000 time steps, approximately 200 TB uncompressed.
-
Content Platform Recommender (StreamRec). Scaling the two-tower retrieval model from single-GPU training (Chapter 13) to the full 1.2-billion-event dataset using multi-GPU data parallelism. The progressive project for this chapter.
26.2 Data Parallelism: The Foundation
Data parallelism is the simplest and most widely used distributed training strategy. The idea is conceptually straightforward: replicate the model on every GPU, partition the data across GPUs, and synchronize gradients after each backward pass so that all replicas stay in sync.
26.2.1 The Algorithm
Given $N$ GPUs, each holding a full copy of the model with parameters $\theta$:
- Partition the global batch $\mathcal{B}$ of size $B$ into $N$ micro-batches, each of size $B/N$.
- Forward pass: Each GPU $k$ computes the loss on its micro-batch: $$\mathcal{L}_k = \frac{1}{B/N} \sum_{(x,y) \in \mathcal{B}_k} \ell(f_\theta(x), y)$$
- Backward pass: Each GPU $k$ computes local gradients $g_k = \nabla_\theta \mathcal{L}_k$.
- All-reduce: Compute the average gradient across all GPUs: $$\bar{g} = \frac{1}{N} \sum_{k=1}^{N} g_k$$
- Update: Each GPU applies the same optimizer step using $\bar{g}$, producing identical updated parameters $\theta'$.
Because all GPUs start with the same parameters and apply the same averaged gradient, they remain synchronized — no parameter server is needed, and no GPU holds a "master" copy.
The key insight is that this procedure computes exactly the same gradient as single-GPU training with global batch size $B$. The distributed version is mathematically equivalent; only the wall-clock time differs.
26.2.2 All-Reduce: The Synchronization Primitive
The all-reduce operation is the critical communication step. Its job is to take $N$ gradient tensors (one per GPU) and produce the element-wise sum (or mean) on every GPU.
Naive all-reduce. Send all gradients to a single GPU, sum them, and broadcast the result. Communication cost: $2(N-1) \cdot M$ bytes transferred through a single bottleneck, where $M$ is the gradient size in bytes. This is bandwidth-limited by the bottleneck GPU's network interface.
Ring all-reduce. Arrange the $N$ GPUs in a logical ring. Divide the gradient tensor into $N$ chunks. In the reduce-scatter phase, each GPU sends one chunk to its neighbor and accumulates the received chunk — after $N-1$ steps, each GPU holds the fully reduced version of one chunk. In the all-gather phase, each GPU sends its fully reduced chunk around the ring — after $N-1$ more steps, every GPU has the complete reduced gradient. Total data transferred per GPU: $2 \cdot \frac{N-1}{N} \cdot M$ bytes.
$$T_{\text{ring}} = 2(N-1) \cdot \alpha + 2 \cdot \frac{N-1}{N} \cdot \frac{M}{\beta}$$
where $\alpha$ is the per-message latency and $\beta$ is the per-GPU bandwidth (bytes/second). The critical property: as $N$ grows, the bandwidth term $\frac{N-1}{N} \cdot \frac{M}{\beta}$ approaches $\frac{M}{\beta}$ — the communication cost per GPU is nearly independent of the number of GPUs. This is why ring all-reduce scales efficiently.
Ring All-Reduce: Reduce-Scatter Phase (4 GPUs, 4 chunks)
Step 0: Each GPU holds its local gradient, divided into 4 chunks
GPU 0: [A0] [B0] [C0] [D0]
GPU 1: [A1] [B1] [C1] [D1]
GPU 2: [A2] [B2] [C2] [D2]
GPU 3: [A3] [B3] [C3] [D3]
Step 1: Each GPU sends one chunk to its right neighbor, accumulates received
GPU 0: [A0] [B0] [C0] [D0+D3] ← received D3 from GPU 3
GPU 1: [A1+A0] [B1] [C1] [D1] ← received A0 from GPU 0
GPU 2: [A2] [B2+B1] [C2] [D2] ← received B1 from GPU 1
GPU 3: [A3] [B3] [C3+C2] [D3] ← received C2 from GPU 2
Step 2: Continue rotating
GPU 0: [A0] [B0] [C0+C3+C2] [D0+D3]
GPU 1: [A1+A0] [B1] [C1] [D1+D0+D3]
GPU 2: [A2+A1+A0][B2+B1] [C2] [D2]
GPU 3: [A3] [B3+B2+B1][C3+C2] [D3]
Step 3: Final reduce-scatter step
GPU 0: [A0] [B0+B3+B2+B1] [C0+C3+C2] [D0+D3]
GPU 1: [A1+A0] [B1] [C1+C0+C3+C2] [D1+D0+D3]
GPU 2: [A2+A1+A0] [B2+B1] [C2] [D2+D1+D0+D3]
GPU 3: [A3+A2+A1+A0][B3+B2+B1] [C3+C2] [D3]
Result: Each GPU holds the fully reduced version of one chunk.
GPU 0: chunk B fully reduced (sum of B0+B1+B2+B3)
GPU 1: chunk C fully reduced
GPU 2: chunk D fully reduced
GPU 3: chunk A fully reduced
All-Gather phase then distributes the fully reduced chunks to all GPUs
(reverse direction, same number of steps).
NCCL (NVIDIA Collective Communications Library). In practice, you never implement ring all-reduce yourself. NCCL provides optimized collective operations that exploit the GPU interconnect topology — NVLink between GPUs within a node (600 GB/s on A100 nodes), InfiniBand across nodes (200 Gb/s with HDR). NCCL automatically selects the best algorithm (ring, tree, or hybrid) based on message size and topology.
26.2.3 PyTorch DistributedDataParallel (DDP)
PyTorch DDP is the standard API for data-parallel training. It wraps a model, hooks into the backward pass to overlap gradient computation with gradient communication, and handles the all-reduce automatically.
import os
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from dataclasses import dataclass, field
from typing import Optional, Dict, List, Tuple
@dataclass
class DDPConfig:
"""Configuration for distributed data-parallel training.
Attributes:
world_size: Total number of processes (GPUs).
backend: Communication backend ('nccl' for GPU, 'gloo' for CPU).
master_addr: Address of the rank-0 process.
master_port: Port for rendezvous.
find_unused_parameters: Whether to detect unused parameters in
the forward pass. Set to True if not all parameters receive
gradients every iteration (e.g., conditional computation).
"""
world_size: int = 1
backend: str = "nccl"
master_addr: str = "localhost"
master_port: str = "29500"
find_unused_parameters: bool = False
def setup_ddp(rank: int, config: DDPConfig) -> None:
"""Initialize the distributed process group.
Must be called once per process before any distributed operations.
Args:
rank: This process's rank (0 to world_size - 1).
config: Distributed training configuration.
"""
os.environ["MASTER_ADDR"] = config.master_addr
os.environ["MASTER_PORT"] = config.master_port
dist.init_process_group(
backend=config.backend,
rank=rank,
world_size=config.world_size,
)
torch.cuda.set_device(rank)
def cleanup_ddp() -> None:
"""Destroy the distributed process group.
Call this at the end of training to release resources.
"""
dist.destroy_process_group()
class TwoTowerModel(nn.Module):
"""Simplified two-tower retrieval model for StreamRec.
User and item towers produce embeddings; similarity is their dot product.
Args:
num_users: Number of users in the vocabulary.
num_items: Number of items in the vocabulary.
embedding_dim: Dimension of user and item embeddings.
hidden_dim: Width of the tower hidden layers.
"""
def __init__(
self,
num_users: int,
num_items: int,
embedding_dim: int = 128,
hidden_dim: int = 256,
) -> None:
super().__init__()
self.user_tower = nn.Sequential(
nn.Embedding(num_users, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, embedding_dim),
nn.functional.normalize, # L2 normalize
)
self.item_tower = nn.Sequential(
nn.Embedding(num_items, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, embedding_dim),
)
self.embedding_dim = embedding_dim
def forward_user(self, user_ids: torch.Tensor) -> torch.Tensor:
"""Compute user embeddings.
Args:
user_ids: Tensor of shape (batch_size,).
Returns:
L2-normalized user embeddings of shape (batch_size, embedding_dim).
"""
x = self.user_tower[0](user_ids) # Embedding
x = self.user_tower[1](x) # ReLU
x = self.user_tower[2](x) # Linear
x = self.user_tower[3](x) # ReLU
x = self.user_tower[4](x) # Linear
return nn.functional.normalize(x, p=2, dim=-1)
def forward_item(self, item_ids: torch.Tensor) -> torch.Tensor:
"""Compute item embeddings.
Args:
item_ids: Tensor of shape (batch_size,).
Returns:
L2-normalized item embeddings of shape (batch_size, embedding_dim).
"""
x = self.item_tower[0](item_ids)
x = self.item_tower[1](x)
x = self.item_tower[2](x)
x = self.item_tower[3](x)
x = self.item_tower[4](x)
return nn.functional.normalize(x, p=2, dim=-1)
def forward(
self, user_ids: torch.Tensor, item_ids: torch.Tensor
) -> torch.Tensor:
"""Compute similarity scores for user-item pairs.
Args:
user_ids: Tensor of shape (batch_size,).
item_ids: Tensor of shape (batch_size,).
Returns:
Dot-product similarity of shape (batch_size,).
"""
user_emb = self.forward_user(user_ids)
item_emb = self.forward_item(item_ids)
return (user_emb * item_emb).sum(dim=-1)
def create_ddp_model(
model: nn.Module, rank: int, config: DDPConfig
) -> DDP:
"""Wrap a model with DistributedDataParallel.
Args:
model: The model to distribute.
rank: This process's GPU rank.
config: DDP configuration.
Returns:
DDP-wrapped model.
"""
model = model.to(rank)
return DDP(
model,
device_ids=[rank],
find_unused_parameters=config.find_unused_parameters,
)
def train_one_epoch_ddp(
model: DDP,
dataloader: DataLoader,
optimizer: torch.optim.Optimizer,
rank: int,
epoch: int,
) -> Dict[str, float]:
"""Train for one epoch using DDP.
Args:
model: DDP-wrapped model.
dataloader: DataLoader with DistributedSampler.
optimizer: Optimizer (same on every rank).
rank: This process's GPU rank.
epoch: Current epoch (used to set sampler seed).
Returns:
Dictionary of training metrics for this rank.
"""
model.train()
dataloader.sampler.set_epoch(epoch) # Ensure different shuffling per epoch
total_loss = 0.0
num_batches = 0
loss_fn = nn.BCEWithLogitsLoss()
for batch in dataloader:
user_ids = batch["user_id"].to(rank)
item_ids = batch["item_id"].to(rank)
labels = batch["label"].float().to(rank)
optimizer.zero_grad()
scores = model(user_ids, item_ids)
loss = loss_fn(scores, labels)
loss.backward() # DDP hooks trigger all-reduce here
optimizer.step()
total_loss += loss.item()
num_batches += 1
avg_loss = total_loss / max(num_batches, 1)
# Aggregate loss across all ranks for logging
loss_tensor = torch.tensor([avg_loss], device=rank)
dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
return {"loss": loss_tensor.item(), "batches": num_batches}
The key line is loss.backward(). When DDP wraps a model, it registers hooks on every parameter. During the backward pass, as soon as a gradient is computed for a parameter, DDP initiates the all-reduce for that parameter's gradient — overlapping communication with computation. By the time the backward pass finishes, gradient synchronization is already complete (or nearly so). This overlap is the reason DDP is faster than a naive "compute all gradients, then all-reduce" approach.
26.2.4 The DistributedSampler
A subtle but critical detail: each GPU must see a different partition of the data. The DistributedSampler handles this by dividing the dataset indices into $N$ contiguous chunks (one per rank) and shuffling within each chunk. The set_epoch() call changes the random seed each epoch, ensuring that each GPU sees different data every epoch rather than the same partition repeatedly.
Without set_epoch(), every epoch trains on the same GPU-to-data mapping, which reduces the effective data diversity and can degrade convergence.
26.2.5 Gradient Synchronization: Correctness Guarantee
The mathematical equivalence of DDP and single-GPU training relies on one assumption: all GPUs start with identical parameters. If the parameters diverge — even slightly — the averaged gradient is no longer the gradient of any single loss function, and training can diverge.
Sources of divergence:
- Different initialization. Always seed the model identically on all ranks, or broadcast parameters from rank 0 after initialization.
- Non-deterministic operations. CUDA operations like atomicAdd are non-deterministic. Set torch.backends.cudnn.deterministic = True for reproducibility (at a performance cost).
- Floating-point non-associativity. The order of gradient accumulation in all-reduce can affect the result due to floating-point rounding. In practice, this drift is negligible.
def verify_parameter_sync(model: DDP, rank: int) -> bool:
"""Verify that all ranks have identical parameters.
Computes a hash of model parameters on each rank and checks
that all hashes agree. Use after initialization and periodically
during training to catch synchronization bugs.
Args:
model: DDP-wrapped model.
rank: This process's GPU rank.
Returns:
True if all ranks have identical parameters.
"""
# Compute a scalar fingerprint of all parameters
param_sum = sum(
p.data.sum().item() for p in model.parameters()
)
param_tensor = torch.tensor([param_sum], device=rank)
# Gather all fingerprints to rank 0
gathered = [
torch.zeros(1, device=rank) for _ in range(dist.get_world_size())
]
dist.all_gather(gathered, param_tensor)
if rank == 0:
values = [t.item() for t in gathered]
synced = all(abs(v - values[0]) < 1e-4 for v in values)
if not synced:
print(f"WARNING: Parameter divergence detected: {values}")
return synced
return True # Only rank 0 checks
26.3 Beyond Data Parallelism: Model, Pipeline, and Tensor Parallelism
Data parallelism works when the model fits on a single GPU. When it does not, you need to partition the model itself across devices.
26.3.1 A Taxonomy of Parallelism
| Strategy | What is distributed | Communication | When to use |
|---|---|---|---|
| Data parallelism | Data (batches) | Gradient all-reduce | Model fits on one GPU; data is large |
| Model parallelism | Model layers | Activations between stages | Model does not fit on one GPU |
| Pipeline parallelism | Model stages across micro-batches | Activations between stages | Very deep models; amortizes pipeline bubbles |
| Tensor parallelism | Individual layers (weight matrices) | Partial activations within a layer | Very wide layers (e.g., large transformer MLP) |
In practice, these strategies are combined. A 175-billion-parameter model might use tensor parallelism within each node (splitting large matrix multiplications across 8 GPUs connected via NVLink), pipeline parallelism across nodes (splitting the 96 transformer layers into 8 stages of 12 layers each), and data parallelism across pipeline replicas (running 4 pipeline replicas on different data partitions). This is 3D parallelism.
26.3.2 Model Parallelism (Naive Layer Partitioning)
The simplest form of model parallelism assigns different layers to different GPUs. For a 48-layer transformer:
- GPU 0: Embedding + Layers 1-12
- GPU 1: Layers 13-24
- GPU 2: Layers 25-36
- GPU 3: Layers 37-48 + Output head
The forward pass is sequential: GPU 0 computes its layers, sends the output activation tensor to GPU 1, which computes its layers, sends to GPU 2, and so on. The backward pass reverses the direction.
The problem: the pipeline bubble. At any given moment, only one GPU is active — the others are idle, waiting for activations or gradients. If each GPU takes time $T$ for its forward pass, the total forward pass takes $4T$ instead of $T$, but only one GPU is working at a time. GPU utilization is $1/N = 25\%$.
Naive Model Parallelism (4 GPUs, 1 micro-batch):
Time ─────────────────────────────────────────────────▶
GPU 0: [ Forward ][ idle ][ Backward ][ idle ]
GPU 1: [ idle ][ Forward ][ idle ][ idle ][ Backward ][ ]
GPU 2: [ idle ][ Forward ][ idle ][ idle ][ Backward ]
GPU 3: [ idle ][Forward][Backward][ idle ]
◄──────── Pipeline bubble: wasted compute ────────────────────────────►
26.3.3 Pipeline Parallelism: Filling the Bubble
Pipeline parallelism reduces the bubble by splitting the global batch into $M$ micro-batches and feeding them through the pipeline in sequence. While GPU 1 processes micro-batch 1, GPU 0 can start processing micro-batch 2.
GPipe (Huang et al., 2019). The GPipe schedule processes all micro-batches in the forward direction, then all in the backward direction. The bubble fraction is:
$$\text{Bubble fraction} = \frac{N - 1}{N + M - 1}$$
where $N$ is the number of pipeline stages and $M$ is the number of micro-batches. With $N=4$ and $M=16$, the bubble is $3/19 \approx 16\%$ — a dramatic improvement over 75%.
GPipe Schedule (4 GPUs, 8 micro-batches):
Time ──────────────────────────────────────────────────────────────────▶
GPU 0: [F1][F2][F3][F4][F5][F6][F7][F8][ idle ][B8][B7][B6][B5][B4][B3][B2][B1]
GPU 1: [F1][F2][F3][F4][F5][F6][F7][F8][idle][B8][B7][B6][B5][B4][B3][B2][B1]
GPU 2: [F1][F2][F3][F4][F5][F6][F7][F8][ ][B8][B7][B6][B5][B4][B3][B2][B1]
GPU 3: [F1][F2][F3][F4][F5][F6][F7][F8][B8][B7][B6][B5][B4][B3][B2][B1]
Fn = forward pass on micro-batch n
Bn = backward pass on micro-batch n
Bubble fraction = 3/11 ≈ 27% (with M=8, N=4)
1F1B (One Forward, One Backward). An alternative schedule that interleaves forward and backward micro-batches, reducing the peak activation memory. Each GPU runs one forward pass, then alternates between forward and backward passes for subsequent micro-batches. DeepSpeed and Megatron-LM use 1F1B as the default pipeline schedule.
26.3.4 Tensor Parallelism
Tensor parallelism splits individual layers across GPUs. Consider a transformer's feed-forward network:
$$\text{FFN}(x) = \text{GeLU}(xW_1)W_2$$
where $W_1 \in \mathbb{R}^{d \times 4d}$ and $W_2 \in \mathbb{R}^{4d \times d}$. With 2-way tensor parallelism:
- Split $W_1$ column-wise: GPU 0 holds $W_1^{(0)} \in \mathbb{R}^{d \times 2d}$, GPU 1 holds $W_1^{(1)} \in \mathbb{R}^{d \times 2d}$.
- Each GPU computes $\text{GeLU}(xW_1^{(k)})$ locally — no communication needed because GeLU is element-wise.
- Split $W_2$ row-wise: GPU 0 holds $W_2^{(0)} \in \mathbb{R}^{2d \times d}$, GPU 1 holds $W_2^{(1)} \in \mathbb{R}^{2d \times d}$.
- Each GPU computes a partial output $y_k = \text{GeLU}(xW_1^{(k)}) W_2^{(k)}$.
- All-reduce to sum the partial outputs: $y = y_0 + y_1$.
The all-reduce happens within each transformer layer, so tensor parallelism requires extremely high inter-GPU bandwidth. It is only practical within a single node connected by NVLink (600 GB/s on DGX A100) — not across nodes connected by InfiniBand (25-50 GB/s).
For attention heads, tensor parallelism is even more natural: each GPU computes a subset of attention heads, and the results are concatenated. A 32-head attention layer with 8-way tensor parallelism assigns 4 heads per GPU.
26.3.5 Choosing a Strategy
@dataclass
class ParallelismDecision:
"""Decision framework for selecting a parallelism strategy.
Attributes:
model_size_gb: Model size in GB (parameters + optimizer state).
single_gpu_memory_gb: Available memory on a single GPU.
num_gpus: Total available GPUs.
inter_gpu_bandwidth_gbps: Bandwidth between GPUs (NVLink or IB).
training_data_size_gb: Total dataset size in GB.
target_batch_size: Desired global batch size.
"""
model_size_gb: float
single_gpu_memory_gb: float = 80.0 # A100 80GB
num_gpus: int = 8
inter_gpu_bandwidth_gbps: float = 600.0 # NVLink
training_data_size_gb: float = 100.0
target_batch_size: int = 4096
def recommend_strategy(self) -> str:
"""Recommend a parallelism strategy based on hardware and model.
Returns:
Human-readable strategy recommendation.
"""
fits_single_gpu = self.model_size_gb < self.single_gpu_memory_gb * 0.6
if fits_single_gpu:
return (
"Data Parallelism (DDP): Model fits on a single GPU. "
"Replicate across all GPUs, partition data, synchronize "
"gradients via all-reduce. This is the simplest and most "
"efficient strategy."
)
fits_with_optimizations = (
self.model_size_gb < self.single_gpu_memory_gb * 0.9
)
if fits_with_optimizations:
return (
"Data Parallelism + Memory Optimization: Model barely fits "
"on a single GPU. Use DDP with mixed precision (halves "
"activation memory), gradient checkpointing (trades compute "
"for memory), and/or FSDP (shards optimizer state). See "
"Sections 26.4-26.5."
)
# Model doesn't fit on a single GPU
high_bandwidth = self.inter_gpu_bandwidth_gbps > 100 # NVLink
if high_bandwidth and self.num_gpus <= 8:
return (
"Tensor Parallelism + Data Parallelism: Model does not fit "
"on a single GPU but GPUs have high-bandwidth interconnect. "
"Split large layers across GPUs within a node (tensor "
"parallelism), replicate across nodes (data parallelism)."
)
return (
"3D Parallelism: Model is very large and spans multiple nodes. "
"Use tensor parallelism within nodes, pipeline parallelism "
"across nodes, and data parallelism across pipeline replicas. "
"Frameworks: DeepSpeed ZeRO Stage 3, Megatron-LM."
)
# Example: StreamRec two-tower model
streamrec_decision = ParallelismDecision(
model_size_gb=0.5, # 45M params * 4 bytes + Adam state ≈ 0.5 GB
single_gpu_memory_gb=80.0,
num_gpus=4,
training_data_size_gb=150.0, # 1.2B interactions
)
print("StreamRec:", streamrec_decision.recommend_strategy())
# Example: Climate DL foundation model
climate_decision = ParallelismDecision(
model_size_gb=14.4, # 1.2B params * 4 bytes + Adam state ≈ 14.4 GB
single_gpu_memory_gb=80.0,
num_gpus=64,
inter_gpu_bandwidth_gbps=600.0,
training_data_size_gb=200_000.0, # 200 TB
)
print("Climate DL:", climate_decision.recommend_strategy())
StreamRec: Data Parallelism (DDP): Model fits on a single GPU. Replicate
across all GPUs, partition data, synchronize gradients via all-reduce.
This is the simplest and most efficient strategy.
Climate DL: Data Parallelism + Memory Optimization: Model barely fits on
a single GPU. Use DDP with mixed precision (halves activation memory),
gradient checkpointing (trades compute for memory), and/or FSDP (shards
optimizer state). See Sections 26.4-26.5.
StreamRec's two-tower model is small enough for pure data parallelism. The Climate DL model, at 1.2 billion parameters (approximately 14.4 GB including Adam optimizer state), fits on an A100 80 GB with mixed precision but requires memory optimizations for reasonable batch sizes.
26.4 GPU Memory Hierarchy and the Memory Wall
Before we can optimize GPU utilization, we need to understand what consumes GPU memory during training and where the bottlenecks are.
26.4.1 What Consumes GPU Memory?
For a model with $P$ parameters trained with Adam in FP32:
| Component | Memory | StreamRec (45M params) | Climate DL (1.2B params) |
|---|---|---|---|
| Parameters ($\theta$) | $4P$ bytes | 180 MB | 4.8 GB |
| Gradients ($\nabla\theta$) | $4P$ bytes | 180 MB | 4.8 GB |
| Adam first moment ($m$) | $4P$ bytes | 180 MB | 4.8 GB |
| Adam second moment ($v$) | $4P$ bytes | 180 MB | 4.8 GB |
| Optimizer total | $16P$ bytes | 720 MB | 19.2 GB |
| Activations | $O(B \cdot L \cdot d^2)$ | Variable | Variable |
The surprise: activations dominate memory for large batches. A single transformer layer with hidden dimension $d = 4096$, sequence length $s = 1024$, and batch size $B = 32$ stores activations of approximately:
$$\text{Activation memory per layer} \approx 2Bsd + 2Bs^2 \text{ (attention scores)} = 2 \times 32 \times 1024 \times 4096 + 2 \times 32 \times 1024^2 \approx 335 \text{ MB}$$
For a 48-layer transformer, that is 16 GB of activations — more than the model parameters and optimizer state combined.
26.4.2 The GPU Memory Hierarchy
Modern GPUs have a three-level memory hierarchy:
GPU Memory Hierarchy (NVIDIA A100)
┌──────────────────────────────────┐
│ SRAM (On-chip) │
│ 20 MB total │
│ 19 TB/s bandwidth │
│ Per-SM: 192 KB │ ← Registers + shared memory
│ Latency: ~1 cycle │
├──────────────────────────────────┤
│ HBM (Device Memory) │
│ 80 GB (A100 80GB) │
│ 2.0 TB/s bandwidth │ ← Parameters, activations, gradients
│ Latency: ~400 cycles │
├──────────────────────────────────┤
│ CPU DRAM (Host Memory) │
│ 512+ GB typical │
│ ~50 GB/s (PCIe Gen4) │ ← CPU offloading target
│ Latency: ~10,000 cycles │
└──────────────────────────────────┘
Key insight: SRAM is 10x faster than HBM, but 4,000x smaller.
HBM is 40x faster than CPU DRAM, but often 6-8x smaller.
The performance of most deep learning operations is determined by which level of the hierarchy they are bound by:
Compute-bound operations. The operation has enough data to keep the GPU's arithmetic units busy. Matrix multiplications (GEMMs) with large dimensions are typically compute-bound: the $O(n^3)$ arithmetic on $O(n^2)$ data gives high arithmetic intensity (FLOPs per byte accessed from memory).
Memory-bound operations. The operation cannot keep the arithmetic units busy because it spends most of its time reading and writing data. Element-wise operations (activation functions, layer normalization, dropout) have arithmetic intensity of $O(1)$ — one FLOP per byte — and are memory-bound. They execute at HBM bandwidth speed, far below the GPU's peak FLOP rate.
26.4.3 Arithmetic Intensity and the Roofline Model
Arithmetic intensity $I$ is the ratio of FLOPs to bytes accessed:
$$I = \frac{\text{FLOPs}}{\text{Bytes accessed from memory}}$$
The roofline model defines the maximum achievable performance:
$$\text{Attainable FLOP/s} = \min(\text{Peak FLOP/s}, \; I \times \text{Memory Bandwidth})$$
For an A100 (FP16): - Peak compute: 312 TFLOP/s - HBM bandwidth: 2.0 TB/s - Ridge point: $312 / 2.0 = 156$ FLOP/byte
Operations with $I > 156$ are compute-bound (limited by the 312 TFLOP/s ceiling). Operations with $I < 156$ are memory-bound (limited by the 2.0 TB/s memory bandwidth).
| Operation | Arithmetic Intensity | Bound |
|---|---|---|
| Large GEMM ($M=N=K=4096$) | ~4096 FLOP/byte | Compute |
| Small GEMM ($M=N=K=128$) | ~128 FLOP/byte | Memory |
| ReLU | 0.25 FLOP/byte | Memory |
| LayerNorm | ~5 FLOP/byte | Memory |
| Softmax | ~5 FLOP/byte | Memory |
| Standard attention ($QK^T$, softmax, $\times V$) | ~$s/4$ FLOP/byte | Memory for short $s$ |
26.4.4 Model FLOP Utilization (MFU)
Model FLOP Utilization (MFU), introduced by Chowdhery et al. (2022) in the PaLM paper, measures what fraction of a GPU's theoretical peak FLOP/s is actually used for model computation:
$$\text{MFU} = \frac{\text{Model FLOPs per second (observed)}}{\text{GPU peak FLOP/s}}$$
MFU excludes communication overhead, memory-bound operations, and pipeline bubbles. Typical values:
| Configuration | Typical MFU |
|---|---|
| Single GPU, large GEMM | 60-75% |
| 8-GPU DDP (intra-node NVLink) | 50-65% |
| 64-GPU DDP (multi-node) | 35-55% |
| 3D parallelism, large model | 40-55% |
| Unoptimized code | 10-25% |
An MFU below 30% indicates significant optimization opportunities. The techniques in Sections 26.5 and 26.6 target this metric.
26.5 Mixed-Precision Training and AMP
26.5.1 The FP16 Opportunity
FP16 (16-bit floating point) uses half the memory of FP32 for both parameters and activations, and modern GPUs have specialized hardware (Tensor Cores on NVIDIA GPUs) that execute FP16 matrix multiplications at 2-8x the throughput of FP32.
| Data Type | Bits | Range | Precision | A100 Peak |
|---|---|---|---|---|
| FP32 | 32 | $\pm 3.4 \times 10^{38}$ | $\sim 7$ decimal digits | 19.5 TFLOP/s |
| FP16 | 16 | $\pm 65504$ | $\sim 3.3$ decimal digits | 312 TFLOP/s |
| BF16 | 16 | $\pm 3.4 \times 10^{38}$ | $\sim 2.4$ decimal digits | 312 TFLOP/s |
| TF32 | 19 | $\pm 3.4 \times 10^{38}$ | $\sim 3.4$ decimal digits | 156 TFLOP/s |
BF16 (bfloat16) has the same exponent range as FP32, making it robust to overflow — the primary failure mode of FP16 training. On Ampere and later GPUs, BF16 is generally preferred over FP16.
26.5.2 The Problem: Precision Loss
Naively replacing all FP32 operations with FP16 fails for two reasons:
-
Gradient underflow. Small gradients (common in early layers of deep networks) fall below FP16's minimum representable positive value ($\approx 5.96 \times 10^{-8}$) and become zero. Once zeroed, the parameter stops updating.
-
Accumulation error. Summing many small FP16 values (e.g., in a reduction or batch normalization) accumulates rounding error that degrades the result.
26.5.3 Automatic Mixed Precision (AMP)
AMP, introduced by Micikevicius et al. (2018) and implemented in PyTorch as torch.cuda.amp, solves both problems through two mechanisms:
Selective precision. AMP maintains a "master copy" of weights in FP32. Forward and backward passes use FP16 (or BF16) for compute-bound operations (matrix multiplications, convolutions) and FP32 for operations that require higher precision (reductions, layer normalization, softmax, loss computation).
Loss scaling. Before the backward pass, the loss is multiplied by a large constant $S$ (the loss scale). This shifts all gradients into the representable range of FP16, preventing underflow. After the backward pass, gradients are divided by $S$ before the optimizer step. If any gradient overflows (becomes inf), the optimizer step is skipped and $S$ is halved. If no overflow occurs for a window of iterations, $S$ is doubled. This dynamic loss scaling adapts to the gradient distribution automatically.
import torch
from torch.cuda.amp import GradScaler, autocast
from dataclasses import dataclass
from typing import Dict, Optional
@dataclass
class AMPConfig:
"""Configuration for Automatic Mixed Precision training.
Attributes:
enabled: Whether to use AMP.
dtype: The reduced-precision dtype (float16 or bfloat16).
initial_scale: Initial loss scale for dynamic scaling.
growth_interval: Steps between loss scale increases.
"""
enabled: bool = True
dtype: torch.dtype = torch.bfloat16
initial_scale: float = 2.0 ** 16
growth_interval: int = 2000
def train_one_epoch_amp(
model: torch.nn.Module,
dataloader: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
amp_config: AMPConfig,
scaler: Optional[GradScaler] = None,
device: str = "cuda",
) -> Dict[str, float]:
"""Train for one epoch with Automatic Mixed Precision.
Uses torch.cuda.amp.autocast for forward pass and GradScaler
for loss scaling. Compatible with DDP (wrap model with DDP first,
then use this training loop).
Args:
model: The model (may be DDP-wrapped).
dataloader: Training data loader.
optimizer: Optimizer.
amp_config: AMP configuration.
scaler: GradScaler instance. If None and AMP is enabled,
creates one. Reuse across epochs for correct scaling state.
device: Device string.
Returns:
Dictionary with loss, skipped_steps (overflow count),
and current loss scale.
"""
model.train()
loss_fn = torch.nn.BCEWithLogitsLoss()
if amp_config.enabled and scaler is None:
scaler = GradScaler(
init_scale=amp_config.initial_scale,
growth_interval=amp_config.growth_interval,
)
total_loss = 0.0
num_batches = 0
skipped_steps = 0
for batch in dataloader:
user_ids = batch["user_id"].to(device)
item_ids = batch["item_id"].to(device)
labels = batch["label"].float().to(device)
optimizer.zero_grad()
if amp_config.enabled:
# Forward pass in reduced precision
with autocast(device_type="cuda", dtype=amp_config.dtype):
scores = model(user_ids, item_ids)
loss = loss_fn(scores, labels)
# Backward pass with loss scaling
scaler.scale(loss).backward()
# Unscale gradients, check for overflow, step
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scale_before = scaler.get_scale()
scaler.step(optimizer)
scaler.update()
if scaler.get_scale() < scale_before:
skipped_steps += 1
else:
scores = model(user_ids, item_ids)
loss = loss_fn(scores, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item()
num_batches += 1
return {
"loss": total_loss / max(num_batches, 1),
"skipped_steps": skipped_steps,
"loss_scale": scaler.get_scale() if scaler else 1.0,
}
26.5.4 Memory Savings in Practice
AMP reduces memory consumption through two mechanisms:
- Activation memory. Activations stored for the backward pass are in FP16/BF16, halving activation memory.
- Weight memory. The master weights remain in FP32, but the working copies used in the forward pass are FP16/BF16. Net savings depend on whether optimizer states are also reduced.
For the Climate DL model (1.2B parameters, 48-layer transformer, batch size 32, sequence length 1024):
| Configuration | Parameter Memory | Activation Memory | Total |
|---|---|---|---|
| FP32 baseline | 19.2 GB | ~16 GB | ~35 GB |
| AMP (BF16) | 19.2 GB (master) + 2.4 GB (working) | ~8 GB | ~30 GB |
| AMP + FSDP | ~5 GB (sharded) + 2.4 GB | ~8 GB | ~15 GB |
The 30 GB AMP configuration fits on an A100 80 GB with room for larger batch sizes. Combined with FSDP (Section 26.7), the memory footprint drops further.
26.6 Gradient Checkpointing: Trading Compute for Memory
26.6.1 The Activation Memory Problem
During the backward pass, PyTorch needs the intermediate activations from the forward pass to compute gradients. By default, it stores all of them — for a 48-layer transformer, that means storing the output of every layer, every attention computation, and every feed-forward block.
Gradient checkpointing (Chen et al., 2016) offers a trade-off: discard some activations during the forward pass and recompute them during the backward pass. The result is less memory at the cost of more compute.
26.6.2 The Trade-off
For a network with $L$ layers:
| Strategy | Activation Memory | Compute (forward passes) |
|---|---|---|
| No checkpointing | $O(L)$ | 1 |
| Checkpoint every layer | $O(1)$ | 2 (each activation recomputed once) |
| Checkpoint every $\sqrt{L}$ layers | $O(\sqrt{L})$ | $\leq 1 + L/\sqrt{L} = 1 + \sqrt{L}$ |
The $\sqrt{L}$ strategy is optimal: for a 48-layer network, it stores $\lceil\sqrt{48}\rceil = 7$ checkpoint activations and recomputes at most 7 segments of ~7 layers each. The overhead is approximately 30-40% additional forward compute — a significant cost, but far less than the 2x worst case.
26.6.3 Implementation
PyTorch provides torch.utils.checkpoint.checkpoint for wrapping individual modules:
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from dataclasses import dataclass
from typing import List
class TransformerBlock(nn.Module):
"""A single transformer block (simplified for illustration).
Args:
d_model: Hidden dimension.
n_heads: Number of attention heads.
d_ff: Feed-forward dimension.
dropout: Dropout rate.
"""
def __init__(
self,
d_model: int = 768,
n_heads: int = 12,
d_ff: int = 3072,
dropout: float = 0.1,
) -> None:
super().__init__()
self.attn = nn.MultiheadAttention(
d_model, n_heads, dropout=dropout, batch_first=True
)
self.norm1 = nn.LayerNorm(d_model)
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout),
)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass with pre-norm residual connections.
Args:
x: Input tensor of shape (batch, seq_len, d_model).
Returns:
Output tensor of same shape.
"""
normed = self.norm1(x)
attn_out, _ = self.attn(normed, normed, normed)
x = x + attn_out
x = x + self.ff(self.norm2(x))
return x
class CheckpointedTransformer(nn.Module):
"""Transformer with gradient checkpointing.
Checkpoints every `checkpoint_every` layers to reduce activation
memory at the cost of recomputing activations during backward.
Args:
d_model: Hidden dimension.
n_heads: Number of attention heads.
d_ff: Feed-forward dimension.
n_layers: Total number of transformer layers.
dropout: Dropout rate.
checkpoint_every: Checkpoint interval (0 = no checkpointing).
"""
def __init__(
self,
d_model: int = 768,
n_heads: int = 12,
d_ff: int = 3072,
n_layers: int = 48,
dropout: float = 0.1,
checkpoint_every: int = 0,
) -> None:
super().__init__()
self.layers = nn.ModuleList([
TransformerBlock(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
self.checkpoint_every = checkpoint_every
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass with optional gradient checkpointing.
Args:
x: Input tensor of shape (batch, seq_len, d_model).
Returns:
Output tensor of same shape.
"""
for i, layer in enumerate(self.layers):
if (
self.checkpoint_every > 0
and i % self.checkpoint_every == 0
and self.training
):
# Checkpoint: activations will be recomputed in backward
x = checkpoint(layer, x, use_reentrant=False)
else:
x = layer(x)
return x
# Example: Climate DL transformer with sqrt(48) ≈ 7 checkpoint interval
climate_model = CheckpointedTransformer(
d_model=1536,
n_heads=24,
d_ff=6144,
n_layers=48,
checkpoint_every=7, # Checkpoint every 7 layers (≈ sqrt(48))
)
# Memory comparison (illustrative)
# Without checkpointing: stores 48 layers of activations
# With checkpoint_every=7: stores 7 checkpoint activations,
# recomputes up to 6 layers per segment during backward
# Memory reduction: ~85% of activation memory saved
# Compute overhead: ~33% additional forward compute
26.6.4 When to Use Gradient Checkpointing
Gradient checkpointing is most valuable when:
- Activation memory is the bottleneck. If optimizer state dominates (small model, large optimizer), checkpointing provides little benefit.
- Increasing batch size would improve training efficiency. If you are running a small batch per GPU because activations consume too much memory, checkpointing lets you increase the batch size — which reduces the relative overhead of gradient synchronization in DDP.
- The model is deep. Checkpointing saves $O(L)$ activation memory. Shallow models with large per-layer activations benefit less than deep models.
For the Climate DL model, gradient checkpointing reduces activation memory from approximately 16 GB to 2.3 GB, freeing 14 GB for larger batch sizes. The 33% compute overhead is a worthwhile trade-off: the larger batch size improves GPU utilization (more compute-bound operations relative to memory-bound ones) and reduces the number of gradient synchronization steps per epoch.
26.7 FlashAttention: Algorithmic Memory Efficiency
26.7.1 The Standard Attention Memory Problem
Standard self-attention computes:
$$\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$
where $Q, K, V \in \mathbb{R}^{B \times s \times d_k}$. The intermediate attention matrix $QK^T \in \mathbb{R}^{B \times s \times s}$ has $O(Bs^2)$ elements. For $B = 32$, $s = 2048$, and FP16:
$$\text{Attention matrix memory} = 32 \times 2048 \times 2048 \times 2 \text{ bytes} = 256 \text{ MB}$$
For a 48-layer model with multi-head attention, the total attention matrix memory is $48 \times 256 = 12.3$ GB — and that is just one of the intermediate tensors stored for the backward pass.
26.7.2 FlashAttention: Tiling on SRAM
FlashAttention (Dao et al., 2022) avoids materializing the full $s \times s$ attention matrix in HBM. Instead, it computes attention in tiles that fit in the GPU's on-chip SRAM (shared memory), streaming through the $Q$, $K$, and $V$ matrices block by block.
The algorithm:
- Divide $Q$ into blocks of $B_r$ rows and $K$, $V$ into blocks of $B_c$ rows.
- For each block of $Q$: a. Load the $Q$ block into SRAM. b. Iterate over $K$, $V$ blocks, computing the partial attention scores, applying softmax (with online softmax for numerical stability), and accumulating the output. c. Write the output block back to HBM.
The key insight: softmax can be computed in a single pass over the $K$ blocks using the online softmax trick (Milakov and Gimelshein, 2018), which tracks the running maximum and denominator without storing the full attention matrix.
Memory savings. FlashAttention reduces the memory overhead of attention from $O(Bs^2)$ to $O(Bs)$ — the full attention matrix is never materialized in HBM. For $s = 2048$, this is a $2048\times$ reduction in intermediate memory.
Speed improvement. Despite computing the same mathematical function, FlashAttention is faster because it reduces HBM reads and writes. Standard attention makes three HBM round trips (write $QK^T$, read for softmax, write softmax result, read for multiplication with $V$). FlashAttention fuses all these into a single kernel, performing most arithmetic on SRAM.
26.7.3 FlashAttention in Practice
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional
@dataclass
class AttentionBenchmark:
"""Benchmark configuration for comparing attention implementations.
Attributes:
batch_size: Batch size.
seq_len: Sequence length.
d_model: Model dimension.
n_heads: Number of attention heads.
dtype: Data type for computation.
"""
batch_size: int = 32
seq_len: int = 2048
d_model: int = 1536
n_heads: int = 24
dtype: torch.dtype = torch.bfloat16
@property
def d_head(self) -> int:
"""Dimension per attention head."""
return self.d_model // self.n_heads
def standard_attention_memory_mb(self) -> float:
"""Estimate HBM memory for standard attention (forward only).
Includes Q, K, V, attention scores, softmax output, and
attention output. Does not include backward pass activations.
Returns:
Estimated memory in MB.
"""
bytes_per_element = 2 if self.dtype in (
torch.float16, torch.bfloat16
) else 4
# QKV: 3 * B * s * d
qkv_bytes = 3 * self.batch_size * self.seq_len * self.d_model * bytes_per_element
# Attention scores: B * n_heads * s * s
attn_bytes = (
self.batch_size * self.n_heads * self.seq_len ** 2
* bytes_per_element
)
# Output: B * s * d
out_bytes = self.batch_size * self.seq_len * self.d_model * bytes_per_element
total = qkv_bytes + attn_bytes + out_bytes
return total / (1024 ** 2)
def flash_attention_memory_mb(self) -> float:
"""Estimate HBM memory for FlashAttention (forward only).
FlashAttention does not materialize the full attention matrix.
Memory is O(B * s * d), not O(B * n_heads * s^2).
Returns:
Estimated memory in MB.
"""
bytes_per_element = 2 if self.dtype in (
torch.float16, torch.bfloat16
) else 4
# QKV: 3 * B * s * d
qkv_bytes = 3 * self.batch_size * self.seq_len * self.d_model * bytes_per_element
# Output: B * s * d
out_bytes = self.batch_size * self.seq_len * self.d_model * bytes_per_element
# Logsumexp for backward: B * n_heads * s
lse_bytes = (
self.batch_size * self.n_heads * self.seq_len * 4 # FP32
)
total = qkv_bytes + out_bytes + lse_bytes
return total / (1024 ** 2)
# Compare for Climate DL configuration
bench = AttentionBenchmark(
batch_size=32, seq_len=2048, d_model=1536, n_heads=24,
dtype=torch.bfloat16,
)
print(f"Standard attention memory: {bench.standard_attention_memory_mb():.0f} MB")
print(f"FlashAttention memory: {bench.flash_attention_memory_mb():.0f} MB")
print(f"Memory reduction: {bench.standard_attention_memory_mb() / bench.flash_attention_memory_mb():.1f}x")
Standard attention memory: 6339 MB
FlashAttention memory: 570 MB
Memory reduction: 11.1x
In PyTorch 2.0+, FlashAttention is available through torch.nn.functional.scaled_dot_product_attention, which automatically selects the most efficient attention kernel based on hardware and input size:
class EfficientAttention(nn.Module):
"""Multi-head attention using PyTorch's SDPA with FlashAttention.
Args:
d_model: Model dimension.
n_heads: Number of attention heads.
dropout: Attention dropout rate (only applied during training).
"""
def __init__(
self, d_model: int = 1536, n_heads: int = 24, dropout: float = 0.0
) -> None:
super().__init__()
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.qkv_proj = nn.Linear(d_model, 3 * d_model)
self.out_proj = nn.Linear(d_model, d_model)
self.dropout = dropout
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass using scaled_dot_product_attention.
PyTorch 2.0+ automatically selects FlashAttention, memory-
efficient attention, or standard attention based on hardware,
input size, and mask type.
Args:
x: Input of shape (batch, seq_len, d_model).
attn_mask: Optional attention mask.
Returns:
Output of shape (batch, seq_len, d_model).
"""
B, S, D = x.shape
qkv = self.qkv_proj(x).reshape(B, S, 3, self.n_heads, self.d_head)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, n_heads, S, d_head)
q, k, v = qkv.unbind(0)
# PyTorch dispatches to FlashAttention when available
attn_out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=False,
)
attn_out = attn_out.transpose(1, 2).reshape(B, S, D)
return self.out_proj(attn_out)
Fundamentals > Frontier: FlashAttention is not a new attention mechanism — it computes exactly the same mathematical function as standard attention. The innovation is algorithmic: by restructuring the computation to exploit the GPU memory hierarchy (tiling on SRAM, avoiding HBM round trips), it achieves both lower memory usage and higher throughput. This is a recurring pattern in high-performance computing: understanding the hardware — not just the mathematics — determines whether an algorithm is practical at scale.
26.8 Efficient Large-Batch Training
26.8.1 Why Large Batches?
Distributed training naturally increases the effective batch size: $N$ GPUs each processing a local batch of size $b$ yield a global batch size of $B = Nb$. For StreamRec on 4 GPUs with a local batch of 2048, $B = 8192$. For Climate DL on 64 GPUs with a local batch of 32, $B = 2048$.
Larger batches offer two benefits: 1. Fewer gradient synchronization steps per epoch. Each all-reduce is amortized over more data. 2. Higher GPU utilization. Larger matrix multiplications have higher arithmetic intensity.
But larger batches introduce a problem: the generalization gap. Models trained with very large batches often converge to sharper minima that generalize worse than those found with small batches (Keskar et al., 2017). Naively scaling the batch size from 256 to 8192 can degrade validation accuracy by 1-3 percentage points — a significant loss.
26.8.2 The Linear Scaling Rule
Goyal et al. (2017) proposed a simple recipe: when multiplying the batch size by $k$, multiply the learning rate by $k$. The intuition is that the gradient variance decreases by $1/k$ with larger batches, so a proportionally larger step size is needed to make the same expected progress per epoch.
$$\eta_{\text{new}} = \eta_{\text{base}} \times \frac{B_{\text{new}}}{B_{\text{base}}}$$
This rule works well up to a critical batch size, beyond which it breaks down. For ImageNet training with ResNet-50, the critical batch size is approximately 8,192; for transformer language models, it can be much larger (Kaplan et al., 2020, observed scaling up to batch sizes of 500,000 tokens).
26.8.3 Learning Rate Warmup
The linear scaling rule fails in the early training phase when the network is far from a good region of the loss landscape. Large learning rates applied to randomly initialized weights cause divergence. The fix: learning rate warmup.
During the first $T_{\text{warmup}}$ steps, the learning rate increases linearly from $\eta_0$ (a small value, or zero) to $\eta_{\text{target}}$:
$$\eta(t) = \eta_{\text{target}} \cdot \frac{t}{T_{\text{warmup}}}, \quad t \leq T_{\text{warmup}}$$
After warmup, any standard schedule (cosine annealing, step decay, linear decay) takes over.
import torch
from torch.optim.lr_scheduler import LambdaLR
from dataclasses import dataclass
@dataclass
class LargeBatchConfig:
"""Configuration for large-batch training with LR scaling.
Attributes:
base_lr: Learning rate for the base batch size.
base_batch_size: The batch size at which base_lr was tuned.
global_batch_size: The actual global batch size (local * num_gpus).
warmup_steps: Number of warmup steps.
total_steps: Total training steps.
min_lr_fraction: Minimum LR as a fraction of peak LR
(for cosine annealing).
"""
base_lr: float = 1e-3
base_batch_size: int = 256
global_batch_size: int = 8192
warmup_steps: int = 1000
total_steps: int = 50000
min_lr_fraction: float = 0.01
@property
def scaled_lr(self) -> float:
"""Learning rate after linear scaling."""
return self.base_lr * (self.global_batch_size / self.base_batch_size)
def create_warmup_cosine_scheduler(
optimizer: torch.optim.Optimizer,
config: LargeBatchConfig,
) -> LambdaLR:
"""Create a warmup + cosine annealing LR scheduler.
Linear warmup for config.warmup_steps, then cosine decay
from config.scaled_lr to config.scaled_lr * config.min_lr_fraction.
Args:
optimizer: The optimizer whose LR to schedule.
config: Large-batch configuration.
Returns:
LambdaLR scheduler.
"""
import math
def lr_lambda(step: int) -> float:
if step < config.warmup_steps:
return step / max(config.warmup_steps, 1)
progress = (step - config.warmup_steps) / max(
config.total_steps - config.warmup_steps, 1
)
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
return config.min_lr_fraction + (
1 - config.min_lr_fraction
) * cosine_decay
return LambdaLR(optimizer, lr_lambda)
# Example: StreamRec scaling from 256 to 8192
config = LargeBatchConfig(
base_lr=1e-3,
base_batch_size=256,
global_batch_size=8192,
warmup_steps=500,
total_steps=20000,
)
print(f"Base LR: {config.base_lr}")
print(f"Scaled LR: {config.scaled_lr}")
print(f"Scaling factor: {config.global_batch_size / config.base_batch_size}x")
Base LR: 0.001
Scaled LR: 0.032
Scaling factor: 32.0x
26.8.4 LARS and LAMB: Layer-wise Adaptive Scaling
The linear scaling rule assumes all layers need the same learning rate scaling, but in practice, different layers have different gradient-to-weight ratios. LARS (Layer-wise Adaptive Rate Scaling) (You et al., 2017) adjusts the learning rate per layer based on the ratio of the weight norm to the gradient norm:
$$\eta_l = \eta \cdot \phi \cdot \frac{\|w_l\|}{\|g_l\| + \lambda\|w_l\|}$$
where $\eta$ is the global learning rate, $\phi$ is a trust coefficient (typically 0.001-0.02), $w_l$ and $g_l$ are the weights and gradients of layer $l$, and $\lambda$ is the weight decay coefficient.
LAMB (Layer-wise Adaptive Moments optimizer for Batch training) (You et al., 2020) extends LARS to Adam by applying the layer-wise scaling to the Adam update:
$$r_l = \frac{\|w_l\|}{\left\|\frac{m_l / (1 - \beta_1^t)}{\sqrt{v_l / (1 - \beta_2^t)} + \epsilon} + \lambda w_l\right\|}$$
$$w_l \leftarrow w_l - \eta \cdot r_l \cdot \left(\frac{m_l / (1 - \beta_1^t)}{\sqrt{v_l / (1 - \beta_2^t)} + \epsilon} + \lambda w_l\right)$$
LAMB enabled training BERT in 76 minutes on 1,024 TPU v3 chips with a batch size of 65,536 — a dramatic reduction from the original 3-day training time.
26.8.5 Practical Recipe
The following recipe summarizes the state of the art for efficient large-batch training:
- Start with a well-tuned single-GPU configuration (batch size $B_0$, learning rate $\eta_0$, converged validation metric).
- Scale batch size gradually ($2\times$, $4\times$, $8\times$, ...) and apply linear LR scaling.
- Add warmup proportional to the scaling factor: $T_{\text{warmup}} \approx 5\text{-}10\%$ of total steps.
- Use LAMB for transformer models with batch sizes above 4,096.
- Monitor the critical batch size: if validation loss degrades despite LR scaling and warmup, you have exceeded the critical batch size.
- Gradient accumulation can simulate large batches on fewer GPUs by accumulating gradients over multiple micro-steps before the optimizer step.
def train_with_gradient_accumulation(
model: torch.nn.Module,
dataloader: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
accumulation_steps: int = 4,
device: str = "cuda",
) -> float:
"""Train with gradient accumulation to simulate a larger batch.
Accumulates gradients over `accumulation_steps` micro-batches
before performing an optimizer step. The effective batch size is
local_batch_size * accumulation_steps * num_gpus.
Args:
model: The model.
dataloader: Training data loader.
optimizer: Optimizer.
accumulation_steps: Number of micro-batches to accumulate.
device: Device string.
Returns:
Average loss over the epoch.
"""
model.train()
loss_fn = torch.nn.BCEWithLogitsLoss()
total_loss = 0.0
num_steps = 0
for i, batch in enumerate(dataloader):
user_ids = batch["user_id"].to(device)
item_ids = batch["item_id"].to(device)
labels = batch["label"].float().to(device)
scores = model(user_ids, item_ids)
loss = loss_fn(scores, labels) / accumulation_steps
loss.backward()
if (i + 1) % accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=1.0
)
optimizer.step()
optimizer.zero_grad()
num_steps += 1
total_loss += loss.item() * accumulation_steps
# Handle remaining batches
if (i + 1) % accumulation_steps != 0:
optimizer.step()
optimizer.zero_grad()
num_steps += 1
return total_loss / max(len(dataloader), 1)
26.9 DeepSpeed and Fully Sharded Data Parallelism (FSDP)
26.9.1 The ZeRO Insight
In standard data parallelism, every GPU holds a complete copy of the model parameters, gradients, and optimizer state. For a 1.2B parameter model with Adam in FP32, that is 19.2 GB replicated on every GPU — a waste of memory that limits the maximum model size and batch size.
ZeRO (Zero Redundancy Optimizer) (Rajbhandari et al., 2020) eliminates this redundancy by sharding the model state across GPUs. It has three stages:
| Stage | Shards | Memory per GPU | Communication |
|---|---|---|---|
| ZeRO-1 | Optimizer state | $4P + 4P + \frac{8P}{N}$ | Same as DDP |
| ZeRO-2 | Optimizer state + gradients | $4P + \frac{12P}{N}$ | Same as DDP |
| ZeRO-3 | Everything (params + grads + optimizer) | $\frac{16P}{N}$ | All-gather before each forward/backward step |
For the Climate DL model (1.2B parameters, $P = 1.2 \times 10^9$) on 8 GPUs:
| Configuration | Memory per GPU |
|---|---|
| Standard DDP | 19.2 GB |
| ZeRO-1 | 10.8 GB |
| ZeRO-2 | 6.0 GB |
| ZeRO-3 | 2.4 GB |
ZeRO-3 reduces the per-GPU memory by $8\times$, enabling either much larger models or much larger batch sizes.
26.9.2 DeepSpeed
DeepSpeed is Microsoft's library for large-scale distributed training. It implements ZeRO and integrates with PyTorch.
import torch
import json
from dataclasses import dataclass, asdict
from typing import Dict, Any
@dataclass
class DeepSpeedConfig:
"""DeepSpeed configuration for distributed training.
Generates the JSON config file that DeepSpeed expects.
Attributes:
zero_stage: ZeRO optimization stage (0, 1, 2, or 3).
fp16_enabled: Whether to use FP16 mixed precision.
bf16_enabled: Whether to use BF16 mixed precision.
gradient_accumulation_steps: Micro-batches before optimizer step.
train_micro_batch_size_per_gpu: Local batch size per GPU.
gradient_clipping: Maximum gradient norm.
offload_optimizer: Whether to offload optimizer state to CPU.
offload_param: Whether to offload parameters to CPU (ZeRO-3).
"""
zero_stage: int = 2
fp16_enabled: bool = False
bf16_enabled: bool = True
gradient_accumulation_steps: int = 4
train_micro_batch_size_per_gpu: int = 32
gradient_clipping: float = 1.0
offload_optimizer: bool = False
offload_param: bool = False
def to_dict(self) -> Dict[str, Any]:
"""Convert to DeepSpeed config dictionary.
Returns:
Dictionary suitable for deepspeed.initialize().
"""
config: Dict[str, Any] = {
"train_micro_batch_size_per_gpu": (
self.train_micro_batch_size_per_gpu
),
"gradient_accumulation_steps": (
self.gradient_accumulation_steps
),
"gradient_clipping": self.gradient_clipping,
"zero_optimization": {
"stage": self.zero_stage,
},
}
if self.zero_stage >= 3:
config["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = True
if self.offload_optimizer:
config["zero_optimization"]["offload_optimizer"] = {
"device": "cpu",
"pin_memory": True,
}
if self.offload_param and self.zero_stage == 3:
config["zero_optimization"]["offload_param"] = {
"device": "cpu",
"pin_memory": True,
}
if self.bf16_enabled:
config["bf16"] = {"enabled": True}
elif self.fp16_enabled:
config["fp16"] = {
"enabled": True,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1,
}
return config
def to_json(self, path: str) -> None:
"""Write config to a JSON file.
Args:
path: Output file path.
"""
with open(path, "w") as f:
json.dump(self.to_dict(), f, indent=2)
# Example: Climate DL configuration
climate_ds_config = DeepSpeedConfig(
zero_stage=2,
bf16_enabled=True,
gradient_accumulation_steps=8,
train_micro_batch_size_per_gpu=4, # Small due to large activations
gradient_clipping=1.0,
)
print(json.dumps(climate_ds_config.to_dict(), indent=2))
{
"train_micro_batch_size_per_gpu": 4,
"gradient_accumulation_steps": 8,
"gradient_clipping": 1.0,
"zero_optimization": {
"stage": 2
},
"bf16": {
"enabled": true
}
}
26.9.3 PyTorch FSDP
Fully Sharded Data Parallelism (FSDP), PyTorch's native implementation of ZeRO-3, integrates directly with PyTorch's autograd and distributed infrastructure. Unlike DeepSpeed, FSDP does not require a separate launcher or configuration file.
import torch
import torch.nn as nn
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from functools import partial
def wrap_model_fsdp(
model: nn.Module,
sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD,
use_bf16: bool = True,
) -> FSDP:
"""Wrap a model with PyTorch FSDP for memory-efficient training.
Args:
model: The model to shard.
sharding_strategy: FULL_SHARD (ZeRO-3), SHARD_GRAD_OP (ZeRO-2),
or NO_SHARD (DDP equivalent).
use_bf16: Whether to use BF16 mixed precision.
Returns:
FSDP-wrapped model.
"""
mixed_precision = None
if use_bf16:
mixed_precision = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
)
# Auto-wrap policy: shard at transformer block boundaries
wrap_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={TransformerBlock},
)
return FSDP(
model,
sharding_strategy=sharding_strategy,
mixed_precision=mixed_precision,
auto_wrap_policy=wrap_policy,
device_id=torch.cuda.current_device(),
)
FSDP shards parameters across GPUs: before each forward pass, it performs an all-gather to reconstruct the full parameters for the current layer, computes the forward pass, and then frees the parameters. This all-gather/free pattern repeats for every transformer block, trading communication for memory.
26.10 Cost Estimation and Management
26.10.1 The Cost of Compute
Training a large model is expensive. The total cost depends on three factors:
$$\text{Training cost} = \text{GPU-hours} \times \text{Price per GPU-hour}$$
where:
$$\text{GPU-hours} = \frac{\text{Total FLOPs}}{(\text{GPU peak FLOP/s}) \times \text{MFU} \times 3600}$$
For a rough estimate of total FLOPs for a transformer model, Kaplan et al. (2020) provide the approximation:
$$C \approx 6PD$$
where $C$ is the total compute in FLOPs, $P$ is the number of parameters, and $D$ is the number of training tokens (or, more generally, the number of training examples times the sequence length). The factor of 6 accounts for the forward pass ($\approx 2PD$) and the backward pass ($\approx 4PD$, since backprop requires approximately twice the compute of the forward pass).
26.10.2 Cost Estimation Framework
from dataclasses import dataclass
from typing import Dict
@dataclass
class TrainingCostEstimator:
"""Estimate the cost of a distributed training run.
Attributes:
model_params: Number of model parameters.
training_tokens: Total training tokens (or examples * seq_len).
gpu_peak_tflops: Peak GPU throughput in TFLOP/s (for the dtype).
mfu: Expected Model FLOP Utilization (0-1).
num_gpus: Number of GPUs to use.
price_per_gpu_hour: Cloud price per GPU-hour in USD.
spot_discount: Fraction discount for spot instances (0-1).
"""
model_params: int
training_tokens: int
gpu_peak_tflops: float = 312.0 # A100 BF16
mfu: float = 0.45
num_gpus: int = 8
price_per_gpu_hour: float = 3.50 # A100 80GB on-demand (approximate)
spot_discount: float = 0.65 # Spot is ~65% cheaper
def total_flops(self) -> float:
"""Estimate total training FLOPs using the 6PD approximation.
Returns:
Total FLOPs.
"""
return 6.0 * self.model_params * self.training_tokens
def single_gpu_hours(self) -> float:
"""Wall-clock hours if training on a single GPU.
Returns:
Hours.
"""
effective_tflops = self.gpu_peak_tflops * self.mfu
flops = self.total_flops()
seconds = flops / (effective_tflops * 1e12)
return seconds / 3600
def distributed_hours(self) -> float:
"""Wall-clock hours with distributed training.
Assumes linear scaling (ideal). In practice, communication
overhead reduces efficiency by 10-30%.
Returns:
Hours.
"""
return self.single_gpu_hours() / self.num_gpus
def on_demand_cost(self) -> float:
"""Cost using on-demand GPU instances.
Returns:
Cost in USD.
"""
return self.distributed_hours() * self.num_gpus * self.price_per_gpu_hour
def spot_cost(self) -> float:
"""Cost using spot/preemptible instances.
Returns:
Cost in USD.
"""
spot_price = self.price_per_gpu_hour * (1 - self.spot_discount)
return self.distributed_hours() * self.num_gpus * spot_price
def summary(self) -> Dict[str, str]:
"""Generate a human-readable cost summary.
Returns:
Dictionary of cost components.
"""
return {
"Total FLOPs": f"{self.total_flops():.2e}",
"Single GPU hours": f"{self.single_gpu_hours():.1f}",
f"Distributed hours ({self.num_gpus} GPUs)": (
f"{self.distributed_hours():.1f}"
),
"On-demand cost": f"${self.on_demand_cost():,.0f}",
"Spot cost (~65% discount)": f"${self.spot_cost():,.0f}",
}
# StreamRec: 45M params, 1.2B interactions, seq_len=1 (tabular)
streamrec_cost = TrainingCostEstimator(
model_params=45_000_000,
training_tokens=1_200_000_000,
gpu_peak_tflops=312.0,
mfu=0.50,
num_gpus=4,
price_per_gpu_hour=3.50,
)
print("=== StreamRec Training Cost ===")
for k, v in streamrec_cost.summary().items():
print(f" {k}: {v}")
print()
# Climate DL: 1.2B params, 350K timesteps * 2M grid points = 700B tokens
climate_cost = TrainingCostEstimator(
model_params=1_200_000_000,
training_tokens=700_000_000_000,
gpu_peak_tflops=312.0,
mfu=0.42,
num_gpus=64,
price_per_gpu_hour=3.50,
)
print("=== Climate DL Training Cost ===")
for k, v in climate_cost.summary().items():
print(f" {k}: {v}")
=== StreamRec Training Cost ===
Total FLOPs: 3.24e+17
Single GPU hours: 0.6
Distributed hours (4 GPUs): 0.1
On-demand cost: $2
Spot cost (~65% discount): $1
=== Climate DL Training Cost ===
Total FLOPs: 5.04e+21
Single GPU hours: 10687.4
Distributed hours (64 GPUs): 167.0
On-demand cost: $37,360
Spot cost (~65% discount): $13,076
StreamRec's training is cheap — a fraction of a dollar per run. This is characteristic of recommendation models: many interactions but small models. Climate DL is two orders of magnitude more expensive: $37,000 per training run on demand, or $13,000 with spot instances. At daily retraining cadence, that would be $1.1 million per month on demand — a powerful argument for training efficiency.
26.10.3 Spot Instances and Fault Tolerance
Spot (AWS) or preemptible (GCP) instances offer 60-90% discounts compared to on-demand pricing, with the caveat that the cloud provider can reclaim the instance with minimal notice (2 minutes on AWS, 30 seconds on GCP).
For training jobs, spot instances require checkpoint-based fault tolerance:
-
Checkpoint frequently. Save model parameters, optimizer state, learning rate scheduler state, and the current epoch/step every $k$ steps. For Climate DL at $13,000 per run, a checkpoint every 30 minutes costs ~$50 of redundant compute (to redo up to 30 minutes of work after preemption) but saves thousands compared to on-demand pricing.
-
Auto-restart. Use a job scheduler (SLURM, Kubernetes Job, or cloud-native like AWS Batch) that automatically restarts the training job from the latest checkpoint when a spot instance is reclaimed.
-
Elastic training. Some frameworks (TorchElastic, Horovod Elastic) support training that dynamically adjusts to the number of available GPUs. If 2 of 8 GPUs are reclaimed, training continues on 6 GPUs at reduced throughput rather than failing entirely.
import os
import torch
from dataclasses import dataclass
from typing import Optional
from pathlib import Path
@dataclass
class CheckpointManager:
"""Manages training checkpoints for fault-tolerant distributed training.
Saves checkpoints at regular intervals and supports resuming
from the latest checkpoint after spot instance preemption.
Attributes:
checkpoint_dir: Directory to store checkpoints.
save_every_steps: Save a checkpoint every N steps.
max_checkpoints: Maximum number of checkpoints to keep
(oldest are deleted).
"""
checkpoint_dir: str = "./checkpoints"
save_every_steps: int = 500
max_checkpoints: int = 3
def __post_init__(self) -> None:
Path(self.checkpoint_dir).mkdir(parents=True, exist_ok=True)
def save(
self,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: Optional[object],
epoch: int,
step: int,
loss: float,
rank: int = 0,
) -> Optional[str]:
"""Save a checkpoint (only on rank 0 for DDP).
Args:
model: The model (may be DDP-wrapped).
optimizer: The optimizer.
scheduler: LR scheduler (optional).
epoch: Current epoch.
step: Current global step.
loss: Current loss.
rank: Process rank (checkpoint saved only on rank 0).
Returns:
Checkpoint path if saved, None otherwise.
"""
if rank != 0:
return None
state = {
"epoch": epoch,
"step": step,
"loss": loss,
"model_state_dict": (
model.module.state_dict()
if hasattr(model, "module") else model.state_dict()
),
"optimizer_state_dict": optimizer.state_dict(),
}
if scheduler is not None:
state["scheduler_state_dict"] = scheduler.state_dict()
path = os.path.join(
self.checkpoint_dir, f"checkpoint_step_{step:08d}.pt"
)
torch.save(state, path)
self._cleanup_old_checkpoints()
return path
def load_latest(self) -> Optional[dict]:
"""Load the most recent checkpoint.
Returns:
Checkpoint state dict, or None if no checkpoint exists.
"""
checkpoints = sorted(
Path(self.checkpoint_dir).glob("checkpoint_step_*.pt")
)
if not checkpoints:
return None
return torch.load(checkpoints[-1], map_location="cpu")
def _cleanup_old_checkpoints(self) -> None:
"""Remove old checkpoints, keeping only max_checkpoints."""
checkpoints = sorted(
Path(self.checkpoint_dir).glob("checkpoint_step_*.pt")
)
while len(checkpoints) > self.max_checkpoints:
checkpoints[0].unlink()
checkpoints.pop(0)
26.10.4 Cost Optimization Strategies
| Strategy | Savings | Risk / Trade-off |
|---|---|---|
| Spot instances | 60-90% | Preemption; requires checkpointing |
| Mixed precision (BF16) | 30-50% fewer GPU-hours (faster compute) | Numerical issues (rare with BF16) |
| Gradient checkpointing | Larger batch sizes → fewer steps | 30-40% compute overhead |
| Progressive resizing | Train on lower resolution first | May not transfer to full resolution |
| Curriculum learning | Faster convergence on easy examples first | Requires data difficulty ordering |
| Reduced-precision optimizer | 8-bit Adam (bitsandbytes) saves memory | Slight convergence difference |
| Model distillation | Train small model from large model | Two-stage process |
For the Climate DL project, the team used spot instances (60% savings), BF16 training (2x throughput), and gradient checkpointing (4x larger batch, 25% fewer gradient syncs). Total cost per training run dropped from $37,000 (on-demand, FP32) to approximately $5,400 — a 7x reduction.
26.11 Putting It All Together: The Training Pipeline
A production training pipeline combines all the techniques from this chapter into a coherent system. Here is the complete configuration for the StreamRec two-tower model:
from dataclasses import dataclass, field
from typing import Dict, Any, Optional, List
import json
@dataclass
class ProductionTrainingConfig:
"""Complete configuration for a production-grade training pipeline.
Combines distributed training, mixed precision, gradient checkpointing,
large-batch optimization, cost management, and fault tolerance.
Attributes:
model_name: Model identifier for logging and registry.
num_gpus: Total number of GPUs.
parallelism: Parallelism strategy ('ddp', 'fsdp', 'deepspeed').
local_batch_size: Batch size per GPU.
gradient_accumulation_steps: Micro-batches per optimizer step.
amp_dtype: Mixed-precision dtype (None for FP32).
gradient_checkpointing: Whether to enable gradient checkpointing.
base_lr: Base learning rate (before scaling).
base_batch_size: Batch size at which base_lr was tuned.
warmup_fraction: Fraction of total steps for LR warmup.
optimizer: Optimizer name ('adamw', 'lamb').
max_epochs: Maximum training epochs.
checkpoint_every_steps: Checkpoint interval.
use_spot: Whether to use spot instances.
max_training_hours: Maximum wall-clock hours (cost guard).
"""
model_name: str = "streamrec-two-tower-v2"
num_gpus: int = 4
parallelism: str = "ddp"
local_batch_size: int = 2048
gradient_accumulation_steps: int = 1
amp_dtype: Optional[str] = "bfloat16"
gradient_checkpointing: bool = False
base_lr: float = 1e-3
base_batch_size: int = 256
warmup_fraction: float = 0.05
optimizer: str = "adamw"
max_epochs: int = 3
checkpoint_every_steps: int = 1000
use_spot: bool = True
max_training_hours: float = 2.0
@property
def global_batch_size(self) -> int:
"""Effective global batch size."""
return (
self.local_batch_size
* self.gradient_accumulation_steps
* self.num_gpus
)
@property
def scaled_lr(self) -> float:
"""Learning rate after linear scaling."""
return self.base_lr * (self.global_batch_size / self.base_batch_size)
def validate(self) -> List[str]:
"""Check for common configuration errors.
Returns:
List of warning messages (empty if valid).
"""
warnings = []
if self.global_batch_size > 65536:
warnings.append(
f"Global batch size {self.global_batch_size} exceeds 65536. "
"Consider LAMB optimizer and monitor convergence carefully."
)
if self.scaled_lr > 0.1:
warnings.append(
f"Scaled LR {self.scaled_lr:.4f} is very large. "
"Increase warmup or use LAMB."
)
if self.use_spot and self.checkpoint_every_steps > 2000:
warnings.append(
"Using spot instances with infrequent checkpointing. "
"Risk of losing >30 minutes of work on preemption."
)
if self.gradient_checkpointing and self.parallelism == "fsdp":
warnings.append(
"Gradient checkpointing + FSDP: ensure activation "
"checkpointing wrapper is applied to FSDP units."
)
return warnings
def summary(self) -> str:
"""Human-readable summary of the training configuration.
Returns:
Formatted summary string.
"""
lines = [
f"=== Training Configuration: {self.model_name} ===",
f" Parallelism: {self.parallelism.upper()}",
f" GPUs: {self.num_gpus}",
f" Local batch: {self.local_batch_size}",
f" Gradient accumulation: {self.gradient_accumulation_steps}x",
f" Global batch: {self.global_batch_size}",
f" AMP: {self.amp_dtype or 'disabled'}",
f" Gradient checkpointing: {self.gradient_checkpointing}",
f" Optimizer: {self.optimizer}",
f" Base LR: {self.base_lr}",
f" Scaled LR: {self.scaled_lr:.6f}",
f" Warmup: {self.warmup_fraction:.0%} of steps",
f" Max epochs: {self.max_epochs}",
f" Spot instances: {self.use_spot}",
f" Checkpoint every: {self.checkpoint_every_steps} steps",
]
warnings = self.validate()
if warnings:
lines.append(" WARNINGS:")
for w in warnings:
lines.append(f" - {w}")
return "\n".join(lines)
# StreamRec production configuration
streamrec_config = ProductionTrainingConfig(
model_name="streamrec-two-tower-v2",
num_gpus=4,
parallelism="ddp",
local_batch_size=2048,
gradient_accumulation_steps=1,
amp_dtype="bfloat16",
gradient_checkpointing=False, # Model is small enough
base_lr=1e-3,
base_batch_size=256,
warmup_fraction=0.05,
optimizer="adamw",
max_epochs=3,
use_spot=True,
)
print(streamrec_config.summary())
print()
# Climate DL production configuration
climate_config = ProductionTrainingConfig(
model_name="climate-vit-1.2b",
num_gpus=64,
parallelism="fsdp",
local_batch_size=4,
gradient_accumulation_steps=8,
amp_dtype="bfloat16",
gradient_checkpointing=True,
base_lr=3e-4,
base_batch_size=256,
warmup_fraction=0.10,
optimizer="lamb",
max_epochs=50,
checkpoint_every_steps=500,
use_spot=True,
max_training_hours=200.0,
)
print(climate_config.summary())
=== Training Configuration: streamrec-two-tower-v2 ===
Parallelism: DDP
GPUs: 4
Local batch: 2048
Gradient accumulation: 1x
Global batch: 8192
AMP: bfloat16
Gradient checkpointing: False
Optimizer: adamw
Base LR: 0.001
Scaled LR: 0.032000
Warmup: 5% of steps
Max epochs: 3
Spot instances: True
Checkpoint every: 1000 steps
=== Training Configuration: climate-vit-1.2b ===
Parallelism: FSDP
GPUs: 64
Local batch: 4
Gradient accumulation: 8x
Global batch: 2048
AMP: bfloat16
Gradient checkpointing: True
Optimizer: lamb
Base LR: 0.0003
Scaled LR: 0.002400
Warmup: 10% of steps
Max epochs: 50
Spot instances: True
Checkpoint every: 500 steps
The two configurations illustrate different regimes. StreamRec is a small model on a large dataset: data parallelism with DDP, large local batch sizes (the model is small, so activations fit easily), and minimal optimization beyond AMP. Climate DL is a large model on a massive dataset: FSDP for memory efficiency, small local batch sizes compensated by gradient accumulation, gradient checkpointing to fit activations, LAMB for stable large-batch convergence, and frequent checkpointing for spot-instance fault tolerance.
26.12 Measuring Training Efficiency
The final skill is knowing whether your training pipeline is efficient and where the bottlenecks are.
26.12.1 Key Metrics
| Metric | Formula | Target |
|---|---|---|
| MFU | Model FLOPs / (Peak FLOPs $\times$ wall time) | > 40% |
| GPU utilization | Time GPU is computing / total time | > 80% |
| Communication fraction | Time in all-reduce / total step time | < 20% |
| Throughput | Samples/second or tokens/second | Maximize |
| Time to accuracy | Wall-clock hours to reach target metric | Minimize |
26.12.2 Profiling Tools
from dataclasses import dataclass
from typing import Dict, List, Tuple
@dataclass
class TrainingProfiler:
"""Lightweight profiler for distributed training steps.
Tracks computation, communication, and memory metrics
per training step to identify bottlenecks.
Attributes:
step_times_ms: List of total step times in milliseconds.
compute_times_ms: List of forward + backward compute times.
comm_times_ms: List of communication (all-reduce) times.
peak_memory_gb: Peak GPU memory usage in GB per step.
"""
step_times_ms: List[float] = None
compute_times_ms: List[float] = None
comm_times_ms: List[float] = None
peak_memory_gb: List[float] = None
def __post_init__(self) -> None:
self.step_times_ms = self.step_times_ms or []
self.compute_times_ms = self.compute_times_ms or []
self.comm_times_ms = self.comm_times_ms or []
self.peak_memory_gb = self.peak_memory_gb or []
def record_step(
self,
step_ms: float,
compute_ms: float,
comm_ms: float,
peak_mem_gb: float,
) -> None:
"""Record metrics for a single training step.
Args:
step_ms: Total step time.
compute_ms: Compute-only time (forward + backward).
comm_ms: Communication time (all-reduce).
peak_mem_gb: Peak GPU memory in GB.
"""
self.step_times_ms.append(step_ms)
self.compute_times_ms.append(compute_ms)
self.comm_times_ms.append(comm_ms)
self.peak_memory_gb.append(peak_mem_gb)
def summary(self) -> Dict[str, float]:
"""Compute summary statistics over recorded steps.
Returns:
Dictionary of mean metrics and derived ratios.
"""
n = len(self.step_times_ms)
if n == 0:
return {}
mean = lambda xs: sum(xs) / len(xs)
avg_step = mean(self.step_times_ms)
avg_compute = mean(self.compute_times_ms)
avg_comm = mean(self.comm_times_ms)
return {
"avg_step_ms": round(avg_step, 1),
"avg_compute_ms": round(avg_compute, 1),
"avg_comm_ms": round(avg_comm, 1),
"comm_fraction": round(avg_comm / avg_step, 3),
"compute_fraction": round(avg_compute / avg_step, 3),
"overhead_fraction": round(
1 - (avg_compute + avg_comm) / avg_step, 3
),
"peak_memory_gb": round(max(self.peak_memory_gb), 2),
}
def diagnose(self) -> List[str]:
"""Identify training bottlenecks from profiled metrics.
Returns:
List of diagnostic messages.
"""
s = self.summary()
if not s:
return ["No steps recorded."]
issues = []
if s["comm_fraction"] > 0.3:
issues.append(
f"Communication overhead is {s['comm_fraction']:.0%}. "
"Consider: (1) increasing local batch size to amortize "
"all-reduce, (2) gradient accumulation, or (3) overlapping "
"communication with computation."
)
if s["overhead_fraction"] > 0.15:
issues.append(
f"Non-compute/non-comm overhead is {s['overhead_fraction']:.0%}. "
"Check data loading (use num_workers > 0, pin_memory=True), "
"or CPU-GPU transfer bottlenecks."
)
if s["peak_memory_gb"] > 75:
issues.append(
f"Peak memory {s['peak_memory_gb']:.1f} GB is near A100 "
"limit. Risk of OOM. Enable gradient checkpointing or "
"reduce batch size."
)
if not issues:
issues.append("Training pipeline is well-optimized.")
return issues
Production ML = Software Engineering: Profiling a training pipeline is no different from profiling any distributed system. You measure, you identify the bottleneck, you optimize the bottleneck, and then you measure again. The metrics above — MFU, communication fraction, memory utilization — are the training-pipeline equivalents of latency percentiles and throughput metrics for a serving system. The discipline is identical; only the units differ.
26.13 Progressive Project: Scaling StreamRec Training
M10a: Multi-GPU Training with DDP
Scale the StreamRec two-tower model training from the single-GPU setup in Chapter 13 to multi-GPU using DDP.
Deliverables:
-
Implement the DDP training loop using the
setup_ddp,create_ddp_model, andtrain_one_epoch_ddpfunctions from Section 26.2. Verify that training produces the same loss trajectory (within floating-point tolerance) as single-GPU training with the equivalent global batch size. -
Add mixed precision training (BF16) using the AMP recipe from Section 26.5. Measure: - Training throughput (samples/second) with and without AMP. - Peak GPU memory with and without AMP. - Final Recall@20 and NDCG@20 with and without AMP (should be within 0.5% of each other).
-
Add gradient checkpointing to the item tower (the larger tower). Measure the memory savings and compute overhead. Determine whether the freed memory allows a larger local batch size, and whether the larger batch improves throughput enough to offset the recomputation cost.
-
Profile the training loop using the
TrainingProfilerclass. Report: - Average step time breakdown (compute, communication, other). - Communication fraction as a function of local batch size. - MFU estimate. -
Compute the training cost using the
TrainingCostEstimator. Compare on-demand vs. spot pricing. Implement checkpoint-based fault tolerance usingCheckpointManager.
Evaluation criteria: - DDP training reproduces single-GPU quality (Recall@20 within 1%). - AMP provides measurable speedup (>1.5x expected). - Profiling identifies the primary bottleneck and the optimization addresses it. - Cost estimate is realistic and documented.
26.14 Chapter Summary
Training at scale is a distributed systems problem. The model is a mathematical object; the training pipeline is software that must handle communication, synchronization, memory management, fault tolerance, and cost optimization.
Data parallelism (DDP) is the right starting point for most production training. It replicates the model, partitions the data, and synchronizes gradients via ring all-reduce — a communication primitive whose per-GPU cost is nearly independent of the number of GPUs. For models too large to fit on a single device, model parallelism (layer partitioning), pipeline parallelism (micro-batch interleaving to reduce bubbles), and tensor parallelism (splitting individual layers) provide the tools, with frameworks like DeepSpeed and FSDP implementing the ZeRO optimization that shards optimizer state across GPUs.
GPU memory optimization — mixed precision (AMP), gradient checkpointing, and FlashAttention — is not optional for large-scale training. AMP provides 2-8x throughput improvement by leveraging Tensor Cores. Gradient checkpointing trades 30-40% recomputation for 85% activation memory savings. FlashAttention eliminates the $O(s^2)$ attention matrix from HBM, enabling longer sequences and larger batches.
Efficient large-batch training — linear LR scaling, warmup, LARS, and LAMB — maintains model quality as the effective batch size grows with the number of GPUs. And cost management — spot instances with checkpoint-based fault tolerance, MFU monitoring, and systematic profiling — ensures that compute budgets are spent efficiently.
The result is a training pipeline that is fast, memory-efficient, fault-tolerant, and cost-effective. That pipeline is a prerequisite for everything that follows in Part V: you cannot orchestrate (Chapter 27), test (Chapter 28), deploy (Chapter 29), or monitor (Chapter 30) a model that takes too long to train, costs too much to iterate on, or fails unpredictably on distributed hardware.