Case Study 1: StreamRec Two-Tower Retrieval — Contrastive Learning for User-Item Matching
Context
StreamRec's recommendation pipeline has grown through the progressive project milestones: matrix factorization for baseline collaborative filtering (M0, Chapter 1), a click-prediction MLP (M2, Chapter 6), 1D CNN content embeddings (M3, Chapter 8), an LSTM session model (M4a, Chapter 9), and a transformer session model replacing the LSTM (M4, Chapter 10). Each milestone improved prediction accuracy, but they all share a fundamental limitation: they are ranking models that score individual items, not retrieval models that efficiently search a catalog.
The ranking models from Chapters 6-10 assume a small candidate set has already been selected. In production, StreamRec's catalog contains 200,000 items. Scoring every item for every user request — even with a fast MLP — takes ~200ms at 200K forward passes, far exceeding the 50ms latency budget for the retrieval stage. The platform needs a retrieval model that can find the top-100 most relevant items from the full catalog in under 10ms.
The two-tower architecture solves this: precompute item embeddings offline, store them in a FAISS index, and compute only one user embedding per request. Retrieval becomes an approximate nearest-neighbor search — sublinear in catalog size.
The Data
StreamRec's engagement data consists of user-item interactions with engagement signals:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from typing import Dict, List, Tuple
from dataclasses import dataclass
@dataclass
class StreamRecCatalog:
"""Simplified StreamRec catalog for the two-tower case study."""
n_users: int = 50000
n_items: int = 20000
n_interactions: int = 500000
embedding_dim: int = 128
n_categories: int = 15
user_feature_dim: int = 64
item_feature_dim: int = 96
def generate_streamrec_interactions(
config: StreamRecCatalog, seed: int = 42
) -> Dict[str, np.ndarray]:
"""Generate synthetic StreamRec interaction data with realistic structure.
Users and items have latent category preferences. Interactions are
generated based on user-category affinity and item-category membership,
simulating the collaborative filtering signal that the two-tower
model must learn.
Args:
config: Catalog configuration.
seed: Random seed.
Returns:
Dictionary with user features, item features, interaction pairs,
and temporal split indices.
"""
rng = np.random.RandomState(seed)
# User features: profile embedding + category preferences
user_profiles = rng.randn(config.n_users, config.user_feature_dim).astype(np.float32)
user_category_prefs = np.zeros((config.n_users, config.n_categories), dtype=np.float32)
for u in range(config.n_users):
# Each user has 2-4 preferred categories
n_prefs = rng.randint(2, 5)
preferred = rng.choice(config.n_categories, n_prefs, replace=False)
user_category_prefs[u, preferred] = rng.uniform(0.5, 2.0, n_prefs)
# Item features: content embedding + category one-hot
item_content = rng.randn(config.n_items, config.item_feature_dim).astype(np.float32)
item_categories = rng.randint(0, config.n_categories, config.n_items)
item_category_onehot = np.eye(config.n_categories, dtype=np.float32)[item_categories]
# Generate interactions based on user-item affinity
user_ids = []
item_ids = []
for _ in range(config.n_interactions):
u = rng.randint(config.n_users)
# Score each item by user's preference for its category
scores = user_category_prefs[u, item_categories]
# Add noise and sample
scores += rng.gumbel(size=config.n_items) * 0.5
item = scores.argmax()
user_ids.append(u)
item_ids.append(item)
interactions = np.stack([user_ids, item_ids], axis=1)
# Temporal split: first 80% train, last 20% test
split_idx = int(0.8 * len(interactions))
return {
"user_profiles": user_profiles,
"user_category_prefs": user_category_prefs,
"item_content": item_content,
"item_category_onehot": item_category_onehot,
"item_categories": item_categories,
"train_interactions": interactions[:split_idx],
"test_interactions": interactions[split_idx:],
}
Building the Two-Tower Model
The model uses two independent towers that project user and item features into a shared 128-dimensional embedding space:
class UserTower(nn.Module):
"""User encoder for StreamRec two-tower retrieval.
Encodes user profile features and category preferences into a
dense embedding vector. In production, this would use a pretrained
transformer over the user's watch history; here we use a simpler
MLP for clarity.
Args:
input_dim: User feature dimensionality.
hidden_dim: Hidden layer size.
output_dim: Embedding dimensionality.
"""
def __init__(
self, input_dim: int = 79, hidden_dim: int = 256, output_dim: int = 128
) -> None:
super().__init__()
self.network = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, output_dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.normalize(self.network(x), dim=-1)
class ItemTower(nn.Module):
"""Item encoder for StreamRec two-tower retrieval.
Encodes item content features and category into a dense
embedding vector. In production, this would use a pretrained
sentence transformer over item descriptions.
Args:
input_dim: Item feature dimensionality.
hidden_dim: Hidden layer size.
output_dim: Embedding dimensionality.
"""
def __init__(
self, input_dim: int = 111, hidden_dim: int = 256, output_dim: int = 128
) -> None:
super().__init__()
self.network = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, output_dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.normalize(self.network(x), dim=-1)
class StreamRecTwoTower(nn.Module):
"""Two-tower retrieval model with in-batch contrastive loss.
Each (user, item) pair in the batch is a positive. All other items
in the batch serve as negatives. The symmetric InfoNCE loss trains
both towers to produce aligned embeddings.
Args:
user_dim: User feature dimensionality.
item_dim: Item feature dimensionality.
embedding_dim: Shared embedding space dimensionality.
temperature: Softmax temperature for contrastive loss.
"""
def __init__(
self,
user_dim: int = 79,
item_dim: int = 111,
embedding_dim: int = 128,
temperature: float = 0.05,
) -> None:
super().__init__()
self.user_tower = UserTower(user_dim, 256, embedding_dim)
self.item_tower = ItemTower(item_dim, 256, embedding_dim)
self.temperature = temperature
def forward(
self, user_features: torch.Tensor, item_features: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute embeddings and contrastive loss.
Args:
user_features: (batch_size, user_dim).
item_features: (batch_size, item_dim).
Returns:
Tuple of (loss, user_embeddings, item_embeddings).
"""
user_emb = self.user_tower(user_features) # (B, d)
item_emb = self.item_tower(item_features) # (B, d)
# Similarity matrix with temperature scaling
logits = torch.mm(user_emb, item_emb.T) / self.temperature # (B, B)
labels = torch.arange(logits.size(0), device=logits.device)
# Symmetric loss
loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2
return loss, user_emb, item_emb
Training and Evaluation
def train_and_evaluate(
data: Dict[str, np.ndarray],
config: StreamRecCatalog,
epochs: int = 20,
batch_size: int = 512,
learning_rate: float = 3e-4,
) -> Dict[str, float]:
"""Train the two-tower model and evaluate retrieval quality.
Args:
data: Output of generate_streamrec_interactions.
config: Catalog configuration.
epochs: Number of training epochs.
batch_size: Training batch size (larger = more negatives).
learning_rate: Learning rate.
Returns:
Dictionary of evaluation metrics.
"""
user_dim = config.user_feature_dim + config.n_categories # 64 + 15 = 79
item_dim = config.item_feature_dim + config.n_categories # 96 + 15 = 111
model = StreamRecTwoTower(user_dim, item_dim, config.embedding_dim)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
# Prepare training features
train_pairs = data["train_interactions"]
user_feats = np.concatenate([
data["user_profiles"], data["user_category_prefs"]
], axis=1)
item_feats = np.concatenate([
data["item_content"], data["item_category_onehot"]
], axis=1)
train_user_feats = torch.tensor(user_feats[train_pairs[:, 0]])
train_item_feats = torch.tensor(item_feats[train_pairs[:, 1]])
dataset = TensorDataset(train_user_feats, train_item_feats)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
# Training loop
for epoch in range(epochs):
model.train()
total_loss = 0.0
for user_batch, item_batch in loader:
optimizer.zero_grad()
loss, _, _ = model(user_batch, item_batch)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
scheduler.step()
avg_loss = total_loss / len(loader)
if (epoch + 1) % 5 == 0:
print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
# Evaluation: embed all items, compute metrics on test set
model.eval()
with torch.no_grad():
all_item_feats = torch.tensor(item_feats)
all_item_emb = model.item_tower(all_item_feats) # (n_items, d)
test_pairs = data["test_interactions"]
test_user_feats = torch.tensor(user_feats[test_pairs[:, 0]])
test_user_emb = model.user_tower(test_user_feats) # (n_test, d)
test_true_items = torch.tensor(test_pairs[:, 1])
# Compute retrieval metrics
# For efficiency, evaluate on a subsample of 5,000 test interactions
n_eval = min(5000, len(test_pairs))
eval_user_emb = test_user_emb[:n_eval]
eval_true_items = test_true_items[:n_eval]
similarities = torch.mm(eval_user_emb, all_item_emb.T) # (n_eval, n_items)
_, top_indices = similarities.topk(100, dim=1)
# Hit Rate@K
metrics = {}
for k in [10, 50, 100]:
hits = (top_indices[:, :k] == eval_true_items.unsqueeze(1)).any(dim=1)
metrics[f"HR@{k}"] = hits.float().mean().item()
# MRR
ranks = (top_indices == eval_true_items.unsqueeze(1)).nonzero(as_tuple=True)[1] + 1
if len(ranks) > 0:
metrics["MRR"] = (1.0 / ranks.float()).mean().item()
else:
metrics["MRR"] = 0.0
return metrics
Results and Analysis
Running the training pipeline on the synthetic StreamRec data:
Epoch 5/20, Loss: 4.2187
Epoch 10/20, Loss: 3.1543
Epoch 15/20, Loss: 2.4891
Epoch 20/20, Loss: 2.0234
Retrieval Metrics (catalog size: 20,000 items):
HR@10: 0.142
HR@50: 0.301
HR@100: 0.387
MRR: 0.083
These results demonstrate the two-tower model's ability to learn meaningful user-item correspondences from engagement data alone. The HR@100 of 0.387 means that for nearly 40% of test interactions, the true engaged item appears in the top 100 out of 20,000 candidates — a 200x improvement over random retrieval (HR@100 = 0.005 for random).
Batch Size Sensitivity
The number of in-batch negatives — determined by batch size — is critical for contrastive learning:
| Batch Size | Effective Negatives | HR@100 | Training Time |
|---|---|---|---|
| 64 | 63 | 0.218 | 1x |
| 256 | 255 | 0.341 | 1.3x |
| 512 | 511 | 0.387 | 1.6x |
| 1024 | 1023 | 0.402 | 2.2x |
| 2048 | 2047 | 0.408 | 3.1x |
Larger batches improve performance through the mutual information bound (Section 13.6), but with diminishing returns above batch size 1024 — consistent with the theoretical saturation at $\log K$.
Temperature Analysis
Temperature $\tau$ controls the sharpness of the softmax distribution in the contrastive loss:
| Temperature | HR@100 | Embedding Uniformity | Training Stability |
|---|---|---|---|
| 0.01 | 0.312 | High | Unstable (loss spikes) |
| 0.05 | 0.387 | Medium-high | Stable |
| 0.10 | 0.371 | Medium | Stable |
| 0.50 | 0.289 | Low | Very stable |
Too-low temperature focuses the loss on the hardest negatives, causing instability. Too-high temperature treats all negatives equally, failing to discriminate between genuinely similar and dissimilar items. The sweet spot ($\tau = 0.05$) balances discrimination and stability.
Deployment Architecture
In production, the two-tower model enables sub-10ms retrieval:
graph LR
A["User Request"] --> B["User Tower<br/>(~2ms)"]
B --> C["FAISS Search<br/>(~3ms)"]
C --> D["Top-100<br/>Candidates"]
D --> E["Ranking Model<br/>(Ch. 10 Transformer)"]
E --> F["Top-10<br/>Recommendations"]
G["Daily Batch Job"] --> H["Item Tower<br/>(all 200K items)"]
H --> I["FAISS Index<br/>(rebuild)"]
I --> C
The item embeddings are recomputed daily (or when new items are added) and stored in a FAISS IndexIVFFlat with 256 Voronoi cells and 32 probes. For 200,000 items with 128-dimensional float32 embeddings, the index requires approximately 100 MB of memory.
Lessons Learned
-
Contrastive learning quality depends on batch size. The team initially trained with batch size 64 (standard for supervised learning) and saw poor retrieval quality. Increasing to 512 gave a 77% relative improvement in HR@100.
-
Temperature tuning is not optional. Default values ($\tau = 0.07$ from CLIP, $\tau = 0.5$ from SimCLR) are starting points, not universal constants. The optimal temperature depends on the embedding dimensionality, the number of negatives, and the difficulty of the retrieval task.
-
The two-tower architecture constrains expressiveness. Because user and item are encoded independently, the model cannot capture fine-grained interactions (e.g., "this user likes jazz documentaries but not jazz concerts"). The ranking model downstream handles these interactions. The two-tower model's job is coverage, not precision.
-
Cold-start items are hard. Items with no engagement history rely entirely on content features (title, description, category). The quality of these features — and the pretrained encoder used to embed them — directly determines cold-start retrieval quality. This is where the sentence transformer encoders from Section 13.5 add the most value.