Chapter 26: Further Reading
Essential Sources
1. Shen Li, Yanli Zhao, Rohan Varma, Omkar Salpekar, Pieter Noordhuis, Teng Li, Adam Paszke, Jeff Smith, Brian Vaughan, Pritam Damania, and Soumith Chintala, "PyTorch Distributed: Experiences on Accelerating Data Parallel Training" (VLDB, 2020)
The authoritative description of PyTorch's distributed training infrastructure, written by the team that built it. The paper covers the design of DistributedDataParallel (DDP), including the gradient bucketing strategy that overlaps all-reduce with backward-pass computation, the communication backend abstraction (NCCL, Gloo, MPI), and the DistributedSampler design. Section 3.2 on gradient bucketing is essential for understanding why DDP achieves near-linear scaling: by starting the all-reduce for early layers while later layers are still computing gradients, DDP hides most of the communication behind computation.
Reading guidance: Section 2 provides the clearest available explanation of the ring all-reduce algorithm and its performance characteristics, including the bandwidth-optimal property that makes the per-GPU communication cost nearly independent of the number of GPUs. Section 3 describes the DDP implementation, with performance benchmarks on up to 256 GPUs. Section 4 discusses common pitfalls — the ones that cause DDP to hang or produce incorrect results — and is invaluable for debugging production training pipelines. For the FSDP extension, see Zhao et al., "PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel" (VLDB, 2023), which covers the ZeRO-3 implementation in PyTorch, including the auto-wrap policy and mixed-precision integration used in Section 26.9.3 of this chapter.
2. Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, and Christopher Ré, "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (NeurIPS, 2022)
The paper that transformed attention computation by making the GPU memory hierarchy — not just the FLOP count — the optimization target. Dao et al. show that standard attention is memory-bound because it materializes the $O(s^2)$ attention matrix in HBM, requiring multiple slow read/write round trips. FlashAttention avoids this by tiling the computation onto the GPU's on-chip SRAM, using the online softmax trick to compute numerically correct results in a single pass. The result is both lower memory ($O(s)$ instead of $O(s^2)$) and higher throughput (2-4x faster than standard attention on A100s).
Reading guidance: Section 3 presents the algorithm with the tiling and online softmax derivation — this is the core technical contribution and is worth studying carefully. Theorem 1 proves the IO complexity bound, showing that FlashAttention is optimal in the number of HBM accesses. Section 5 provides benchmarks on GPT-2, BERT, and long-document models, demonstrating that FlashAttention enables training with sequence lengths that were previously infeasible. For the extension to the backward pass and FlashAttention-2, see Dao, "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" (ICLR, 2024), which further optimizes the kernel by reducing non-matmul FLOPs and improving work distribution across GPU warps and thread blocks.
3. Samyam Rajbhandari, Jeff Rasley, Olatunji Ruwase, and Yuxiong He, "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models" (SC, 2020)
The paper that introduced the Zero Redundancy Optimizer, the foundational idea behind both DeepSpeed and PyTorch FSDP. Rajbhandari et al. observe that standard data parallelism wastes memory by replicating optimizer state, gradients, and parameters on every GPU, and propose three stages of progressively aggressive sharding that reduce per-GPU memory from $16P$ to $16P/N$ bytes. The paper also introduces CPU offloading (ZeRO-Offload) for cases where even sharded state does not fit in GPU memory.
Reading guidance: Section 3 provides the memory analysis that motivates ZeRO — Table 1 is the definitive reference for understanding where GPU memory goes during training and how much each ZeRO stage saves. Section 4 describes the communication analysis, showing that ZeRO-1 and ZeRO-2 have the same communication volume as standard DDP (the sharding is "free" in terms of communication), while ZeRO-3 adds an all-gather per layer. Section 5 demonstrates training a 100B parameter model on 400 GPUs — a configuration that was impossible with standard data parallelism. For the practical implementation, the DeepSpeed documentation (deepspeed.ai) provides tutorials and configuration guides. For a comparison with FSDP, see the PyTorch FSDP paper cited above.
4. Priya Goyal, Piotr Dollár, Ross Girshick, Pieter Noordhuis, Lukasz Wesolowski, Aapo Kyrola, Andrew Tulloch, Yangqing Jia, and Kaiming He, "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour" (arXiv, 2017)
The paper that established the linear scaling rule and warmup recipe for large-batch training. Goyal et al. train ResNet-50 on ImageNet with batch sizes up to 8,192 across 256 GPUs, achieving the same accuracy as small-batch training in 1 hour. The two key contributions — the linear learning rate scaling rule ($\eta_{\text{new}} = \eta_{\text{base}} \times k$ when the batch size is multiplied by $k$) and the gradual warmup strategy — remain the standard recipe for scaling batch size in data-parallel training. The paper's experimental methodology is a model of rigor: every claim is backed by controlled experiments with clearly stated baselines.
Reading guidance: Section 2 derives the linear scaling rule from first principles, showing that it preserves the expected SGD update magnitude across batch sizes. Section 3 describes the warmup strategy and provides practical guidance on the warmup duration. Section 5 provides the communication analysis for 256-GPU training, including the observation that gradient all-reduce can be overlapped with backward computation — the same insight implemented in DDP's gradient bucketing. For the extension to Adam-based optimizers, see You et al., "Large Batch Optimization for Deep Learning: Training BERT in 76 Minutes" (ICLR, 2020), which introduces LAMB and demonstrates batch sizes up to 65,536.
5. Jordan Hoffmann, Sebastian Borgeaud, Arthur Mensch, Elena Buchatskaya, Trevor Cai, Eliza Rutherford, Diego de Las Casas, Lisa Anne Hendricks, Johannes Welbl, Aidan Clark, Tom Hennigan, Eric Noland, Katie Millican, George van den Driessche, Bogdan Damoc, Aurelia Guy, Simon Osindero, Karen Simonyan, Erich Elsen, Jack W. Rae, Oriol Vinyals, and Laurent Sifre, "Training Compute-Optimal Large Language Models" (NeurIPS, 2022)
Known as the "Chinchilla paper," this work from DeepMind challenges the prevailing practice of training very large models on relatively small datasets. Hoffmann et al. demonstrate that for a given compute budget, the optimal allocation is to scale model size and dataset size equally — specifically, each doubling of model parameters should be matched by a doubling of training tokens. The Chinchilla scaling laws ($L(P, D) = A/P^\alpha + B/D^\beta + L_\infty$) provide a principled framework for compute budgeting that directly informs the cost estimation analysis in Section 26.10.
Reading guidance: Section 3 presents the scaling law derivation and the three estimation approaches. Table 3 provides the optimal parameter-count-to-token-count ratios for different compute budgets — use this as a planning tool before committing to a training configuration. Section 4 validates the scaling law by training the 70B parameter Chinchilla model with the compute-optimal token count, outperforming the 280B parameter Gopher model trained on fewer tokens. For practitioners, the key takeaway is that undertrained large models are a waste of compute: a smaller model trained on more data is both cheaper and better. For extensions beyond language modeling, see the scaling laws discussion in Kaplan et al., "Scaling Laws for Neural Language Models" (arXiv, 2020), and for vision models, Zhai et al., "Scaling Vision Transformers" (CVPR, 2022).