Case Study 2: Training with FSDP and DeepSpeed

Overview

A research team needs to fine-tune a 13B parameter language model on domain-specific data. The model requires approximately 52 GB in float32 for parameters alone, far exceeding single-GPU memory. The team has access to 4 A100-80GB GPUs on a single node and must choose between FSDP and DeepSpeed for memory-efficient training.

Problem Statement

Training a 13B model with AdamW requires: - Parameters: 13B * 4 bytes = 52 GB - Gradients: 52 GB - Adam states: 104 GB (2 copies of parameters) - Total (excluding activations): 208 GB

This exceeds the combined 320 GB across 4 A100-80GB GPUs when accounting for activation memory and framework overhead.

Approach

Configuration 1: FSDP

PyTorch FSDP with full sharding (equivalent to ZeRO Stage 3): - Shard parameters, gradients, and optimizer states across 4 GPUs - Mixed precision: compute in BF16, reduce in FP32, store in BF16 - Activation checkpointing on every transformer block - Per-block FSDP wrapping

Memory per GPU (FSDP): - Sharded parameters: 52 GB / 4 = 13 GB (stored in BF16: 6.5 GB) - Sharded gradients: 13 GB (in BF16: 6.5 GB) - Sharded optimizer states: 26 GB (in FP32) - Activations (with checkpointing): ~8 GB - Total: ~47 GB per GPU (fits in 80 GB)

Configuration 2: DeepSpeed ZeRO Stage 3

DeepSpeed with ZeRO-3 and CPU offloading: - ZeRO Stage 3 parameter sharding - Optimizer state offloading to CPU - Mixed precision with FP16

Memory per GPU (DeepSpeed + offload): - Sharded parameters: 6.5 GB (BF16) - Sharded gradients: 6.5 GB - Optimizer states: offloaded to CPU RAM - Activations: ~8 GB - Total GPU: ~21 GB per GPU

Training Configuration

Setting FSDP DeepSpeed
Precision BF16 compute, FP32 reduce FP16 with loss scaling
Batch size (per GPU) 2 4
Gradient accumulation 4 steps 2 steps
Effective batch size 32 32
Learning rate 2e-5 2e-5
Warmup steps 100 100
Activation checkpointing Yes Yes

Results

Metric FSDP DeepSpeed ZeRO-3 DeepSpeed + Offload
GPU memory per device 47 GB 44 GB 21 GB
Training throughput 820 tokens/s 790 tokens/s 420 tokens/s
Time per epoch (10K steps) 3.4 hours 3.5 hours 6.6 hours
Validation loss 1.82 1.83 1.83
Setup complexity Medium Medium-High High

Throughput Breakdown

Component FSDP DeepSpeed ZeRO-3
Forward pass 35% 34%
Backward pass 40% 39%
All-gather (params) 12% 13%
Reduce-scatter (grads) 8% 9%
Optimizer step 5% 5%

Key Lessons

  1. FSDP and DeepSpeed ZeRO-3 achieve comparable performance. Without CPU offloading, both solutions provide similar throughput and memory efficiency. The choice depends on ecosystem preferences (PyTorch-native vs DeepSpeed library).

  2. CPU offloading trades throughput for memory. DeepSpeed with offloading reduced GPU memory by 55% but halved training throughput. This is worthwhile when GPU memory is the binding constraint (e.g., larger models or fewer GPUs).

  3. Activation checkpointing is essential for large models. Without it, activation memory exceeded GPU capacity even with parameter sharding. Checkpointing added ~30% compute overhead but made training feasible.

  4. BF16 is preferred over FP16 for large models. BF16's larger dynamic range eliminated the need for loss scaling and gradient overflow handling, simplifying the training loop.

  5. Per-block FSDP wrapping optimizes the memory-communication trade-off. Wrapping at the transformer block level (rather than per-layer or whole-model) provided the best balance: sufficient sharding for memory savings without excessive all-gather communication.

  6. Effective batch size should be kept constant across configurations. When comparing FSDP and DeepSpeed, keeping the effective batch size at 32 (via different per-GPU batch sizes and gradient accumulation steps) ensured comparable training dynamics and final loss.

Code Reference

The complete implementation is available in code/case-study-code.py.