Case Study 1: Climate DL — Distributed Training for a Global Weather Prediction Model

Context

The Climate DL team is building a 1.2-billion-parameter Vision Transformer (ViT) for global weather prediction. The model takes atmospheric fields — temperature, humidity, wind speed, geopotential height, and 12 other variables — at $1024 \times 2048$ resolution (approximately 0.25-degree latitude-longitude grid) and predicts the same fields 6 hours into the future. By autoregressively applying the model, the team produces 10-day forecasts that compete with the European Centre for Medium-Range Weather Forecasts (ECMWF) operational model.

The training data is ERA5 reanalysis: 40 years of hourly atmospheric data (1979-2019), approximately 350,000 time steps. Each time step is a $1024 \times 2048 \times 17$ tensor (17 atmospheric variables). Uncompressed, the dataset is approximately 200 TB; compressed with float16 and lossless encoding, it is approximately 35 TB stored on a distributed filesystem (Lustre) accessible from all compute nodes.

The model architecture:

Component Configuration
Encoder Patch embedding ($16 \times 16$ patches), 4,096 patches per field
Backbone 48 transformer layers, $d = 1536$, 24 attention heads, $d_{\text{ff}} = 6144$
Parameters 1.2 billion
Sequence length 4,096 tokens (one per patch)
Input 17 atmospheric variables at 6-hour intervals
Output Same 17 variables, 6 hours ahead
Loss Latitude-weighted MSE (accounting for convergence of meridians at poles)

The team has access to a cluster of 8 DGX A100 nodes, each with 8 A100 80GB GPUs connected via NVLink (600 GB/s intra-node), with 4x HDR InfiniBand (200 Gb/s inter-node). Total: 64 GPUs.

The challenge: Training on a single A100 at batch size 1 takes approximately 120 days. The team needs to train the model in under 7 days to iterate on architecture changes weekly.

The Scaling Journey

Phase 1: Single-GPU Baseline

The team begins with a single-GPU training run to establish quality baselines and identify bottlenecks.

Memory analysis. With FP32 precision:

Component Memory
Parameters 4.8 GB
Gradients 4.8 GB
Adam optimizer state 9.6 GB
Activations (batch size 1, 48 layers, seq 4096) ~52 GB
Total ~71 GB

The model does not fit on a single A100 at batch size 1 in FP32. Activations alone exceed the memory available after loading the model.

First optimization: AMP (BF16). Switching to BF16 halves activation memory:

Component Memory
Master parameters (FP32) 4.8 GB
Working parameters (BF16) 2.4 GB
Gradients (BF16) 2.4 GB
Adam state (FP32) 9.6 GB
Activations (BF16, batch 1) ~26 GB
Total ~45 GB

The model now fits on a single A100 80GB at batch size 1 with 35 GB headroom.

Second optimization: Gradient checkpointing. To increase the batch size beyond 1, the team enables gradient checkpointing every 7 layers ($\approx \sqrt{48}$). Activation memory drops from 26 GB to approximately 3.8 GB — enough for batch size 4:

Component Memory (batch 4)
Master + working parameters 7.2 GB
Gradients 2.4 GB
Adam state 9.6 GB
Checkpointed activations ~15 GB
Total ~34 GB

Third optimization: FlashAttention. Standard attention for sequence length 4,096 with 24 heads at batch size 4 would consume approximately $4 \times 24 \times 4096^2 \times 2 = 3.2$ GB per layer for the attention matrix alone (in BF16). With FlashAttention, this is eliminated entirely — the attention matrix is never materialized in HBM. This frees additional memory for a batch size of 8 per GPU.

Baseline result: Single GPU, batch size 8, BF16, gradient checkpointing, FlashAttention: approximately 0.85 samples/second. Estimated time for 50 epochs over 350,000 time steps: $50 \times 350{,}000 / 0.85 / 3600 \approx 5{,}720$ GPU-hours (238 days).

Phase 2: Multi-GPU Data Parallelism (Intra-Node)

The team scales to 8 GPUs on a single DGX node using DDP.

Configuration: - Parallelism: DDP (NCCL backend, NVLink) - Local batch size per GPU: 8 - Global batch size: 64 - Learning rate: $3 \times 10^{-4} \times (64/8) = 2.4 \times 10^{-3}$ (linear scaling, 8x base) - Warmup: 2,000 steps (approximately 5% of total steps) - Optimizer: LAMB (required for global batch size > 256 after further scaling)

Scaling efficiency: Intra-node NVLink provides 600 GB/s bidirectional bandwidth. The gradient tensor is approximately 2.4 GB (BF16). Ring all-reduce time:

$$T_{\text{ring}} = 2 \times 7 \times 10\,\mu\text{s} + 2 \times \frac{7}{8} \times \frac{2.4 \times 10^9}{600 \times 10^9} = 140\,\mu\text{s} + 7.0\,\text{ms} \approx 7.0\,\text{ms}$$

A single training step (forward + backward) takes approximately 1.18 seconds. Communication is 7ms / 1180ms = 0.6% of step time. Scaling is near-ideal.

Result: 8 GPUs, throughput 6.4 samples/second (8x single GPU, 96% scaling efficiency). Training time: ~720 GPU-hours (30 days wall clock, or 3.7 days on 8 GPUs).

Still too slow for weekly iteration.

Phase 3: Multi-Node Data Parallelism (64 GPUs)

The team scales to 8 nodes (64 GPUs) using DDP across InfiniBand.

The communication challenge. Inter-node bandwidth is 25 GB/s effective (200 Gb/s HDR InfiniBand with encoding overhead) — 24x lower than intra-node NVLink. The all-reduce now crosses the node boundary:

$$T_{\text{ring}} = 2 \times 63 \times 50\,\mu\text{s} + 2 \times \frac{63}{64} \times \frac{2.4 \times 10^9}{25 \times 10^9} = 6.3\,\text{ms} + 189\,\text{ms} \approx 195\,\text{ms}$$

At 195ms per all-reduce and a 1.18-second training step, communication is 16.5% of step time. Scaling efficiency drops to about 83%.

Optimization 1: Gradient accumulation. The team accumulates gradients over 4 micro-batches before each all-reduce, reducing communication frequency by 4x. Effective communication fraction: 195ms / (4 × 1180ms) = 4.1%.

Optimization 2: Overlap communication with backward pass. DDP's built-in gradient bucketing overlaps all-reduce with gradient computation. Since the backward pass takes ~900ms and the all-reduce takes ~195ms, approximately 80% of the communication is hidden behind computation. Effective communication overhead: ~40ms (the tail that is not overlapped).

Optimization 3: Switch to FSDP (ZeRO-2). With 64 GPUs, sharding optimizer state and gradients reduces per-GPU memory by 4.5 GB (from the optimizer state sharding), allowing the team to increase the local batch size from 8 to 12. The larger batch size further amortizes communication and increases GPU utilization.

Configuration after optimization: - Parallelism: FSDP (ZeRO-2 equivalent), NCCL backend - Local batch size per GPU: 12 - Gradient accumulation: 4 - Effective global batch size: $12 \times 4 \times 64 = 3{,}072$ - Optimizer: LAMB with warmup cosine schedule - Learning rate: $3 \times 10^{-4} \times (3072/8) = 0.1152$ (capped at 0.01 with trust ratio adjustment)

Result: 64 GPUs, throughput 44 samples/second (64x single GPU, approximately 80% scaling efficiency). Wall-clock training time: $50 \times 350{,}000 / 44 / 3600 \approx 110$ hours (4.6 days).

Phase 4: Cost Optimization

The 64-GPU training run costs:

$$\text{On-demand:} \quad 110\,\text{hours} \times 64\,\text{GPUs} \times \$3.50/\text{GPU-hour} = \$24{,}640$$

Spot instance strategy. The team uses a mix of on-demand (1 node, hosting the rank-0 process and checkpoints) and spot instances (7 nodes). Spot pricing at 65% discount: $1.23/GPU-hour.

Expected cost with spot: - 1 on-demand node: $110 \times 8 \times \$3.50 = \$3{,}080$ - 7 spot nodes: $110 \times 56 \times \$1.23 = \$7{,}577$ - Expected preemption cost (2 preemptions, 45 min lost each): $2 \times 0.75 \times 64 \times \$1.23 = \$118$ - Total: $10,775 (56% savings vs. on-demand)

Checkpoint strategy: Save every 500 steps (approximately every 35 minutes). Each checkpoint is 12 GB (model + optimizer state). Checkpoints are saved to a shared Lustre filesystem visible to all nodes, with asynchronous I/O to avoid blocking training.

Results and Lessons

The final model achieves a weighted root-mean-square error (RMSE) of 3.21 hPa on 500 hPa geopotential height at 5-day lead time, compared to 3.45 hPa for the team's previous CNN-based model (Chapter 8) and 3.12 hPa for the ECMWF operational model.

Configuration Throughput Wall-clock Scaling Eff. Cost
1 GPU baseline 0.85 s/s 238 days 100%
8 GPU (DDP, NVLink) 6.4 s/s 30 days 96%
64 GPU (FSDP, IB) 44 s/s 4.6 days 80% $10,775

Lesson 1: Memory optimization unlocks data parallelism. Without AMP, gradient checkpointing, and FlashAttention, the model would not fit on a single GPU at a useful batch size. These optimizations are prerequisites for DDP, not enhancements.

Lesson 2: Communication overhead is the primary scaling bottleneck across nodes. Intra-node (NVLink) scaling is near-ideal; inter-node (InfiniBand) introduces meaningful overhead. Gradient accumulation and DDP's communication-computation overlap are essential for multi-node efficiency.

Lesson 3: Cost optimization is engineering, not compromise. Spot instances with checkpointing saved $14,000 per training run — enabling the team to run 2.3x more experiments within their compute budget, which contributed more to final model quality than any single architectural change.

Lesson 4: Profiling before optimizing. The team's first instinct was to use tensor parallelism (splitting the transformer layers across GPUs). Profiling revealed that the model fit comfortably on a single GPU with memory optimizations, making data parallelism (with FSDP for optimizer state) the simpler and more efficient choice. The profiling took 2 hours; the avoided complexity of tensor parallelism implementation saved weeks.