Case Study 2: StreamRec — Scaling Two-Tower Training to 1.2 Billion Interactions

Context

Since Chapter 13, the StreamRec two-tower retrieval model has been trained on a 10% sample of user-item interactions — approximately 120 million events. This sample was sufficient for architecture iteration and offline evaluation, but the production system (Chapter 24) serves 50 million users, and the team has strong evidence that training on the full dataset improves recommendation quality.

An offline experiment on historical data shows the relationship between training data size and Recall@20:

Training data Recall@20 NDCG@20
12M (1% sample) 0.182 0.091
120M (10% sample) 0.214 0.112
600M (50% sample) 0.231 0.124
1.2B (full dataset) 0.243 0.133

The full-dataset model shows a meaningful improvement: +13.6% relative Recall@20 and +18.8% relative NDCG@20 compared to the 10% sample. A/B tests on a 5% traffic holdout confirmed that this translates to a 4.2% increase in user engagement (content completion rate), which the product team estimates is worth $2.1M in annual revenue.

The constraint: StreamRec retrains daily (Chapter 24). The training pipeline must complete within 6 hours to fit the daily cycle. Single-GPU training on the full dataset takes approximately 14 hours.

The Model

The two-tower model from Chapter 13:

Component Details
User tower User ID embedding (50M users, dim 256) → MLP (256→256→128)
Item tower Item ID embedding (200K items, dim 256) + content features → MLP (512→256→128)
Training objective InfoNCE contrastive loss with in-batch negatives
Parameters 45M total (dominated by user embedding table: 50M × 256 × 4 bytes = 51 GB in FP32)
Training data 1.2 billion (user_id, item_id, label) triples

The memory challenge. The model parameters total 45M, but the user embedding table is 51 GB in FP32 — it does not fit in GPU memory. In the 10% sample training, the team used a vocabulary-subsetting trick: only embeddings for users in the current batch are loaded to GPU. With DDP, each GPU needs only the embeddings for its local batch.

The Scaling Solution

Step 1: Parallelism Selection

The ParallelismDecision framework from Section 26.3.5 recommends data parallelism: the model's active parameters per batch (excluding the full embedding table) are approximately 0.5 GB, which fits easily on a single GPU. The embedding table is handled via PyTorch's nn.Embedding with sparse gradients — only the embeddings for users in the current batch are updated.

The team uses 4 A100 80GB GPUs on a single DGX node. With NVLink interconnect, communication overhead will be minimal.

Step 2: DDP Configuration

from dataclasses import dataclass


@dataclass
class StreamRecTrainingConfig:
    """Training configuration for StreamRec two-tower model at scale.

    Attributes:
        num_gpus: Number of GPUs for DDP.
        local_batch_size: Batch size per GPU.
        learning_rate_base: LR tuned at base_batch_size.
        base_batch_size: Batch size at which LR was tuned.
        num_epochs: Training epochs.
        amp_dtype: Mixed precision dtype.
        warmup_fraction: Fraction of steps for LR warmup.
        infonce_temperature: Temperature for InfoNCE loss.
    """
    num_gpus: int = 4
    local_batch_size: int = 4096
    learning_rate_base: float = 1e-3
    base_batch_size: int = 256
    num_epochs: int = 3
    amp_dtype: str = "bfloat16"
    warmup_fraction: float = 0.05
    infonce_temperature: float = 0.07

    @property
    def global_batch_size(self) -> int:
        return self.local_batch_size * self.num_gpus

    @property
    def scaled_lr(self) -> float:
        return self.learning_rate_base * (
            self.global_batch_size / self.base_batch_size
        )

    @property
    def steps_per_epoch(self) -> int:
        return 1_200_000_000 // self.global_batch_size

    @property
    def total_steps(self) -> int:
        return self.steps_per_epoch * self.num_epochs

    @property
    def warmup_steps(self) -> int:
        return int(self.total_steps * self.warmup_fraction)

The key design decisions:

Local batch size of 4,096. For contrastive learning with in-batch negatives, larger batches provide more negative examples, which improves the quality of the learned representations. Each micro-batch of 4,096 provides 4,095 negative pairs per positive — sufficient for high-quality contrastive learning.

Global batch size of 16,384. With 4 GPUs, the global batch is 16,384. This is well above the StreamRec model's critical batch size (estimated at ~8,192 from preliminary experiments), so the team monitors convergence carefully.

BF16 mixed precision. The embedding lookups and MLP operations benefit from BF16 Tensor Core acceleration. The InfoNCE loss computation (involving a softmax over similarity scores) runs in FP32 for numerical stability.

Step 3: Addressing the Large Embedding Table

The user embedding table (50M users × 256 dimensions) is too large to replicate on every GPU in FP32. Two strategies are available:

Strategy A: Sparse embeddings with FP32. Each GPU stores the full embedding table but uses sparse gradients — only embeddings for users in the local batch receive gradient updates. Memory cost: 51 GB per GPU, which exceeds A100 80GB after accounting for other model components and activations. This does not work.

Strategy B: BF16 embeddings + CPU offloading. Store the embedding table in BF16 (25.6 GB), fitting on a single GPU. However, this leaves limited room for activations.

Strategy C (chosen): Embedding table sharding. Split the embedding table across 4 GPUs, with each GPU storing 12.5M user embeddings. When a GPU encounters a user whose embedding lives on another GPU, it performs an all-to-all communication to fetch the embedding. This is a form of model parallelism applied specifically to the embedding layer.

import torch
import torch.nn as nn
import torch.distributed as dist
from typing import Tuple


class ShardedEmbedding(nn.Module):
    """Embedding table sharded across GPUs in a DDP group.

    Each GPU stores a contiguous shard of the full embedding table.
    Forward pass: all-to-all to gather required embeddings, then
    local lookup and scatter results back.

    Args:
        num_embeddings: Total vocabulary size.
        embedding_dim: Embedding dimension.
        world_size: Number of GPUs (shards).
        rank: This GPU's rank.
    """

    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        world_size: int,
        rank: int,
    ) -> None:
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.world_size = world_size
        self.rank = rank

        # Each GPU stores a contiguous shard
        shard_size = (num_embeddings + world_size - 1) // world_size
        self.shard_start = rank * shard_size
        self.shard_end = min((rank + 1) * shard_size, num_embeddings)
        self.local_embedding = nn.Embedding(
            self.shard_end - self.shard_start, embedding_dim
        )

    def forward(self, ids: torch.Tensor) -> torch.Tensor:
        """Look up embeddings with cross-GPU communication.

        Args:
            ids: Tensor of global IDs, shape (batch_size,).

        Returns:
            Embeddings of shape (batch_size, embedding_dim).
        """
        # Determine which IDs are local vs. remote
        local_mask = (ids >= self.shard_start) & (ids < self.shard_end)
        local_ids = ids[local_mask] - self.shard_start

        # Local lookup
        result = torch.zeros(
            ids.shape[0], self.embedding_dim,
            device=ids.device, dtype=self.local_embedding.weight.dtype,
        )
        if local_ids.numel() > 0:
            result[local_mask] = self.local_embedding(local_ids)

        # All-reduce to combine results from all shards
        # Each GPU contributes embeddings for its shard; others are zero
        dist.all_reduce(result, op=dist.ReduceOp.SUM)
        return result

With sharding, each GPU stores 12.8 GB of embeddings (BF16) — comfortably fitting alongside the MLP parameters, gradients, optimizer state, and activations.

Step 4: Training Loop and Performance

The team implements the training loop with DDP for the MLP towers and sharded embeddings for the user table:

Metric 1 GPU (10% data) 1 GPU (full data) 4 GPU DDP (full data)
Throughput (samples/s) 18,500 18,200 68,400
Steps per epoch 4,688 73,242 18,311
Time per epoch ~4.2 min ~1.1 hours ~18 min
3-epoch training time ~13 min ~3.3 hours ~54 min
Recall@20 0.214 0.243 0.241
NDCG@20 0.112 0.133 0.131

The 4-GPU DDP result (Recall@20 = 0.241) is within 1% of the single-GPU baseline on the full dataset (0.243), confirming that the distributed training maintains model quality. The small gap is attributable to the larger effective batch size (16,384 vs. 4,096), which slightly changes the contrastive learning dynamics.

Training time: 54 minutes for 3 epochs on the full 1.2B-event dataset. This fits easily within the 6-hour daily retraining window, leaving 5 hours for data pipeline processing, model evaluation, shadow mode comparison, and deployment (Chapters 27-30).

Step 5: Cost and GPU Utilization

Profiling results (4 GPU DDP):

Component Time per step Fraction
Forward pass (BF16) 42 ms 53%
Backward pass (BF16) 28 ms 35%
All-reduce (gradients) 3.2 ms 4%
Embedding all-reduce 2.8 ms 3.5%
Data loading 2.1 ms 2.6%
Other (optimizer, LR) 1.5 ms 1.9%
Total step 79.6 ms 100%

Communication (gradient + embedding all-reduce) is 7.5% of step time — well below the 20% threshold that would indicate a bottleneck. The pipeline is compute-bound, which is the desired regime.

GPU utilization: 87% (measured via nvidia-smi). The remaining 13% is split between data loading (2.6%), communication (7.5%), and Python/CUDA overhead (2.9%).

Cost per training run: - 4 A100 GPUs × 0.9 hours × $3.50/GPU-hour = **$12.60 on-demand - With spot instances: $12.60 × 0.35 = **$4.41

At daily retraining, the annual training compute cost is approximately $1,600 (spot) — negligible compared to the $2.1M annual revenue impact of the full-dataset model.

Lessons Learned

Lesson 1: The embedding table is the bottleneck, not the model. StreamRec's MLP is tiny (45M parameters excluding embeddings), but the 50M-user embedding table dominates memory. Sharding the embedding table is a form of model parallelism that is invisible to the rest of the training pipeline — the MLP towers still use standard DDP.

Lesson 2: Contrastive learning benefits from large batches, but there is a limit. Increasing the global batch size from 256 to 16,384 improved Recall@20 by 3.2% (from the larger pool of in-batch negatives). Beyond 16,384, returns diminished — the critical batch size for contrastive loss is related to the effective number of hard negatives, which saturates.

Lesson 3: Daily retraining at scale is cheap. The team initially budgeted $50,000/year for compute, expecting distributed training to be expensive. The actual cost ($1,600/year with spot instances) was 97% below budget. The recommendation model's computational efficiency — small model, large data — makes daily retraining economically trivial. The larger costs are in the data pipeline (Chapter 25) and serving infrastructure (Chapter 24), not in training.

Lesson 4: Profile before parallelizing. The team's initial plan was to use 8 GPUs with pipeline parallelism across the user and item towers. Profiling on a single GPU revealed that the model is so small that pipeline parallelism's bubble overhead would reduce throughput compared to simple DDP. The correct answer was 4 GPUs with DDP — less hardware, lower cost, simpler code, and faster training than the 8-GPU pipeline design.