Case Study 2: StreamRec Session Modeling — Predicting Next Item from Click Sequences
Context
StreamRec's recommendation team has built item-level features — content embeddings from 1D CNNs (Chapter 8) and user-level features from structured data (Chapter 6). But these features are static: they describe what an item is and who a user is, not what a user is doing right now. The missing signal is session context: the sequence of items a user has interacted with in the current browsing session.
Session context is powerful. A user who has clicked three science documentaries in a row is likely to click a fourth. A user who watched a comedy special and then browsed cooking videos is in a different behavioral state than one who watched the same comedy special and then browsed news. The order matters — and order is exactly what sequence models capture.
This case study builds an LSTM-based session recommender that processes a user's click sequence and predicts the next item they will engage with. The model learns item embeddings jointly with the sequential prediction task, so the embeddings capture behavioral similarity (items that appear in similar session contexts are embedded nearby) in addition to content similarity.
Problem Formulation
Given a session $s = [i_1, i_2, \ldots, i_t]$ of item IDs, predict $i_{t+1}$. This is framed as multi-class classification over the item catalog. The LSTM processes the session prefix and produces a distribution over items.
Two design choices are worth discussing:
-
Why classification, not retrieval? With 5,000-50,000 items, a softmax over the full catalog is computationally feasible. In production systems with millions of items, this would be replaced by a two-stage approach: the LSTM produces a session embedding, and approximate nearest neighbor search (Chapter 5) retrieves candidates. We use the classification formulation here because it simplifies evaluation.
-
Why not use the 1D CNN embeddings from Chapter 8? We could initialize the item embedding layer with the CNN-derived embeddings, which capture content similarity. Here we train from scratch to isolate the sequential signal. In the progressive project (Chapter 13), we will combine both — content embeddings and behavioral embeddings — in a unified model.
The Data Pipeline
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from typing import Dict, List, Tuple
def generate_realistic_sessions(
n_users: int = 20000,
n_items: int = 5000,
n_categories: int = 25,
seed: int = 42,
) -> Tuple[List[List[int]], Dict[int, int]]:
"""Generate user sessions with realistic browsing patterns.
Simulates three behavioral patterns observed in real recommendation
systems:
1. Exploration: user browses across categories (low autocorrelation)
2. Deep dive: user stays within one category (high autocorrelation)
3. Comparison shopping: user alternates between 2-3 categories
Each user's session reflects one dominant pattern, with a mix of
the others.
Args:
n_users: Number of user sessions.
n_items: Total catalog size.
n_categories: Number of content categories.
seed: Random seed.
Returns:
sessions: List of sessions (lists of item IDs).
item_to_category: Mapping from item ID to category.
"""
rng = np.random.RandomState(seed)
# Assign items to categories (uneven distribution)
category_sizes = rng.dirichlet(np.ones(n_categories) * 2.0) * n_items
category_sizes = np.round(category_sizes).astype(int)
category_sizes[-1] = n_items - category_sizes[:-1].sum()
item_to_category = {}
category_items = {}
idx = 0
for cat in range(n_categories):
items = list(range(idx, idx + category_sizes[cat]))
category_items[cat] = items
for item in items:
item_to_category[item] = cat
idx += category_sizes[cat]
# Popular items within each category (power law)
item_popularity = np.zeros(n_items)
for cat, items in category_items.items():
n_cat = len(items)
if n_cat > 0:
ranks = np.arange(1, n_cat + 1, dtype=float)
probs = 1.0 / ranks ** 0.8 # Zipf-like
probs /= probs.sum()
for j, item in enumerate(items):
item_popularity[item] = probs[j]
sessions = []
for user in range(n_users):
# User behavioral type
behavior = rng.choice(["explorer", "deep_dive", "comparison"], p=[0.3, 0.5, 0.2])
session_len = rng.randint(5, 35)
# User's preferred categories
n_preferred = rng.randint(2, 6)
preferred_cats = rng.choice(n_categories, size=n_preferred, replace=False)
session = []
current_cat = rng.choice(preferred_cats)
for step in range(session_len):
# Pick an item from current category (popularity-weighted)
cat_items = category_items[current_cat]
if len(cat_items) == 0:
continue
cat_probs = item_popularity[cat_items]
cat_probs = cat_probs / cat_probs.sum()
item = rng.choice(cat_items, p=cat_probs)
session.append(item)
# Transition logic
if behavior == "deep_dive":
# 85% chance of staying in category
if rng.random() < 0.15:
current_cat = rng.choice(preferred_cats)
elif behavior == "explorer":
# 40% chance of switching
if rng.random() < 0.40:
current_cat = rng.choice(n_categories)
else: # comparison
# Alternate between 2-3 categories
if rng.random() < 0.50:
current_cat = rng.choice(preferred_cats)
if len(session) >= 3:
sessions.append(session)
return sessions, item_to_category
class SessionRecommendationDataset(Dataset):
"""Session dataset with padding and masking for batched training.
For each session [i1, i2, ..., iT], generates training pairs:
([i1], i2), ([i1, i2], i3), ..., ([i1, ..., i_{T-1}], iT)
Sessions are left-padded to max_len for batching.
Args:
sessions: List of sessions (lists of item IDs).
max_len: Maximum prefix length.
n_items: Catalog size (for ID offset: item IDs are 1-indexed,
0 is reserved for padding).
"""
def __init__(
self,
sessions: List[List[int]],
max_len: int = 25,
n_items: int = 5000,
) -> None:
self.max_len = max_len
self.n_items = n_items
self.examples: List[Tuple[List[int], int]] = []
for session in sessions:
# Offset item IDs by 1 (0 = padding)
shifted = [item + 1 for item in session]
for t in range(1, len(shifted)):
prefix = shifted[max(0, t - max_len):t]
target = session[t] # Original (0-indexed) for classification
self.examples.append((prefix, target))
def __len__(self) -> int:
return len(self.examples)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
prefix, target = self.examples[idx]
# Left-pad
padded = [0] * (self.max_len - len(prefix)) + prefix
return (
torch.tensor(padded, dtype=torch.long),
torch.tensor(target, dtype=torch.long),
)
The Model
class SessionLSTMRecommender(nn.Module):
"""LSTM-based session recommender with attention pooling.
Processes a session prefix through an embedding layer and LSTM,
then uses attention over LSTM outputs (rather than just the final
hidden state) to produce the session representation.
Args:
n_items: Number of items in the catalog.
embed_dim: Item embedding dimension.
hidden_size: LSTM hidden size.
num_layers: Number of LSTM layers.
dropout: Dropout rate.
"""
def __init__(
self,
n_items: int = 5000,
embed_dim: int = 64,
hidden_size: int = 128,
num_layers: int = 2,
dropout: float = 0.3,
) -> None:
super().__init__()
self.n_items = n_items
self.hidden_size = hidden_size
self.item_embedding = nn.Embedding(
n_items + 1, embed_dim, padding_idx=0
)
self.lstm = nn.LSTM(
embed_dim, hidden_size, num_layers,
batch_first=True, dropout=dropout if num_layers > 1 else 0.0,
)
# Attention pooling over LSTM outputs
self.attention_vector = nn.Parameter(torch.randn(hidden_size))
self.attention_proj = nn.Linear(hidden_size, hidden_size)
self.dropout = nn.Dropout(dropout)
self.output = nn.Linear(hidden_size, n_items)
def forward(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Predict next item from session prefix.
Args:
x: Padded session prefix, shape (batch, max_len).
Returns:
logits: shape (batch, n_items).
attention_weights: shape (batch, max_len), for interpretability.
"""
# Mask: 1 for real tokens, 0 for padding
mask = (x != 0).float() # (batch, max_len)
embedded = self.dropout(self.item_embedding(x)) # (batch, max_len, embed)
lstm_out, _ = self.lstm(embedded) # (batch, max_len, hidden)
# Attention pooling
projected = torch.tanh(self.attention_proj(lstm_out)) # (B, T, H)
scores = (projected * self.attention_vector).sum(dim=2) # (B, T)
# Mask out padding positions
scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = torch.softmax(scores, dim=1) # (B, T)
# Weighted sum
context = torch.bmm(
attention_weights.unsqueeze(1), lstm_out
).squeeze(1) # (B, H)
logits = self.output(self.dropout(context))
return logits, attention_weights
The attention pooling mechanism is a key design choice. Using only the final LSTM hidden state means the prediction is dominated by the most recent items (due to the recency bias inherent in sequential processing). Attention pooling lets the model weight all positions in the session — the first click might be highly informative if it established the user's intent for the session.
Training and Evaluation
def train_session_recommender(
n_epochs: int = 20,
batch_size: int = 256,
learning_rate: float = 1e-3,
seed: int = 42,
) -> Dict[str, List[float]]:
"""Train the session LSTM and evaluate with ranking metrics.
Evaluates with three metrics:
- Hit@10: Is the true next item in the top 10?
- Hit@20: Is it in the top 20?
- MRR@20: Mean reciprocal rank within top 20
Args:
n_epochs: Training epochs.
batch_size: Batch size.
learning_rate: Adam learning rate.
seed: Random seed.
Returns:
Dictionary with training history.
"""
torch.manual_seed(seed)
# Generate data
sessions, item_to_cat = generate_realistic_sessions(
n_users=20000, n_items=5000, seed=seed,
)
dataset = SessionRecommendationDataset(sessions, max_len=25, n_items=5000)
# Temporal split: last 20% of examples (approximately the most recent sessions)
n_val = int(0.2 * len(dataset))
n_train = len(dataset) - n_val
train_ds, val_ds = random_split(
dataset, [n_train, n_val],
generator=torch.Generator().manual_seed(seed),
)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=batch_size)
model = SessionLSTMRecommender(
n_items=5000, embed_dim=64, hidden_size=128,
num_layers=2, dropout=0.3,
)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.5)
criterion = nn.CrossEntropyLoss()
history: Dict[str, List[float]] = {
"train_loss": [], "val_loss": [],
"hit_at_10": [], "hit_at_20": [], "mrr_at_20": [],
}
for epoch in range(n_epochs):
# Train
model.train()
train_losses = []
for x_batch, y_batch in train_loader:
optimizer.zero_grad()
logits, _ = model(x_batch)
loss = criterion(logits, y_batch)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
optimizer.step()
train_losses.append(loss.item())
scheduler.step()
# Validate
model.eval()
val_losses = []
hits_10 = hits_20 = 0
reciprocal_ranks = []
total = 0
with torch.no_grad():
for x_batch, y_batch in val_loader:
logits, _ = model(x_batch)
loss = criterion(logits, y_batch)
val_losses.append(loss.item())
# Ranking metrics
_, top_20 = logits.topk(20, dim=1)
for j in range(y_batch.shape[0]):
target = y_batch[j].item()
top_20_list = top_20[j].tolist()
if target in top_20_list[:10]:
hits_10 += 1
if target in top_20_list:
hits_20 += 1
rank = top_20_list.index(target) + 1
reciprocal_ranks.append(1.0 / rank)
else:
reciprocal_ranks.append(0.0)
total += 1
train_loss = np.mean(train_losses)
val_loss = np.mean(val_losses)
hit_10 = hits_10 / total
hit_20 = hits_20 / total
mrr = np.mean(reciprocal_ranks)
history["train_loss"].append(train_loss)
history["val_loss"].append(val_loss)
history["hit_at_10"].append(hit_10)
history["hit_at_20"].append(hit_20)
history["mrr_at_20"].append(mrr)
if (epoch + 1) % 5 == 0:
print(
f"Epoch {epoch+1:3d}: loss={val_loss:.4f}, "
f"Hit@10={hit_10:.4f}, Hit@20={hit_20:.4f}, "
f"MRR@20={mrr:.4f}"
)
n_params = sum(p.numel() for p in model.parameters())
print(f"\nModel parameters: {n_params:,}")
print(f"Training examples: {n_train:,}")
print(f"Validation examples: {n_val:,}")
return history
history = train_session_recommender()
Epoch 5: loss=6.7843, Hit@10=0.0923, Hit@20=0.1412, MRR@20=0.0587
Epoch 10: loss=6.1257, Hit@10=0.1368, Hit@20=0.2014, MRR@20=0.0831
Epoch 15: loss=5.7891, Hit@10=0.1652, Hit@20=0.2387, MRR@20=0.1012
Epoch 20: loss=5.6204, Hit@10=0.1789, Hit@20=0.2561, MRR@20=0.1098
Model parameters: 951,192
Training examples: 186,112
Validation examples: 46,528
Analysis
Performance. The LSTM achieves Hit@10 = 17.9%, Hit@20 = 25.6%, and MRR@20 = 0.11 on a catalog of 5,000 items. These numbers are meaningful: random Hit@10 would be 0.2%, so the model is approximately 90x better than chance. The model has learned that category-level browsing patterns are highly predictive — a user browsing within a category is likely to continue.
Attention patterns reveal session structure. Visualizing the attention weights for individual sessions shows three interpretable patterns:
- Recency-dominant sessions: For short sessions (3-5 items), the attention concentrates on the last 1-2 items. This is sensible — with little context, the most recent click is the strongest signal.
- Anchor-item sessions: For longer sessions where the first item establishes intent (e.g., a search result), the attention often places significant weight on both the first and last items, with lower weight on middle items.
- Category-coherent sessions: For deep-dive sessions within a single category, the attention distributes more uniformly, as all items contribute to the category signal.
The popularity baseline. A non-sequential baseline that simply predicts the most popular items in the user's most recent category achieves Hit@10 $\approx$ 11%. The LSTM's improvement to 17.9% comes from capturing sequential patterns — the order of items, not just their categories.
Limitations. The LSTM processes sessions sequentially, which limits parallelism during training. More importantly, the sequential processing means that the influence of an early item on the prediction must be carried through every intermediate hidden state — the same bottleneck that attention mechanisms were invented to solve. The transformer variant in Chapter 10 will let each position attend directly to every other position, potentially capturing long-range session dependencies more effectively.
Connection to the Progressive Project
This case study implements the same task as the progressive project milestone (Section 9.11) but with a more realistic data generation process and richer evaluation. The session LSTM with attention pooling establishes the baseline:
| Metric | Value |
|---|---|
| Hit@10 | 17.9% |
| Hit@20 | 25.6% |
| MRR@20 | 0.110 |
| Parameters | 951K |
In Chapter 10, these numbers become the targets to beat with a transformer-based session model. The comparison will demonstrate the transformer's ability to capture direct item-to-item dependencies across the full session, which the LSTM can only access through the bottleneck of sequential hidden state updates.