Case Study 1: StreamRec Two-Tower Retrieval — Contrastive Learning for User-Item Matching

Context

StreamRec's recommendation pipeline has grown through the progressive project milestones: matrix factorization for baseline collaborative filtering (M0, Chapter 1), a click-prediction MLP (M2, Chapter 6), 1D CNN content embeddings (M3, Chapter 8), an LSTM session model (M4a, Chapter 9), and a transformer session model replacing the LSTM (M4, Chapter 10). Each milestone improved prediction accuracy, but they all share a fundamental limitation: they are ranking models that score individual items, not retrieval models that efficiently search a catalog.

The ranking models from Chapters 6-10 assume a small candidate set has already been selected. In production, StreamRec's catalog contains 200,000 items. Scoring every item for every user request — even with a fast MLP — takes ~200ms at 200K forward passes, far exceeding the 50ms latency budget for the retrieval stage. The platform needs a retrieval model that can find the top-100 most relevant items from the full catalog in under 10ms.

The two-tower architecture solves this: precompute item embeddings offline, store them in a FAISS index, and compute only one user embedding per request. Retrieval becomes an approximate nearest-neighbor search — sublinear in catalog size.

The Data

StreamRec's engagement data consists of user-item interactions with engagement signals:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from typing import Dict, List, Tuple
from dataclasses import dataclass


@dataclass
class StreamRecCatalog:
    """Simplified StreamRec catalog for the two-tower case study."""
    n_users: int = 50000
    n_items: int = 20000
    n_interactions: int = 500000
    embedding_dim: int = 128
    n_categories: int = 15
    user_feature_dim: int = 64
    item_feature_dim: int = 96


def generate_streamrec_interactions(
    config: StreamRecCatalog, seed: int = 42
) -> Dict[str, np.ndarray]:
    """Generate synthetic StreamRec interaction data with realistic structure.

    Users and items have latent category preferences. Interactions are
    generated based on user-category affinity and item-category membership,
    simulating the collaborative filtering signal that the two-tower
    model must learn.

    Args:
        config: Catalog configuration.
        seed: Random seed.

    Returns:
        Dictionary with user features, item features, interaction pairs,
        and temporal split indices.
    """
    rng = np.random.RandomState(seed)

    # User features: profile embedding + category preferences
    user_profiles = rng.randn(config.n_users, config.user_feature_dim).astype(np.float32)
    user_category_prefs = np.zeros((config.n_users, config.n_categories), dtype=np.float32)
    for u in range(config.n_users):
        # Each user has 2-4 preferred categories
        n_prefs = rng.randint(2, 5)
        preferred = rng.choice(config.n_categories, n_prefs, replace=False)
        user_category_prefs[u, preferred] = rng.uniform(0.5, 2.0, n_prefs)

    # Item features: content embedding + category one-hot
    item_content = rng.randn(config.n_items, config.item_feature_dim).astype(np.float32)
    item_categories = rng.randint(0, config.n_categories, config.n_items)
    item_category_onehot = np.eye(config.n_categories, dtype=np.float32)[item_categories]

    # Generate interactions based on user-item affinity
    user_ids = []
    item_ids = []
    for _ in range(config.n_interactions):
        u = rng.randint(config.n_users)
        # Score each item by user's preference for its category
        scores = user_category_prefs[u, item_categories]
        # Add noise and sample
        scores += rng.gumbel(size=config.n_items) * 0.5
        item = scores.argmax()
        user_ids.append(u)
        item_ids.append(item)

    interactions = np.stack([user_ids, item_ids], axis=1)

    # Temporal split: first 80% train, last 20% test
    split_idx = int(0.8 * len(interactions))

    return {
        "user_profiles": user_profiles,
        "user_category_prefs": user_category_prefs,
        "item_content": item_content,
        "item_category_onehot": item_category_onehot,
        "item_categories": item_categories,
        "train_interactions": interactions[:split_idx],
        "test_interactions": interactions[split_idx:],
    }

Building the Two-Tower Model

The model uses two independent towers that project user and item features into a shared 128-dimensional embedding space:

class UserTower(nn.Module):
    """User encoder for StreamRec two-tower retrieval.

    Encodes user profile features and category preferences into a
    dense embedding vector. In production, this would use a pretrained
    transformer over the user's watch history; here we use a simpler
    MLP for clarity.

    Args:
        input_dim: User feature dimensionality.
        hidden_dim: Hidden layer size.
        output_dim: Embedding dimensionality.
    """

    def __init__(
        self, input_dim: int = 79, hidden_dim: int = 256, output_dim: int = 128
    ) -> None:
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return F.normalize(self.network(x), dim=-1)


class ItemTower(nn.Module):
    """Item encoder for StreamRec two-tower retrieval.

    Encodes item content features and category into a dense
    embedding vector. In production, this would use a pretrained
    sentence transformer over item descriptions.

    Args:
        input_dim: Item feature dimensionality.
        hidden_dim: Hidden layer size.
        output_dim: Embedding dimensionality.
    """

    def __init__(
        self, input_dim: int = 111, hidden_dim: int = 256, output_dim: int = 128
    ) -> None:
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return F.normalize(self.network(x), dim=-1)


class StreamRecTwoTower(nn.Module):
    """Two-tower retrieval model with in-batch contrastive loss.

    Each (user, item) pair in the batch is a positive. All other items
    in the batch serve as negatives. The symmetric InfoNCE loss trains
    both towers to produce aligned embeddings.

    Args:
        user_dim: User feature dimensionality.
        item_dim: Item feature dimensionality.
        embedding_dim: Shared embedding space dimensionality.
        temperature: Softmax temperature for contrastive loss.
    """

    def __init__(
        self,
        user_dim: int = 79,
        item_dim: int = 111,
        embedding_dim: int = 128,
        temperature: float = 0.05,
    ) -> None:
        super().__init__()
        self.user_tower = UserTower(user_dim, 256, embedding_dim)
        self.item_tower = ItemTower(item_dim, 256, embedding_dim)
        self.temperature = temperature

    def forward(
        self, user_features: torch.Tensor, item_features: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Compute embeddings and contrastive loss.

        Args:
            user_features: (batch_size, user_dim).
            item_features: (batch_size, item_dim).

        Returns:
            Tuple of (loss, user_embeddings, item_embeddings).
        """
        user_emb = self.user_tower(user_features)  # (B, d)
        item_emb = self.item_tower(item_features)  # (B, d)

        # Similarity matrix with temperature scaling
        logits = torch.mm(user_emb, item_emb.T) / self.temperature  # (B, B)
        labels = torch.arange(logits.size(0), device=logits.device)

        # Symmetric loss
        loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2
        return loss, user_emb, item_emb

Training and Evaluation

def train_and_evaluate(
    data: Dict[str, np.ndarray],
    config: StreamRecCatalog,
    epochs: int = 20,
    batch_size: int = 512,
    learning_rate: float = 3e-4,
) -> Dict[str, float]:
    """Train the two-tower model and evaluate retrieval quality.

    Args:
        data: Output of generate_streamrec_interactions.
        config: Catalog configuration.
        epochs: Number of training epochs.
        batch_size: Training batch size (larger = more negatives).
        learning_rate: Learning rate.

    Returns:
        Dictionary of evaluation metrics.
    """
    user_dim = config.user_feature_dim + config.n_categories  # 64 + 15 = 79
    item_dim = config.item_feature_dim + config.n_categories   # 96 + 15 = 111

    model = StreamRecTwoTower(user_dim, item_dim, config.embedding_dim)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    # Prepare training features
    train_pairs = data["train_interactions"]
    user_feats = np.concatenate([
        data["user_profiles"], data["user_category_prefs"]
    ], axis=1)
    item_feats = np.concatenate([
        data["item_content"], data["item_category_onehot"]
    ], axis=1)

    train_user_feats = torch.tensor(user_feats[train_pairs[:, 0]])
    train_item_feats = torch.tensor(item_feats[train_pairs[:, 1]])

    dataset = TensorDataset(train_user_feats, train_item_feats)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

    # Training loop
    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        for user_batch, item_batch in loader:
            optimizer.zero_grad()
            loss, _, _ = model(user_batch, item_batch)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            total_loss += loss.item()

        scheduler.step()
        avg_loss = total_loss / len(loader)
        if (epoch + 1) % 5 == 0:
            print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

    # Evaluation: embed all items, compute metrics on test set
    model.eval()
    with torch.no_grad():
        all_item_feats = torch.tensor(item_feats)
        all_item_emb = model.item_tower(all_item_feats)  # (n_items, d)

        test_pairs = data["test_interactions"]
        test_user_feats = torch.tensor(user_feats[test_pairs[:, 0]])
        test_user_emb = model.user_tower(test_user_feats)  # (n_test, d)
        test_true_items = torch.tensor(test_pairs[:, 1])

    # Compute retrieval metrics
    # For efficiency, evaluate on a subsample of 5,000 test interactions
    n_eval = min(5000, len(test_pairs))
    eval_user_emb = test_user_emb[:n_eval]
    eval_true_items = test_true_items[:n_eval]

    similarities = torch.mm(eval_user_emb, all_item_emb.T)  # (n_eval, n_items)
    _, top_indices = similarities.topk(100, dim=1)

    # Hit Rate@K
    metrics = {}
    for k in [10, 50, 100]:
        hits = (top_indices[:, :k] == eval_true_items.unsqueeze(1)).any(dim=1)
        metrics[f"HR@{k}"] = hits.float().mean().item()

    # MRR
    ranks = (top_indices == eval_true_items.unsqueeze(1)).nonzero(as_tuple=True)[1] + 1
    if len(ranks) > 0:
        metrics["MRR"] = (1.0 / ranks.float()).mean().item()
    else:
        metrics["MRR"] = 0.0

    return metrics

Results and Analysis

Running the training pipeline on the synthetic StreamRec data:

Epoch 5/20, Loss: 4.2187
Epoch 10/20, Loss: 3.1543
Epoch 15/20, Loss: 2.4891
Epoch 20/20, Loss: 2.0234

Retrieval Metrics (catalog size: 20,000 items):
  HR@10:  0.142
  HR@50:  0.301
  HR@100: 0.387
  MRR:    0.083

These results demonstrate the two-tower model's ability to learn meaningful user-item correspondences from engagement data alone. The HR@100 of 0.387 means that for nearly 40% of test interactions, the true engaged item appears in the top 100 out of 20,000 candidates — a 200x improvement over random retrieval (HR@100 = 0.005 for random).

Batch Size Sensitivity

The number of in-batch negatives — determined by batch size — is critical for contrastive learning:

Batch Size Effective Negatives HR@100 Training Time
64 63 0.218 1x
256 255 0.341 1.3x
512 511 0.387 1.6x
1024 1023 0.402 2.2x
2048 2047 0.408 3.1x

Larger batches improve performance through the mutual information bound (Section 13.6), but with diminishing returns above batch size 1024 — consistent with the theoretical saturation at $\log K$.

Temperature Analysis

Temperature $\tau$ controls the sharpness of the softmax distribution in the contrastive loss:

Temperature HR@100 Embedding Uniformity Training Stability
0.01 0.312 High Unstable (loss spikes)
0.05 0.387 Medium-high Stable
0.10 0.371 Medium Stable
0.50 0.289 Low Very stable

Too-low temperature focuses the loss on the hardest negatives, causing instability. Too-high temperature treats all negatives equally, failing to discriminate between genuinely similar and dissimilar items. The sweet spot ($\tau = 0.05$) balances discrimination and stability.

Deployment Architecture

In production, the two-tower model enables sub-10ms retrieval:

graph LR
    A["User Request"] --> B["User Tower<br/>(~2ms)"]
    B --> C["FAISS Search<br/>(~3ms)"]
    C --> D["Top-100<br/>Candidates"]
    D --> E["Ranking Model<br/>(Ch. 10 Transformer)"]
    E --> F["Top-10<br/>Recommendations"]

    G["Daily Batch Job"] --> H["Item Tower<br/>(all 200K items)"]
    H --> I["FAISS Index<br/>(rebuild)"]
    I --> C

The item embeddings are recomputed daily (or when new items are added) and stored in a FAISS IndexIVFFlat with 256 Voronoi cells and 32 probes. For 200,000 items with 128-dimensional float32 embeddings, the index requires approximately 100 MB of memory.

Lessons Learned

  1. Contrastive learning quality depends on batch size. The team initially trained with batch size 64 (standard for supervised learning) and saw poor retrieval quality. Increasing to 512 gave a 77% relative improvement in HR@100.

  2. Temperature tuning is not optional. Default values ($\tau = 0.07$ from CLIP, $\tau = 0.5$ from SimCLR) are starting points, not universal constants. The optimal temperature depends on the embedding dimensionality, the number of negatives, and the difficulty of the retrieval task.

  3. The two-tower architecture constrains expressiveness. Because user and item are encoded independently, the model cannot capture fine-grained interactions (e.g., "this user likes jazz documentaries but not jazz concerts"). The ranking model downstream handles these interactions. The two-tower model's job is coverage, not precision.

  4. Cold-start items are hard. Items with no engagement history rely entirely on content features (title, description, category). The quality of these features — and the pretrained encoder used to embed them — directly determines cold-start retrieval quality. This is where the sentence transformer encoders from Section 13.5 add the most value.