> "The only way to train really big models is to spread the work across many devices. The art lies in doing it efficiently."
In This Chapter
- 35.1 Introduction: Why Distributed Training
- 35.2 Foundations: Communication Primitives
- 35.3 Data Parallelism
- 35.4 Model Parallelism
- 35.5 Fully Sharded Data Parallel (FSDP)
- 35.6 DeepSpeed ZeRO
- 35.7 Gradient Accumulation
- 35.8 Mixed Precision Training
- 35.9 HuggingFace Accelerate
- 35.10 Multi-GPU and Multi-Node Training
- 35.11 Cost-Effective Training Strategies
- 35.12 Combining Parallelism Strategies: 3D Parallelism
- 35.13 Summary
Chapter 35: Distributed Training and Scaling
Training Models Beyond the Limits of a Single GPU
"The only way to train really big models is to spread the work across many devices. The art lies in doing it efficiently." --- A distributed systems engineer's axiom
35.1 Introduction: Why Distributed Training
Modern deep learning models have grown at an astonishing pace. GPT-3 has 175 billion parameters. LLaMA 2 70B has 70 billion. Even "small" models like BERT-large have 340 million parameters. Training these models on a single GPU is either impossibly slow or physically impossible---the model simply does not fit in memory.
Consider the arithmetic. A model with 7 billion parameters in float32 requires:
$$\text{Model memory} = 7 \times 10^9 \times 4\;\text{bytes} = 28\;\text{GB}$$
But training requires far more than just storing the model. The optimizer state for Adam requires two additional copies (first and second moments), and we need memory for gradients and activations:
$$\text{Training memory} \approx \underbrace{4B}_{\text{params}} + \underbrace{4B}_{\text{gradients}} + \underbrace{8B}_{\text{Adam states}} + \underbrace{\text{variable}}_{\text{activations}} = 16B + \text{activations}$$
For a 7B parameter model, that is at least 112 GB---far beyond the 80 GB of an A100 GPU, even before accounting for activations.
Distributed training solves this by splitting the work across multiple GPUs, potentially across multiple machines. This chapter covers the full spectrum of distributed training techniques, from the conceptually simple (data parallelism) to the sophisticated (tensor parallelism, pipeline parallelism, and ZeRO), with hands-on PyTorch code throughout.
What You Will Learn
- How data parallelism (DDP) scales training across GPUs
- How model parallelism (tensor and pipeline) handles models too large for one GPU
- How FSDP and DeepSpeed ZeRO optimize memory efficiency
- How gradient accumulation simulates larger batch sizes
- How communication primitives (all-reduce, all-gather) underpin distributed training
- How mixed precision training accelerates computation
- How HuggingFace Accelerate simplifies distributed training
- Practical multi-GPU and multi-node training workflows
- Cost-effective training strategies
35.2 Foundations: Communication Primitives
Before diving into parallelism strategies, we need to understand the communication operations that underpin distributed training. These primitives coordinate data exchange between GPUs.
35.2.1 Key Communication Operations
┌─────────────────────────────────────────────────────────────┐
│ Communication Primitives │
│ │
│ BROADCAST: One-to-all │
│ GPU 0: [A] ──> GPU 0: [A] │
│ GPU 1: [A] │
│ GPU 2: [A] │
│ │
│ REDUCE: All-to-one (with operation, e.g., sum) │
│ GPU 0: [A] ──┐ │
│ GPU 1: [B] ──┼──> GPU 0: [A+B+C] │
│ GPU 2: [C] ──┘ │
│ │
│ ALL-REDUCE: All-to-all (with operation) │
│ GPU 0: [A] ──┐ GPU 0: [A+B+C] │
│ GPU 1: [B] ──┼──> GPU 1: [A+B+C] │
│ GPU 2: [C] ──┘ GPU 2: [A+B+C] │
│ │
│ ALL-GATHER: Gather all, distribute to all │
│ GPU 0: [A] ──┐ GPU 0: [A,B,C] │
│ GPU 1: [B] ──┼──> GPU 1: [A,B,C] │
│ GPU 2: [C] ──┘ GPU 2: [A,B,C] │
│ │
│ REDUCE-SCATTER: Reduce + Scatter │
│ GPU 0: [A0,A1,A2]──┐ GPU 0: [A0+B0+C0] │
│ GPU 1: [B0,B1,B2]──┼─>GPU 1: [A1+B1+C1] │
│ GPU 2: [C0,C1,C2]──┘ GPU 2: [A2+B2+C2] │
└─────────────────────────────────────────────────────────────┘
35.2.2 Communication Backends
PyTorch supports multiple communication backends:
| Backend | Hardware | Features |
|---|---|---|
| NCCL | NVIDIA GPUs | Fastest for GPU-to-GPU communication |
| Gloo | CPU + GPU | General-purpose, works on CPU |
| MPI | CPU + GPU | Standards-based, broad compatibility |
For GPU training, NCCL (NVIDIA Collective Communications Library) is almost always the right choice.
35.2.3 PyTorch Distributed Basics
"""Fundamentals of PyTorch distributed communication.
Demonstrates process group initialization and basic
collective operations.
"""
import os
import torch
import torch.distributed as dist
torch.manual_seed(42)
def setup_distributed(rank: int, world_size: int) -> None:
"""Initialize the distributed process group.
Args:
rank: Global rank of this process (0 to world_size-1).
world_size: Total number of processes.
"""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
# Initialize process group
dist.init_process_group(
backend="nccl", # Use "gloo" for CPU
rank=rank,
world_size=world_size,
)
# Set the device for this rank
torch.cuda.set_device(rank)
def cleanup_distributed() -> None:
"""Clean up the distributed process group."""
dist.destroy_process_group()
def demonstrate_collectives(rank: int, world_size: int) -> None:
"""Demonstrate basic collective operations.
Args:
rank: Process rank.
world_size: Total processes.
"""
setup_distributed(rank, world_size)
device = torch.device(f"cuda:{rank}")
# ALL-REDUCE: Sum tensors across all ranks
tensor = torch.tensor([rank + 1.0], device=device)
print(f"[Rank {rank}] Before all-reduce: {tensor.item()}")
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
print(f"[Rank {rank}] After all-reduce (SUM): {tensor.item()}")
# All ranks now have the sum: 1 + 2 + ... + world_size
# BROADCAST: Send from rank 0 to all
if rank == 0:
data = torch.tensor([42.0], device=device)
else:
data = torch.zeros(1, device=device)
dist.broadcast(data, src=0)
print(f"[Rank {rank}] After broadcast: {data.item()}")
# All ranks now have 42.0
# ALL-GATHER: Collect tensors from all ranks
local_tensor = torch.tensor([rank * 10.0], device=device)
gathered = [torch.zeros(1, device=device) for _ in range(world_size)]
dist.all_gather(gathered, local_tensor)
print(f"[Rank {rank}] After all-gather: {[t.item() for t in gathered]}")
cleanup_distributed()
35.2.4 Ring All-Reduce
The ring all-reduce algorithm is the most efficient way to perform all-reduce with $N$ GPUs. It works in two phases:
Phase 1: Reduce-Scatter. Each GPU sends and receives $N-1$ times. After this phase, each GPU has a unique portion of the fully reduced result.
Phase 2: All-Gather. Each GPU sends and receives $N-1$ times again, distributing the fully reduced portions to all GPUs.
The total communication volume is $2(N-1)/N$ times the data size, which approaches $2\times$ the data size as $N$ grows. This is bandwidth-optimal.
$$\text{Communication time} = 2 \cdot \frac{N-1}{N} \cdot \frac{D}{B}$$
where $D$ is the data size and $B$ is the bandwidth per link.
35.3 Data Parallelism
Data parallelism is the simplest and most widely used distributed training strategy. The idea is straightforward: replicate the model on every GPU and split the data across GPUs.
35.3.1 How Data Parallelism Works
┌──────────────────────────────────────────────────────────────┐
│ Data Parallelism (DDP) │
│ │
│ Mini-batch: [B0, B1, B2, B3] │
│ │ │
│ ┌─────────┼─────────┐ │
│ v v v │
│ ┌────────┐ ┌────────┐ ┌────────┐ │
│ │ GPU 0 │ │ GPU 1 │ │ GPU 2 │ Each GPU has full model │
│ │Model(B0,│ │Model(B1)│ │Model(B2,│ │
│ │ B1) │ │ │ │ B3) │ │
│ │ copy │ │ copy │ │ copy │ │
│ └────┬───┘ └────┬───┘ └────┬───┘ │
│ │ │ │ │
│ └──── ALL-REDUCE ─────┘ Average gradients │
│ │ │
│ ┌──────────┼──────────┐ │
│ v v v │
│ ┌────────┐ ┌────────┐ ┌────────┐ │
│ │Update │ │Update │ │Update │ Identical updates │
│ │weights │ │weights │ │weights │ │
│ └────────┘ └────────┘ └────────┘ │
└──────────────────────────────────────────────────────────────┘
Steps: 1. Each GPU receives a full copy of the model. 2. The training batch is split evenly across GPUs. 3. Each GPU computes the forward and backward pass on its portion. 4. Gradients are averaged across all GPUs via all-reduce. 5. Each GPU updates its model with the averaged gradients. 6. All GPUs remain synchronized.
35.3.2 PyTorch DistributedDataParallel (DDP)
DDP is PyTorch's recommended approach for data parallelism. It is more efficient than the older DataParallel because it uses multi-process parallelism (one process per GPU) and overlaps gradient communication with backward computation.
"""Distributed Data Parallel (DDP) training with PyTorch.
Complete example showing multi-GPU training with DDP,
including proper data loading, gradient synchronization,
and checkpointing.
"""
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.distributed import DistributedSampler
torch.manual_seed(42)
class TransformerClassifier(nn.Module):
"""A small transformer-based classifier for demonstration.
Args:
vocab_size: Size of the vocabulary.
d_model: Model dimension.
nhead: Number of attention heads.
num_layers: Number of transformer layers.
num_classes: Number of output classes.
max_seq_len: Maximum sequence length.
"""
def __init__(
self,
vocab_size: int = 10000,
d_model: int = 256,
nhead: int = 8,
num_layers: int = 4,
num_classes: int = 5,
max_seq_len: int = 128,
) -> None:
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_embedding = nn.Embedding(max_seq_len, d_model)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=d_model * 4,
batch_first=True,
)
self.transformer = nn.TransformerEncoder(
encoder_layer, num_layers=num_layers,
)
self.classifier = nn.Linear(d_model, num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
x: Input token IDs, shape (batch_size, seq_len).
Returns:
Logits, shape (batch_size, num_classes).
"""
seq_len = x.size(1)
positions = torch.arange(seq_len, device=x.device).unsqueeze(0)
x = self.embedding(x) + self.pos_embedding(positions)
x = self.transformer(x)
x = x.mean(dim=1) # Global average pooling
return self.classifier(x)
def create_synthetic_dataset(
num_samples: int = 10000,
seq_len: int = 64,
vocab_size: int = 10000,
num_classes: int = 5,
) -> TensorDataset:
"""Create a synthetic text classification dataset.
Args:
num_samples: Number of samples.
seq_len: Sequence length.
vocab_size: Vocabulary size.
num_classes: Number of classes.
Returns:
TensorDataset with (input_ids, labels).
"""
input_ids = torch.randint(0, vocab_size, (num_samples, seq_len))
labels = torch.randint(0, num_classes, (num_samples,))
return TensorDataset(input_ids, labels)
def setup(rank: int, world_size: int) -> None:
"""Initialize distributed training.
Args:
rank: Process rank.
world_size: Number of processes.
"""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup() -> None:
"""Destroy the process group."""
dist.destroy_process_group()
def train_ddp(rank: int, world_size: int, epochs: int = 10) -> None:
"""Train a model using DistributedDataParallel.
Args:
rank: Process rank (GPU index for single-node).
world_size: Total number of GPUs.
epochs: Number of training epochs.
"""
setup(rank, world_size)
# Create model and move to GPU
model = TransformerClassifier().to(rank)
# Wrap with DDP
ddp_model = DDP(model, device_ids=[rank])
# Create dataset and distributed sampler
dataset = create_synthetic_dataset()
sampler = DistributedSampler(
dataset,
num_replicas=world_size,
rank=rank,
shuffle=True,
)
dataloader = DataLoader(
dataset,
batch_size=32,
sampler=sampler,
num_workers=2,
pin_memory=True,
)
# Optimizer and loss
optimizer = optim.AdamW(ddp_model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
# Training loop
for epoch in range(epochs):
# IMPORTANT: Set epoch for sampler to ensure proper shuffling
sampler.set_epoch(epoch)
ddp_model.train()
total_loss = 0.0
num_batches = 0
for input_ids, labels in dataloader:
input_ids = input_ids.to(rank)
labels = labels.to(rank)
optimizer.zero_grad()
logits = ddp_model(input_ids)
loss = criterion(logits, labels)
loss.backward() # DDP handles gradient sync automatically
optimizer.step()
total_loss += loss.item()
num_batches += 1
avg_loss = total_loss / num_batches
# Only print from rank 0
if rank == 0:
print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
# Save checkpoint from rank 0 only
if (epoch + 1) % 5 == 0:
checkpoint = {
"epoch": epoch,
"model_state_dict": ddp_model.module.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": avg_loss,
}
torch.save(checkpoint, f"checkpoint_epoch_{epoch+1}.pt")
cleanup()
def main() -> None:
"""Launch DDP training."""
world_size = torch.cuda.device_count()
if world_size < 2:
print(f"DDP requires at least 2 GPUs. Found {world_size}.")
print("Running single-GPU training for demonstration.")
model = TransformerClassifier()
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
return
# Use torch.multiprocessing to spawn processes
torch.multiprocessing.spawn(
train_ddp,
args=(world_size,),
nprocs=world_size,
join=True,
)
if __name__ == "__main__":
main()
35.3.3 DDP Key Concepts
Gradient Bucketing: DDP does not wait until all gradients are computed before communicating. Instead, it groups gradients into buckets and overlaps all-reduce communication with backward computation. This significantly reduces wall-clock time.
Broadcast at Initialization: DDP broadcasts the model state from rank 0 to all other ranks at initialization, ensuring all replicas start with identical parameters.
No Redundant Computation: Unlike DataParallel, DDP does not replicate the model on a single process. Each process has its own model replica, eliminating the GIL bottleneck.
35.3.4 Effective Batch Size
With DDP, the effective batch size scales with the number of GPUs:
$$B_{\text{effective}} = B_{\text{per\_gpu}} \times N_{\text{gpus}}$$
This is important because larger batch sizes affect convergence. If you scale from 1 to 8 GPUs, you should consider:
- Linear scaling rule: Scale the learning rate proportionally: $\text{lr}_{\text{new}} = \text{lr}_{\text{base}} \times N_{\text{gpus}}$
- Learning rate warmup: Gradually increase the learning rate over the first few hundred steps.
- LARS/LAMB optimizers: Specialized optimizers for large-batch training.
$$\text{lr}(t) = \text{lr}_{\text{target}} \times \min\left(1, \frac{t}{T_{\text{warmup}}}\right)$$
35.4 Model Parallelism
When a model is too large to fit on a single GPU, we must split it across GPUs. There are two main approaches: tensor parallelism and pipeline parallelism.
35.4.1 Tensor Parallelism
Tensor parallelism splits individual operations (typically matrix multiplications) across GPUs. This is the approach used by Megatron-LM for training massive transformer models.
Consider a linear layer: $Y = XW + b$, where $W \in \mathbb{R}^{d_{\text{in}} \times d_{\text{out}}}$.
Column parallelism splits $W$ along columns:
$$W = [W_1 | W_2], \quad Y = X[W_1 | W_2] = [XW_1 | XW_2]$$
GPU 0 computes $XW_1$, GPU 1 computes $XW_2$, and the results are concatenated.
Row parallelism splits $W$ along rows:
$$W = \begin{bmatrix} W_1 \\ W_2 \end{bmatrix}, \quad X = [X_1 | X_2], \quad Y = X_1 W_1 + X_2 W_2$$
GPU 0 computes $X_1 W_1$, GPU 1 computes $X_2 W_2$, and the results are summed via all-reduce.
"""Tensor parallelism demonstration.
Shows how to split a linear layer across multiple GPUs
using column and row parallelism.
"""
import torch
import torch.nn as nn
import torch.distributed as dist
torch.manual_seed(42)
class ColumnParallelLinear(nn.Module):
"""Linear layer split along the output dimension.
Each GPU computes a portion of the output features.
Args:
in_features: Input feature dimension.
out_features: Total output features (before splitting).
world_size: Number of GPUs.
rank: This GPU's rank.
bias: Whether to include bias.
"""
def __init__(
self,
in_features: int,
out_features: int,
world_size: int,
rank: int,
bias: bool = True,
) -> None:
super().__init__()
assert out_features % world_size == 0, (
f"out_features ({out_features}) must be divisible by "
f"world_size ({world_size})"
)
self.out_features_per_gpu = out_features // world_size
self.rank = rank
self.world_size = world_size
self.linear = nn.Linear(
in_features, self.out_features_per_gpu, bias=bias,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass computing this GPU's output portion.
Args:
x: Input tensor, shape (batch, in_features).
Returns:
Partial output, shape (batch, out_features_per_gpu).
"""
return self.linear(x)
class RowParallelLinear(nn.Module):
"""Linear layer split along the input dimension.
Each GPU has a portion of the input features and computes
a partial result. Results are summed via all-reduce.
Args:
in_features: Total input features (before splitting).
out_features: Output feature dimension.
world_size: Number of GPUs.
rank: This GPU's rank.
bias: Whether to include bias.
"""
def __init__(
self,
in_features: int,
out_features: int,
world_size: int,
rank: int,
bias: bool = True,
) -> None:
super().__init__()
assert in_features % world_size == 0, (
f"in_features ({in_features}) must be divisible by "
f"world_size ({world_size})"
)
self.in_features_per_gpu = in_features // world_size
self.rank = rank
self.world_size = world_size
# Only rank 0 has bias to avoid double-counting
self.linear = nn.Linear(
self.in_features_per_gpu, out_features,
bias=(bias and rank == 0),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass computing partial result.
Args:
x: Input shard, shape (batch, in_features_per_gpu).
Returns:
Output after all-reduce, shape (batch, out_features).
"""
output = self.linear(x)
# In actual distributed setting, perform all-reduce here:
# dist.all_reduce(output, op=dist.ReduceOp.SUM)
return output
class TensorParallelTransformerLayer(nn.Module):
"""A transformer layer with tensor parallelism.
The MLP is split using column-then-row parallelism.
Attention is split by partitioning heads across GPUs.
Args:
d_model: Model dimension.
nhead: Total number of attention heads.
world_size: Number of GPUs.
rank: This GPU's rank.
"""
def __init__(
self,
d_model: int = 1024,
nhead: int = 16,
world_size: int = 2,
rank: int = 0,
) -> None:
super().__init__()
self.d_model = d_model
self.world_size = world_size
self.rank = rank
# Attention: split heads across GPUs
assert nhead % world_size == 0
self.heads_per_gpu = nhead // world_size
self.attention = nn.MultiheadAttention(
embed_dim=d_model,
num_heads=self.heads_per_gpu,
batch_first=True,
)
# MLP: Column parallel -> GeLU -> Row parallel
self.mlp_col = ColumnParallelLinear(
d_model, d_model * 4, world_size, rank,
)
self.activation = nn.GELU()
self.mlp_row = RowParallelLinear(
d_model * 4, d_model, world_size, rank,
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass with tensor parallelism.
Args:
x: Input tensor, shape (batch, seq_len, d_model).
Returns:
Output tensor, shape (batch, seq_len, d_model).
"""
# Self-attention (heads are split across GPUs)
normed = self.norm1(x)
attn_out, _ = self.attention(normed, normed, normed)
x = x + attn_out
# MLP (column -> row parallel)
normed = self.norm2(x)
mlp_out = self.mlp_col(normed)
mlp_out = self.activation(mlp_out)
mlp_out = self.mlp_row(mlp_out)
x = x + mlp_out
return x
35.4.2 Pipeline Parallelism
Pipeline parallelism splits the model vertically---different layers go on different GPUs. This avoids the communication overhead of tensor parallelism but introduces pipeline bubbles where some GPUs are idle.
┌──────────────────────────────────────────────────────────────┐
│ Pipeline Parallelism │
│ │
│ GPU 0: [Layer 0, Layer 1] │
│ GPU 1: [Layer 2, Layer 3] │
│ GPU 2: [Layer 4, Layer 5] │
│ GPU 3: [Layer 6, Layer 7] │
│ │
│ Naive pipeline (lots of bubbles): │
│ Time ──────────────────────────────────────> │
│ GPU 0: [F0] [F1] [F2] [F3] [B3] [B2] [B1] [B0] │
│ GPU 1: [F0] [F1] [F2] [F3] [B3] [B2] [B1] [B0] │
│ GPU 2: [F0] [F1] [F2] [F3][B3] [B2] [B1] [B0] │
│ GPU 3: [F0] [F1] [F2][F3] [B3] [B2] [B1][B0]│
│ │
│ GPipe: micro-batches reduce bubble overhead │
│ 1F1B: interleave forward and backward for better overlap │
└──────────────────────────────────────────────────────────────┘
GPipe reduces bubble overhead by splitting the batch into micro-batches. 1F1B (one forward, one backward) scheduling interleaves forward and backward passes for even better GPU utilization.
The pipeline bubble ratio is:
$$\text{Bubble fraction} = \frac{P - 1}{M + P - 1}$$
where $P$ is the number of pipeline stages and $M$ is the number of micro-batches. As $M \to \infty$, the bubble fraction approaches zero.
"""Pipeline parallelism demonstration.
Shows how to partition a model across GPUs in a pipeline
and process micro-batches.
"""
import torch
import torch.nn as nn
from typing import Any
torch.manual_seed(42)
class PipelineStage(nn.Module):
"""A single stage in a pipeline-parallel model.
Args:
layers: Sequential layers for this stage.
stage_id: Index of this stage.
device: Device to place this stage on.
"""
def __init__(
self,
layers: nn.Sequential,
stage_id: int,
device: torch.device,
) -> None:
super().__init__()
self.layers = layers.to(device)
self.stage_id = stage_id
self.device = device
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Process input through this stage's layers.
Args:
x: Input tensor (moved to this stage's device).
Returns:
Output tensor on this stage's device.
"""
x = x.to(self.device)
return self.layers(x)
class SimplePipeline:
"""A simple pipeline parallelism implementation.
Splits a model into stages across available devices
and processes micro-batches in a pipeline fashion.
Args:
model_layers: List of nn.Module layers.
num_stages: Number of pipeline stages.
num_microbatches: Number of micro-batches.
"""
def __init__(
self,
model_layers: list[nn.Module],
num_stages: int = 2,
num_microbatches: int = 4,
) -> None:
self.num_stages = num_stages
self.num_microbatches = num_microbatches
# Partition layers across stages
layers_per_stage = len(model_layers) // num_stages
self.stages: list[PipelineStage] = []
for i in range(num_stages):
start = i * layers_per_stage
end = start + layers_per_stage if i < num_stages - 1 else len(model_layers)
stage_layers = nn.Sequential(*model_layers[start:end])
# Use different GPUs if available, else CPU
if torch.cuda.is_available() and i < torch.cuda.device_count():
device = torch.device(f"cuda:{i}")
else:
device = torch.device("cpu")
self.stages.append(PipelineStage(stage_layers, i, device))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Process input through the pipeline.
Splits input into micro-batches and pipelines them
through stages.
Args:
x: Input tensor, shape (batch_size, ...).
Returns:
Output tensor from the last stage.
"""
batch_size = x.size(0)
microbatch_size = batch_size // self.num_microbatches
assert batch_size % self.num_microbatches == 0, (
f"Batch size ({batch_size}) must be divisible by "
f"num_microbatches ({self.num_microbatches})"
)
# Split into micro-batches
microbatches = x.split(microbatch_size, dim=0)
# Pipeline execution (simplified GPipe-style)
outputs = []
for mb in microbatches:
current = mb
for stage in self.stages:
current = stage(current)
outputs.append(current)
# Concatenate results
return torch.cat(outputs, dim=0)
def parameters(self):
"""Yield all parameters across all stages."""
for stage in self.stages:
yield from stage.parameters()
# Demonstration
def demo_pipeline() -> None:
"""Demonstrate pipeline parallelism with a simple model."""
# Create a deep model
layers = []
for i in range(8):
layers.extend([
nn.Linear(256, 256),
nn.ReLU(),
nn.LayerNorm(256),
])
pipeline = SimplePipeline(
model_layers=layers,
num_stages=2,
num_microbatches=4,
)
# Forward pass
x = torch.randn(32, 256) # batch_size=32
output = pipeline.forward(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Number of stages: {pipeline.num_stages}")
print(f"Number of micro-batches: {pipeline.num_microbatches}")
if __name__ == "__main__":
demo_pipeline()
35.5 Fully Sharded Data Parallel (FSDP)
FSDP is PyTorch's native implementation of sharded data parallelism, inspired by DeepSpeed ZeRO. It combines the simplicity of data parallelism with the memory efficiency of model parallelism.
35.5.1 How FSDP Works
In standard DDP, every GPU holds a full copy of the model, optimizer states, and gradients. This is wasteful---if you have 8 GPUs, you store 8 identical copies of everything.
FSDP shards the model parameters, gradients, and optimizer states across GPUs. During the forward pass, parameters are gathered just-in-time. During the backward pass, gradients are computed and then reduced and scattered.
┌──────────────────────────────────────────────────────────────┐
│ FSDP Memory Savings │
│ │
│ Standard DDP (4 GPUs, 4B param model): │
│ Each GPU: 4B params + 4B grads + 8B optimizer = 16B │
│ Total: 16B × 4 = 64B │
│ │
│ FSDP (4 GPUs, 4B param model): │
│ Each GPU: 1B params + 1B grads + 2B optimizer = 4B │
│ Total: 4B × 4 = 16B │
│ │
│ Memory per GPU: reduced from 16B to 4B (4× savings!) │
└──────────────────────────────────────────────────────────────┘
35.5.2 FSDP in PyTorch
"""Fully Sharded Data Parallel (FSDP) training.
Demonstrates memory-efficient distributed training using
PyTorch FSDP with mixed precision.
"""
import os
import functools
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
transformer_auto_wrap_policy,
)
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.distributed import DistributedSampler
torch.manual_seed(42)
class LargeTransformerBlock(nn.Module):
"""A transformer block designed for FSDP wrapping.
Args:
d_model: Model dimension.
nhead: Number of attention heads.
dim_feedforward: Feedforward dimension.
dropout: Dropout rate.
"""
def __init__(
self,
d_model: int = 1024,
nhead: int = 16,
dim_feedforward: int = 4096,
dropout: float = 0.1,
) -> None:
super().__init__()
self.self_attn = nn.MultiheadAttention(
d_model, nhead, dropout=dropout, batch_first=True,
)
self.ffn = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
nn.Dropout(dropout),
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass with pre-norm residual connections.
Args:
x: Input tensor, shape (batch, seq_len, d_model).
Returns:
Output tensor, same shape as input.
"""
normed = self.norm1(x)
attn_out, _ = self.self_attn(normed, normed, normed)
x = x + attn_out
normed = self.norm2(x)
x = x + self.ffn(normed)
return x
class LargeTransformerModel(nn.Module):
"""A large transformer model suitable for FSDP training.
Args:
vocab_size: Vocabulary size.
d_model: Model dimension.
nhead: Number of attention heads.
num_layers: Number of transformer blocks.
num_classes: Number of output classes.
max_seq_len: Maximum sequence length.
"""
def __init__(
self,
vocab_size: int = 32000,
d_model: int = 1024,
nhead: int = 16,
num_layers: int = 24,
num_classes: int = 10,
max_seq_len: int = 512,
) -> None:
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_embedding = nn.Embedding(max_seq_len, d_model)
self.layers = nn.ModuleList([
LargeTransformerBlock(d_model, nhead)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(d_model)
self.classifier = nn.Linear(d_model, num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
x: Input token IDs, shape (batch, seq_len).
Returns:
Logits, shape (batch, num_classes).
"""
seq_len = x.size(1)
positions = torch.arange(seq_len, device=x.device).unsqueeze(0)
x = self.embedding(x) + self.pos_embedding(positions)
for layer in self.layers:
x = layer(x)
x = self.norm(x)
x = x.mean(dim=1)
return self.classifier(x)
def setup_fsdp(rank: int, world_size: int) -> None:
"""Initialize FSDP distributed training.
Args:
rank: Process rank.
world_size: Total processes.
"""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def train_with_fsdp(rank: int, world_size: int) -> None:
"""Train a large model using FSDP.
Args:
rank: Process rank.
world_size: Total processes.
"""
setup_fsdp(rank, world_size)
# Create model
model = LargeTransformerModel()
if rank == 0:
total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params:,}")
print(
f"Estimated model size (fp32): "
f"{total_params * 4 / 1e9:.2f} GB"
)
# Define FSDP mixed precision policy
mixed_precision_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
)
# Define auto-wrapping policy
# Wrap each LargeTransformerBlock as a separate FSDP unit
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={LargeTransformerBlock},
)
# Wrap model with FSDP
fsdp_model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mixed_precision_policy,
sharding_strategy=ShardingStrategy.FULL_SHARD,
device_id=rank,
)
# Create synthetic data
dataset = TensorDataset(
torch.randint(0, 32000, (1000, 128)),
torch.randint(0, 10, (1000,)),
)
sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size)
dataloader = DataLoader(dataset, batch_size=8, sampler=sampler)
# Optimizer
optimizer = optim.AdamW(fsdp_model.parameters(), lr=1e-4)
# Training loop
for epoch in range(3):
sampler.set_epoch(epoch)
fsdp_model.train()
total_loss = 0.0
num_batches = 0
for input_ids, labels in dataloader:
input_ids = input_ids.to(rank)
labels = labels.to(rank)
optimizer.zero_grad()
logits = fsdp_model(input_ids)
loss = nn.functional.cross_entropy(logits, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
num_batches += 1
if rank == 0:
avg_loss = total_loss / num_batches
print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")
dist.destroy_process_group()
35.5.3 FSDP Sharding Strategies
FSDP offers different sharding strategies that trade off memory savings against communication overhead:
| Strategy | Shards | Memory Savings | Communication |
|---|---|---|---|
FULL_SHARD (ZeRO-3) |
Params + Grads + Optimizer | Maximum | Highest |
SHARD_GRAD_OP (ZeRO-2) |
Grads + Optimizer | Moderate | Moderate |
NO_SHARD (DDP) |
Nothing | None | Lowest |
HYBRID_SHARD |
Full within node, DDP between | Balanced | Balanced |
When to use which: - FULL_SHARD: When GPU memory is the bottleneck. Best for very large models. - SHARD_GRAD_OP: When you need more memory than DDP provides but want less communication overhead than FULL_SHARD. - HYBRID_SHARD: Multi-node training where intra-node bandwidth is high but inter-node bandwidth is limited.
35.6 DeepSpeed ZeRO
DeepSpeed is Microsoft's library for efficient distributed training. Its signature feature is ZeRO (Zero Redundancy Optimizer), which progressively eliminates memory redundancy.
35.6.1 ZeRO Stages
┌──────────────────────────────────────────────────────────────┐
│ DeepSpeed ZeRO Stages │
│ │
│ Memory per GPU (7B parameter model, 4 GPUs): │
│ │
│ No ZeRO (DDP): │
│ ┌──────────────────────────────────────────┐ │
│ │ Params (28GB) + Grads (28GB) + Opt (56GB)│ = 112 GB │
│ └──────────────────────────────────────────┘ │
│ │
│ ZeRO Stage 1 (Partition optimizer states): │
│ ┌──────────────────────────────────┐ │
│ │ Params (28GB) + Grads (28GB) │ │
│ │ + Opt (56/4 = 14GB) │ = 70 GB │
│ └──────────────────────────────────┘ │
│ │
│ ZeRO Stage 2 (+ Partition gradients): │
│ ┌─────────────────────────────┐ │
│ │ Params (28GB) + Grads (7GB) │ │
│ │ + Opt (14GB) │ = 49 GB │
│ └─────────────────────────────┘ │
│ │
│ ZeRO Stage 3 (+ Partition parameters): │
│ ┌──────────────────────┐ │
│ │ Params (7GB) │ │
│ │ + Grads (7GB) │ │
│ │ + Opt (14GB) │ = 28 GB │
│ └──────────────────────┘ │
└──────────────────────────────────────────────────────────────┘
35.6.2 DeepSpeed Configuration
{
"train_batch_size": 64,
"gradient_accumulation_steps": 4,
"fp16": {
"enabled": true,
"loss_scale": 0,
"initial_scale_power": 16
},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"overlap_comm": true,
"contiguous_gradients": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": 1e-4,
"betas": [0.9, 0.999],
"eps": 1e-8,
"weight_decay": 0.01
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 1e-4,
"warmup_num_steps": 1000,
"total_num_steps": 50000
}
},
"gradient_clipping": 1.0,
"wall_clock_breakdown": false
}
35.6.3 DeepSpeed Training Script
"""Training with DeepSpeed ZeRO optimization.
Demonstrates DeepSpeed integration for memory-efficient
distributed training with ZeRO stages.
"""
import argparse
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
torch.manual_seed(42)
# DeepSpeed import (install: pip install deepspeed)
try:
import deepspeed
HAS_DEEPSPEED = True
except ImportError:
HAS_DEEPSPEED = False
class GPTBlock(nn.Module):
"""A GPT-style transformer block.
Args:
d_model: Model dimension.
nhead: Number of attention heads.
dim_feedforward: FFN inner dimension.
dropout: Dropout rate.
"""
def __init__(
self,
d_model: int = 768,
nhead: int = 12,
dim_feedforward: int = 3072,
dropout: float = 0.1,
) -> None:
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = nn.MultiheadAttention(
d_model, nhead, dropout=dropout, batch_first=True,
)
self.ln2 = nn.LayerNorm(d_model)
self.mlp = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
nn.Dropout(dropout),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass with pre-norm architecture.
Args:
x: Input, shape (batch, seq_len, d_model).
Returns:
Output, same shape as input.
"""
normed = self.ln1(x)
attn_out, _ = self.attn(normed, normed, normed)
x = x + attn_out
normed = self.ln2(x)
x = x + self.mlp(normed)
return x
class GPTModel(nn.Module):
"""A GPT-style language model for DeepSpeed demonstration.
Args:
vocab_size: Vocabulary size.
d_model: Model dimension.
nhead: Number of attention heads.
num_layers: Number of transformer blocks.
max_seq_len: Maximum sequence length.
"""
def __init__(
self,
vocab_size: int = 32000,
d_model: int = 768,
nhead: int = 12,
num_layers: int = 12,
max_seq_len: int = 512,
) -> None:
super().__init__()
self.tok_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(max_seq_len, d_model)
self.blocks = nn.ModuleList([
GPTBlock(d_model, nhead) for _ in range(num_layers)
])
self.ln_f = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Forward pass for language modeling.
Args:
input_ids: Token IDs, shape (batch, seq_len).
Returns:
Logits, shape (batch, seq_len, vocab_size).
"""
seq_len = input_ids.size(1)
positions = torch.arange(seq_len, device=input_ids.device)
x = self.tok_emb(input_ids) + self.pos_emb(positions)
for block in self.blocks:
x = block(x)
x = self.ln_f(x)
return self.head(x)
def train_with_deepspeed() -> None:
"""Train a model using DeepSpeed."""
if not HAS_DEEPSPEED:
print("DeepSpeed not installed. Install with: pip install deepspeed")
print("Showing model structure instead.")
model = GPTModel()
total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params:,}")
print(f"Estimated memory (fp32): {total_params * 4 / 1e9:.2f} GB")
print(f"Estimated memory (fp16): {total_params * 2 / 1e9:.2f} GB")
return
# Parse DeepSpeed arguments
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--epochs", type=int, default=3)
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
# Create model
model = GPTModel()
# Create dataset
dataset = TensorDataset(
torch.randint(0, 32000, (5000, 256)), # input_ids
)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
# Initialize DeepSpeed
model_engine, optimizer, _, _ = deepspeed.initialize(
args=args,
model=model,
model_parameters=model.parameters(),
training_data=dataset,
)
# Training loop
for epoch in range(args.epochs):
total_loss = 0.0
num_batches = 0
for batch in dataloader:
input_ids = batch[0].to(model_engine.local_rank)
# Shift for language modeling
targets = input_ids[:, 1:].contiguous()
inputs = input_ids[:, :-1].contiguous()
logits = model_engine(inputs)
loss = nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
)
model_engine.backward(loss)
model_engine.step()
total_loss += loss.item()
num_batches += 1
if model_engine.local_rank == 0:
avg_loss = total_loss / num_batches
print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")
if __name__ == "__main__":
train_with_deepspeed()
35.6.4 ZeRO-Offload and ZeRO-Infinity
DeepSpeed extends ZeRO with CPU and NVMe offloading:
- ZeRO-Offload: Offloads optimizer states and/or gradients to CPU memory, enabling training of larger models on fewer GPUs.
- ZeRO-Infinity: Extends offloading to NVMe SSDs, enabling training of models with trillions of parameters.
{
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
}
}
}
The trade-off is clear: offloading reduces GPU memory usage but increases training time due to data movement between GPU and CPU/NVMe.
35.7 Gradient Accumulation
Gradient accumulation is a simple but powerful technique that simulates larger batch sizes without increasing memory usage. Instead of updating weights after every batch, you accumulate gradients over multiple mini-batches and then update.
35.7.1 How It Works
$$\text{Effective batch size} = \text{micro\_batch\_size} \times \text{accumulation\_steps} \times \text{num\_gpus}$$
"""Gradient accumulation for effective large-batch training.
Demonstrates how to accumulate gradients over multiple
micro-batches to simulate larger effective batch sizes.
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
torch.manual_seed(42)
def train_with_gradient_accumulation(
model: nn.Module,
dataloader: DataLoader,
optimizer: optim.Optimizer,
criterion: nn.Module,
accumulation_steps: int = 4,
max_grad_norm: float = 1.0,
epochs: int = 5,
device: torch.device = torch.device("cpu"),
) -> list[float]:
"""Train with gradient accumulation.
Args:
model: The model to train.
dataloader: Training data loader.
optimizer: Optimizer instance.
criterion: Loss function.
accumulation_steps: Number of micro-batches to accumulate.
max_grad_norm: Maximum gradient norm for clipping.
epochs: Number of training epochs.
device: Device to train on.
Returns:
List of average losses per epoch.
"""
model.to(device)
epoch_losses = []
for epoch in range(epochs):
model.train()
total_loss = 0.0
num_updates = 0
optimizer.zero_grad() # Zero gradients at the start
for step, (inputs, targets) in enumerate(dataloader):
inputs = inputs.to(device)
targets = targets.to(device)
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, targets)
# Scale loss by accumulation steps to get correct average
loss = loss / accumulation_steps
loss.backward()
total_loss += loss.item() * accumulation_steps
# Update weights every accumulation_steps
if (step + 1) % accumulation_steps == 0:
# Clip gradients
torch.nn.utils.clip_grad_norm_(
model.parameters(), max_grad_norm,
)
optimizer.step()
optimizer.zero_grad()
num_updates += 1
# Handle remaining gradients
if (step + 1) % accumulation_steps != 0:
torch.nn.utils.clip_grad_norm_(
model.parameters(), max_grad_norm,
)
optimizer.step()
optimizer.zero_grad()
num_updates += 1
avg_loss = total_loss / (step + 1)
epoch_losses.append(avg_loss)
print(
f"Epoch {epoch+1}/{epochs}, "
f"Loss: {avg_loss:.4f}, "
f"Updates: {num_updates}"
)
return epoch_losses
def demo_gradient_accumulation() -> None:
"""Demonstrate gradient accumulation."""
# Create a simple model
model = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 10),
)
# Create synthetic data
X = torch.randn(1000, 100)
y = torch.randint(0, 10, (1000,))
dataset = TensorDataset(X, y)
# With micro-batch size 8 and accumulation steps 8,
# effective batch size is 64
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
print("Micro-batch size: 8")
print("Accumulation steps: 8")
print("Effective batch size: 64")
print()
losses = train_with_gradient_accumulation(
model=model,
dataloader=dataloader,
optimizer=optimizer,
criterion=criterion,
accumulation_steps=8,
epochs=5,
)
if __name__ == "__main__":
demo_gradient_accumulation()
35.7.2 Gradient Accumulation with DDP
When using gradient accumulation with DDP, you want to avoid unnecessary gradient synchronization during accumulation steps:
"""Gradient accumulation with DDP.
Shows how to efficiently combine gradient accumulation
with distributed data parallelism by disabling gradient
synchronization during accumulation steps.
"""
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from contextlib import nullcontext
torch.manual_seed(42)
def train_step_with_accumulation(
model: DDP,
batch_iterator,
optimizer: torch.optim.Optimizer,
criterion: nn.Module,
accumulation_steps: int,
device: torch.device,
) -> float:
"""Perform one optimizer step with gradient accumulation.
Disables gradient synchronization during accumulation steps
for efficiency. Only synchronizes on the final accumulation step.
Args:
model: DDP-wrapped model.
batch_iterator: Iterator over batches.
optimizer: Optimizer.
criterion: Loss function.
accumulation_steps: Number of steps to accumulate.
device: Training device.
Returns:
Average loss over the accumulation steps.
"""
optimizer.zero_grad()
total_loss = 0.0
for micro_step in range(accumulation_steps):
try:
inputs, targets = next(batch_iterator)
except StopIteration:
break
inputs = inputs.to(device)
targets = targets.to(device)
# Only sync gradients on the last accumulation step
if micro_step < accumulation_steps - 1:
# no_sync() prevents DDP from doing all-reduce
context = model.no_sync()
else:
context = nullcontext()
with context:
outputs = model(inputs)
loss = criterion(outputs, targets) / accumulation_steps
loss.backward()
total_loss += loss.item() * accumulation_steps
# Clip gradients and update
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
return total_loss / accumulation_steps
35.8 Mixed Precision Training
Mixed precision training uses lower-precision floating-point formats (float16 or bfloat16) to speed up computation and reduce memory usage, while maintaining model accuracy.
35.8.1 Why Mixed Precision
Benefits: - Memory reduction: fp16 uses half the memory of fp32, enabling larger batches or models. - Speed improvement: Modern GPUs (V100, A100, H100) have tensor cores that operate 2-8x faster on fp16/bf16 than fp32. - Communication savings: Gradient synchronization is faster with smaller tensors.
Challenges: - Underflow: Small gradient values become zero in fp16 (minimum positive value: ~6e-8). - Overflow: Large values exceed fp16 range (max: 65504). - Loss of precision: Accumulated rounding errors can affect training.
35.8.2 float16 vs bfloat16
| Property | float32 | float16 | bfloat16 |
|---|---|---|---|
| Total bits | 32 | 16 | 16 |
| Exponent bits | 8 | 5 | 8 |
| Mantissa bits | 23 | 10 | 7 |
| Max value | ~3.4e38 | 65504 | ~3.4e38 |
| Min positive | ~1.2e-38 | ~6e-8 | ~1.2e-38 |
| Precision | High | Moderate | Low |
bfloat16 has the same dynamic range as float32 (same exponent bits) but lower precision. This makes it more robust for training---it rarely overflows or underflows.
35.8.3 PyTorch AMP (Automatic Mixed Precision)
"""Mixed precision training with PyTorch AMP.
Demonstrates automatic mixed precision using GradScaler
and autocast for efficient training.
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.amp import autocast, GradScaler
from torch.utils.data import DataLoader, TensorDataset
torch.manual_seed(42)
class VisionTransformer(nn.Module):
"""Simple Vision Transformer for mixed precision demonstration.
Args:
image_size: Input image size (square).
patch_size: Patch size.
num_classes: Number of output classes.
d_model: Model dimension.
nhead: Number of attention heads.
num_layers: Number of transformer layers.
"""
def __init__(
self,
image_size: int = 224,
patch_size: int = 16,
num_classes: int = 1000,
d_model: int = 768,
nhead: int = 12,
num_layers: int = 12,
in_channels: int = 3,
) -> None:
super().__init__()
num_patches = (image_size // patch_size) ** 2
self.patch_embed = nn.Conv2d(
in_channels, d_model,
kernel_size=patch_size, stride=patch_size,
)
self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
self.pos_embed = nn.Parameter(
torch.randn(1, num_patches + 1, d_model)
)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=d_model * 4,
batch_first=True,
norm_first=True,
)
self.encoder = nn.TransformerEncoder(
encoder_layer, num_layers=num_layers,
)
self.norm = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
x: Input images, shape (batch, channels, height, width).
Returns:
Logits, shape (batch, num_classes).
"""
# Patch embedding
x = self.patch_embed(x) # (B, d_model, H/P, W/P)
x = x.flatten(2).transpose(1, 2) # (B, num_patches, d_model)
# Prepend CLS token
cls = self.cls_token.expand(x.size(0), -1, -1)
x = torch.cat([cls, x], dim=1)
x = x + self.pos_embed
# Transformer encoder
x = self.encoder(x)
x = self.norm(x)
# Classification from CLS token
return self.head(x[:, 0])
def train_mixed_precision(
use_amp: bool = True,
epochs: int = 5,
) -> list[float]:
"""Train with optional mixed precision.
Args:
use_amp: Whether to use automatic mixed precision.
epochs: Number of training epochs.
Returns:
List of average losses per epoch.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype_str = "mixed (fp16)" if use_amp else "fp32"
print(f"Training with {dtype_str} precision on {device}")
# Create model (smaller for demo)
model = VisionTransformer(
image_size=32,
patch_size=4,
num_classes=10,
d_model=256,
nhead=8,
num_layers=6,
in_channels=3,
).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params:,}")
# Synthetic data
images = torch.randn(500, 3, 32, 32)
labels = torch.randint(0, 10, (500,))
dataset = TensorDataset(images, labels)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
# GradScaler for fp16 (not needed for bf16)
scaler = GradScaler("cuda", enabled=use_amp)
epoch_losses = []
for epoch in range(epochs):
model.train()
total_loss = 0.0
num_batches = 0
for images_batch, labels_batch in dataloader:
images_batch = images_batch.to(device)
labels_batch = labels_batch.to(device)
optimizer.zero_grad()
# Autocast for mixed precision
with autocast("cuda", enabled=use_amp):
logits = model(images_batch)
loss = criterion(logits, labels_batch)
# Scale loss and backward
scaler.scale(loss).backward()
# Unscale before clipping
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# Step with scaler
scaler.step(optimizer)
scaler.update()
total_loss += loss.item()
num_batches += 1
avg_loss = total_loss / num_batches
epoch_losses.append(avg_loss)
print(f" Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
return epoch_losses
if __name__ == "__main__":
print("=" * 50)
print("Mixed Precision Training Demo")
print("=" * 50)
if torch.cuda.is_available():
train_mixed_precision(use_amp=True)
else:
print("CUDA not available. Showing model structure only.")
model = VisionTransformer(
image_size=32, patch_size=4, num_classes=10,
d_model=256, nhead=8, num_layers=6,
)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
35.8.4 Loss Scaling
Loss scaling is essential for fp16 training. Small gradients can underflow to zero in fp16. By multiplying the loss by a large scale factor before the backward pass, we shift gradients into the representable range.
$$\tilde{L} = S \cdot L \quad \Rightarrow \quad \frac{\partial \tilde{L}}{\partial \theta} = S \cdot \frac{\partial L}{\partial \theta}$$
After the backward pass, gradients are divided by $S$ before the optimizer step. PyTorch's GradScaler handles this automatically with dynamic scaling---it starts with a large scale and reduces it if inf/nan gradients are detected.
35.9 HuggingFace Accelerate
HuggingFace Accelerate provides a high-level abstraction that makes it easy to run the same training code on a single GPU, multiple GPUs, or multiple nodes without code changes.
35.9.1 Minimal Changes for Distributed Training
"""Training with HuggingFace Accelerate.
Demonstrates how Accelerate simplifies distributed training
by abstracting away device management and gradient synchronization.
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
torch.manual_seed(42)
try:
from accelerate import Accelerator
HAS_ACCELERATE = True
except ImportError:
HAS_ACCELERATE = False
class TextClassifier(nn.Module):
"""A transformer-based text classifier.
Args:
vocab_size: Vocabulary size.
d_model: Model dimension.
nhead: Number of attention heads.
num_layers: Number of layers.
num_classes: Output classes.
max_seq_len: Maximum sequence length.
"""
def __init__(
self,
vocab_size: int = 30000,
d_model: int = 512,
nhead: int = 8,
num_layers: int = 6,
num_classes: int = 5,
max_seq_len: int = 256,
) -> None:
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(max_seq_len, d_model)
encoder_layer = nn.TransformerEncoderLayer(
d_model, nhead, d_model * 4, batch_first=True,
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)
self.classifier = nn.Linear(d_model, num_classes)
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
input_ids: Token IDs, shape (batch, seq_len).
Returns:
Logits, shape (batch, num_classes).
"""
seq_len = input_ids.size(1)
positions = torch.arange(seq_len, device=input_ids.device)
x = self.embedding(input_ids) + self.pos_emb(positions)
x = self.encoder(x)
return self.classifier(x.mean(dim=1))
def train_with_accelerate(epochs: int = 5) -> None:
"""Train a model using HuggingFace Accelerate.
The same code works for:
- Single GPU
- Multi-GPU (DDP)
- Multi-node
- TPU
- Mixed precision
Args:
epochs: Number of training epochs.
"""
if not HAS_ACCELERATE:
print("HuggingFace Accelerate not installed.")
print("Install with: pip install accelerate")
return
# Initialize Accelerator
accelerator = Accelerator(
mixed_precision="fp16", # or "bf16", "no"
gradient_accumulation_steps=4,
)
# Create model, optimizer, and data (on CPU - Accelerate handles device)
model = TextClassifier()
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
# Create dataset
dataset = TensorDataset(
torch.randint(0, 30000, (2000, 128)),
torch.randint(0, 5, (2000,)),
)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
# Prepare everything with Accelerator
# This handles: device placement, DDP wrapping, sampler, etc.
model, optimizer, dataloader = accelerator.prepare(
model, optimizer, dataloader,
)
# Training loop - looks exactly like single-GPU code!
for epoch in range(epochs):
model.train()
total_loss = 0.0
num_batches = 0
for input_ids, labels in dataloader:
# Gradient accumulation context
with accelerator.accumulate(model):
logits = model(input_ids)
loss = criterion(logits, labels)
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
num_batches += 1
avg_loss = total_loss / num_batches
# Only print on main process
if accelerator.is_main_process:
print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
# Save model
if accelerator.is_main_process:
accelerator.save_model(model, "saved_model")
print("Model saved.")
if __name__ == "__main__":
train_with_accelerate()
35.9.2 Accelerate Configuration
Accelerate uses a configuration file to specify the distributed setup:
# accelerate_config.yaml
compute_environment: LOCAL_MACHINE
distributed_type: MULTI_GPU
num_machines: 1
num_processes: 4
mixed_precision: bf16
gpu_ids: 0,1,2,3
main_training_function: main
use_cpu: false
Launch with:
accelerate launch --config_file accelerate_config.yaml train.py
35.9.3 Multi-Node Training with Accelerate
For multi-node training, the configuration changes slightly:
compute_environment: LOCAL_MACHINE
distributed_type: MULTI_GPU
num_machines: 2
num_processes: 8 # total across all machines
machine_rank: 0 # different for each machine
main_process_ip: 10.0.0.1
main_process_port: 29500
mixed_precision: bf16
On each node:
# Node 0
accelerate launch --config_file config_node0.yaml train.py
# Node 1
accelerate launch --config_file config_node1.yaml train.py
35.10 Multi-GPU and Multi-Node Training
35.10.1 Launching Multi-GPU Training
PyTorch provides multiple ways to launch distributed training:
torchrun (recommended):
# Single node, 4 GPUs
torchrun --nproc_per_node=4 train.py
# Multi-node (2 nodes, 4 GPUs each)
# On node 0:
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 \
--master_addr=10.0.0.1 --master_port=29500 train.py
# On node 1:
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=1 \
--master_addr=10.0.0.1 --master_port=29500 train.py
35.10.2 Multi-Node Considerations
Network topology matters. Within a single node, GPUs communicate via NVLink (up to 900 GB/s on H100). Between nodes, communication goes over InfiniBand (up to 400 Gb/s) or Ethernet. This asymmetry affects parallelism strategy choice.
┌─────────────────────────────────────────────────────────┐
│ Multi-Node Communication │
│ │
│ Node 0 Node 1 │
│ ┌────────────────────┐ ┌────────────────────┐ │
│ │ GPU0 ←NVLink→ GPU1 │←─IB/ETH→│ GPU4 ←NVLink→ GPU5 │ │
│ │ ↕ ↕ │ │ ↕ ↕ │ │
│ │ GPU2 ←NVLink→ GPU3 │←─IB/ETH→│ GPU6 ←NVLink→ GPU7 │ │
│ └────────────────────┘ └────────────────────┘ │
│ │
│ NVLink bandwidth: 600-900 GB/s (intra-node) │
│ InfiniBand bandwidth: ~50 GB/s (inter-node) │
│ Ethernet bandwidth: ~3-12 GB/s (inter-node) │
└─────────────────────────────────────────────────────────┘
Best practices for multi-node training:
- Use FSDP HYBRID_SHARD or DeepSpeed with hierarchical communication to minimize inter-node traffic.
- Ensure all nodes have identical hardware and software environments.
- Use a high-speed network (InfiniBand preferred) between nodes.
- Implement robust checkpointing---nodes can fail independently.
- Monitor network utilization alongside GPU utilization.
35.10.3 Efficient Checkpointing
"""Efficient distributed checkpointing.
Demonstrates saving and loading checkpoints in distributed
training with support for resuming after failures.
"""
import os
from pathlib import Path
import torch
import torch.nn as nn
import torch.distributed as dist
torch.manual_seed(42)
class DistributedCheckpointer:
"""Manages checkpoints for distributed training.
Handles saving from rank 0 and loading across all ranks.
Args:
checkpoint_dir: Directory for checkpoint storage.
max_checkpoints: Maximum checkpoints to keep.
"""
def __init__(
self,
checkpoint_dir: str = "./checkpoints",
max_checkpoints: int = 3,
) -> None:
self.checkpoint_dir = Path(checkpoint_dir)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.max_checkpoints = max_checkpoints
def save(
self,
model: nn.Module,
optimizer: torch.optim.Optimizer,
epoch: int,
step: int,
loss: float,
rank: int = 0,
is_fsdp: bool = False,
) -> str | None:
"""Save a checkpoint.
Only rank 0 saves in DDP mode. For FSDP, each rank
saves its shard.
Args:
model: The model (or DDP/FSDP-wrapped model).
optimizer: The optimizer.
epoch: Current epoch.
step: Current global step.
loss: Current loss value.
rank: Process rank.
is_fsdp: Whether using FSDP (each rank saves its shard).
Returns:
Path to saved checkpoint, or None for non-saving ranks.
"""
if not is_fsdp and rank != 0:
# In DDP, only rank 0 saves
dist.barrier() # Wait for rank 0 to finish saving
return None
checkpoint_path = self.checkpoint_dir / f"checkpoint_epoch{epoch}_step{step}"
if is_fsdp:
# FSDP: each rank saves its own shard
shard_path = checkpoint_path / f"rank_{rank}"
shard_path.mkdir(parents=True, exist_ok=True)
# Save FSDP state dict
checkpoint = {
"epoch": epoch,
"step": step,
"loss": loss,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
}
torch.save(checkpoint, shard_path / "checkpoint.pt")
else:
# DDP: rank 0 saves the full model
checkpoint_path.mkdir(parents=True, exist_ok=True)
# Access underlying model (unwrap DDP)
model_to_save = (
model.module if hasattr(model, "module") else model
)
checkpoint = {
"epoch": epoch,
"step": step,
"loss": loss,
"model_state_dict": model_to_save.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
}
torch.save(checkpoint, checkpoint_path / "checkpoint.pt")
dist.barrier() # Signal other ranks that save is complete
# Cleanup old checkpoints
self._cleanup_old_checkpoints()
return str(checkpoint_path)
def load(
self,
model: nn.Module,
optimizer: torch.optim.Optimizer | None = None,
checkpoint_path: str | None = None,
rank: int = 0,
device: torch.device = torch.device("cpu"),
) -> dict:
"""Load a checkpoint.
Args:
model: The model to load into.
optimizer: Optional optimizer to restore.
checkpoint_path: Specific checkpoint to load. If None,
loads the latest.
rank: Process rank.
device: Device to load onto.
Returns:
Checkpoint metadata (epoch, step, loss).
"""
if checkpoint_path is None:
checkpoint_path = self._get_latest_checkpoint()
if checkpoint_path is None:
return {"epoch": 0, "step": 0, "loss": float("inf")}
path = Path(checkpoint_path)
# Check for FSDP sharded checkpoint
shard_path = path / f"rank_{rank}" / "checkpoint.pt"
if shard_path.exists():
checkpoint = torch.load(shard_path, map_location=device)
else:
checkpoint = torch.load(
path / "checkpoint.pt", map_location=device,
)
model_to_load = model.module if hasattr(model, "module") else model
model_to_load.load_state_dict(checkpoint["model_state_dict"])
if optimizer is not None:
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
return {
"epoch": checkpoint["epoch"],
"step": checkpoint["step"],
"loss": checkpoint["loss"],
}
def _get_latest_checkpoint(self) -> str | None:
"""Find the most recent checkpoint.
Returns:
Path to the latest checkpoint, or None.
"""
checkpoints = sorted(
self.checkpoint_dir.glob("checkpoint_*"),
key=lambda p: p.stat().st_mtime,
)
return str(checkpoints[-1]) if checkpoints else None
def _cleanup_old_checkpoints(self) -> None:
"""Remove old checkpoints, keeping only the most recent."""
import shutil
checkpoints = sorted(
self.checkpoint_dir.glob("checkpoint_*"),
key=lambda p: p.stat().st_mtime,
)
while len(checkpoints) > self.max_checkpoints:
oldest = checkpoints.pop(0)
shutil.rmtree(oldest)
35.11 Cost-Effective Training Strategies
Training large models can cost millions of dollars. Here are strategies to reduce costs without sacrificing quality.
35.11.1 Spot/Preemptible Instances
Cloud spot instances cost 60--90% less than on-demand instances but can be interrupted. Robust checkpointing is essential.
"""Cost-effective training utilities.
Provides spot instance handling, learning rate scheduling,
and training efficiency monitoring.
"""
import signal
import time
from dataclasses import dataclass
import torch
import torch.nn as nn
torch.manual_seed(42)
@dataclass
class TrainingEfficiency:
"""Metrics for training efficiency analysis.
Attributes:
gpu_utilization: Average GPU utilization (0-1).
memory_utilization: Peak GPU memory usage (0-1).
samples_per_second: Training throughput.
cost_per_sample: Estimated cost per training sample.
time_in_data_loading: Fraction of time in data loading.
time_in_forward: Fraction of time in forward pass.
time_in_backward: Fraction of time in backward pass.
time_in_communication: Fraction of time in gradient sync.
"""
gpu_utilization: float
memory_utilization: float
samples_per_second: float
cost_per_sample: float
time_in_data_loading: float
time_in_forward: float
time_in_backward: float
time_in_communication: float
class SpotInstanceHandler:
"""Handles spot instance preemption gracefully.
Listens for preemption signals and saves a checkpoint
before the instance is terminated.
Args:
checkpoint_fn: Callable to save a checkpoint.
"""
def __init__(self, checkpoint_fn: callable) -> None:
self.checkpoint_fn = checkpoint_fn
self.preempted = False
# Register signal handlers
signal.signal(signal.SIGTERM, self._handle_preemption)
signal.signal(signal.SIGINT, self._handle_preemption)
def _handle_preemption(self, signum: int, frame) -> None:
"""Handle preemption signal.
Args:
signum: Signal number.
frame: Current stack frame.
"""
print(f"\nPreemption signal received (signal {signum})!")
print("Saving emergency checkpoint...")
self.preempted = True
self.checkpoint_fn()
print("Checkpoint saved. Exiting gracefully.")
@property
def should_continue(self) -> bool:
"""Check if training should continue."""
return not self.preempted
class TrainingProfiler:
"""Profiles training loop performance.
Measures time spent in each phase of the training loop
to identify bottlenecks.
"""
def __init__(self) -> None:
self.timings: dict[str, list[float]] = {
"data_loading": [],
"forward": [],
"backward": [],
"optimizer_step": [],
"other": [],
}
self._current_phase: str | None = None
self._phase_start: float = 0.0
def start_phase(self, phase: str) -> None:
"""Start timing a phase.
Args:
phase: Name of the phase.
"""
if torch.cuda.is_available():
torch.cuda.synchronize()
self._current_phase = phase
self._phase_start = time.perf_counter()
def end_phase(self) -> None:
"""End timing the current phase."""
if self._current_phase is None:
return
if torch.cuda.is_available():
torch.cuda.synchronize()
elapsed = time.perf_counter() - self._phase_start
self.timings[self._current_phase].append(elapsed)
self._current_phase = None
def get_summary(self) -> dict[str, float]:
"""Get timing summary as fractions of total time.
Returns:
Dictionary mapping phase names to fraction of total time.
"""
total = sum(
sum(times) for times in self.timings.values()
)
if total == 0:
return {}
return {
phase: sum(times) / total
for phase, times in self.timings.items()
if times
}
def print_report(self) -> None:
"""Print a formatted profiling report."""
summary = self.get_summary()
if not summary:
print("No profiling data collected.")
return
print("\n" + "=" * 50)
print("Training Profiling Report")
print("=" * 50)
for phase, fraction in sorted(
summary.items(), key=lambda x: -x[1]
):
bar = "#" * int(fraction * 40)
print(f" {phase:20s}: {fraction*100:5.1f}% {bar}")
# Identify bottleneck
bottleneck = max(summary, key=summary.get)
print(f"\nBottleneck: {bottleneck} ({summary[bottleneck]*100:.1f}%)")
if bottleneck == "data_loading":
print("Recommendation: Increase num_workers, use pin_memory, "
"or pre-process data.")
elif bottleneck == "backward":
print("Recommendation: Consider gradient checkpointing or "
"mixed precision.")
elif bottleneck == "optimizer_step":
print("Recommendation: Consider fused optimizer "
"(e.g., apex.optimizers.FusedAdam).")
def compute_training_flops(
model_params: int,
batch_size: int,
seq_length: int,
num_steps: int,
) -> dict[str, float]:
"""Estimate training FLOPs for a transformer model.
Uses the approximation from the Chinchilla paper:
FLOPs per token ~ 6 * model_params
Args:
model_params: Number of model parameters.
batch_size: Training batch size.
seq_length: Sequence length.
num_steps: Total training steps.
Returns:
Dictionary with FLOPs estimates and efficiency metrics.
"""
tokens_per_step = batch_size * seq_length
total_tokens = tokens_per_step * num_steps
flops_per_token = 6 * model_params # Forward + backward
total_flops = flops_per_token * total_tokens
return {
"total_tokens": total_tokens,
"total_flops": total_flops,
"pflops": total_flops / 1e15, # PetaFLOPs
"tokens_per_step": tokens_per_step,
}
35.11.2 Chinchilla Optimal Training
The Chinchilla scaling laws showed that many models are overtrained on too little data. The optimal ratio is approximately:
$$D_{\text{optimal}} \approx 20 \times N$$
where $D$ is the number of training tokens and $N$ is the number of model parameters. Training a 7B parameter model optimally requires about 140B tokens.
35.11.3 Cost Optimization Checklist
- Use mixed precision (bf16/fp16). Nearly free 2x speedup.
- Right-size your model. Smaller models trained on more data often outperform larger undertrained models.
- Use spot instances with robust checkpointing.
- Profile before optimizing. Identify whether you are compute-bound, memory-bound, or communication-bound.
- Use gradient checkpointing to trade compute for memory.
- Optimize data loading. Ensure GPUs are never waiting for data.
- Choose the right parallelism strategy. DDP for models that fit on one GPU; FSDP/ZeRO for larger models.
- Use compiled models.
torch.compile()can provide 10-30% speedup.
35.12 Combining Parallelism Strategies: 3D Parallelism
The largest models use a combination of all three parallelism strategies, often called 3D parallelism:
┌──────────────────────────────────────────────────────────────┐
│ 3D Parallelism │
│ │
│ ┌─────────────────────────────────────────┐ │
│ │ Data Parallelism (across replicas) │ │
│ │ │ │
│ │ Replica 0 Replica 1 │ │
│ │ ┌──────────────┐ ┌──────────────┐ │ │
│ │ │Pipeline Stage│ │Pipeline Stage│ │ │
│ │ │ 0 │ │ 0 │ │ │
│ │ │ ┌──┐ ┌──┐ │ │ ┌──┐ ┌──┐ │ │ │
│ │ │ │G0│ │G1│ │ │ │G4│ │G5│ │ │ Tensor │
│ │ │ └──┘ └──┘ │ │ └──┘ └──┘ │ │ Parallel │
│ │ ├──────────────┤ ├──────────────┤ │ │
│ │ │Pipeline Stage│ │Pipeline Stage│ │ Pipeline │
│ │ │ 1 │ │ 1 │ │ Parallel │
│ │ │ ┌──┐ ┌──┐ │ │ ┌──┐ ┌──┐ │ │ │
│ │ │ │G2│ │G3│ │ │ │G6│ │G7│ │ │ Tensor │
│ │ │ └──┘ └──┘ │ │ └──┘ └──┘ │ │ Parallel │
│ │ └──────────────┘ └──────────────┘ │ │
│ └─────────────────────────────────────────┘ │
│ │
│ Total: 8 GPUs │
│ - 2 data parallel replicas │
│ - 2 pipeline stages per replica │
│ - 2 tensor parallel devices per stage │
└──────────────────────────────────────────────────────────────┘
35.12.1 Choosing the Right Combination
| Model Size | Recommended Strategy |
|---|---|
| < 1B params | DDP (single GPU per replica) |
| 1B -- 10B params | FSDP or DDP + gradient checkpointing |
| 10B -- 100B params | FSDP or ZeRO Stage 3 |
| 100B+ params | 3D parallelism (TP + PP + DP) |
35.13 Summary
Distributed training is essential for modern deep learning, where models are often too large for a single GPU and datasets are too massive for reasonable training times. In this chapter, we covered:
-
Communication primitives (all-reduce, all-gather, reduce-scatter) are the building blocks of distributed training. Understanding them helps you reason about performance bottlenecks.
-
Data parallelism (DDP) is the simplest approach: replicate the model, split the data, synchronize gradients. It works when the model fits on one GPU.
-
Model parallelism splits the model across GPUs using tensor parallelism (splitting operations) or pipeline parallelism (splitting layers).
-
FSDP provides memory-efficient data parallelism by sharding parameters, gradients, and optimizer states across GPUs.
-
DeepSpeed ZeRO offers three stages of memory optimization, from partitioning optimizer states (Stage 1) to partitioning everything (Stage 3), with optional CPU/NVMe offloading.
-
Gradient accumulation simulates larger batch sizes without increasing memory, and interacts carefully with DDP's gradient synchronization.
-
Mixed precision training uses fp16 or bf16 to reduce memory usage and increase throughput, with loss scaling to prevent gradient underflow.
-
HuggingFace Accelerate abstracts away distributed training boilerplate, letting you write code that works on any hardware configuration.
-
Multi-node training introduces network bandwidth as a new bottleneck, requiring careful architecture choices and robust checkpointing.
-
Cost-effective strategies like spot instances, proper profiling, and Chinchilla-optimal training can reduce costs by orders of magnitude.
The right distributed training strategy depends on your model size, hardware, and budget. Start with the simplest approach that works (usually DDP), and scale up to more complex strategies only as needed. In the next chapter, we will explore how to take these trained models and deploy them efficiently for inference at scale.