Case Study 2: StreamRec Session Modeling — Predicting Next Item from Click Sequences

Context

StreamRec's recommendation team has built item-level features — content embeddings from 1D CNNs (Chapter 8) and user-level features from structured data (Chapter 6). But these features are static: they describe what an item is and who a user is, not what a user is doing right now. The missing signal is session context: the sequence of items a user has interacted with in the current browsing session.

Session context is powerful. A user who has clicked three science documentaries in a row is likely to click a fourth. A user who watched a comedy special and then browsed cooking videos is in a different behavioral state than one who watched the same comedy special and then browsed news. The order matters — and order is exactly what sequence models capture.

This case study builds an LSTM-based session recommender that processes a user's click sequence and predicts the next item they will engage with. The model learns item embeddings jointly with the sequential prediction task, so the embeddings capture behavioral similarity (items that appear in similar session contexts are embedded nearby) in addition to content similarity.

Problem Formulation

Given a session $s = [i_1, i_2, \ldots, i_t]$ of item IDs, predict $i_{t+1}$. This is framed as multi-class classification over the item catalog. The LSTM processes the session prefix and produces a distribution over items.

Two design choices are worth discussing:

  1. Why classification, not retrieval? With 5,000-50,000 items, a softmax over the full catalog is computationally feasible. In production systems with millions of items, this would be replaced by a two-stage approach: the LSTM produces a session embedding, and approximate nearest neighbor search (Chapter 5) retrieves candidates. We use the classification formulation here because it simplifies evaluation.

  2. Why not use the 1D CNN embeddings from Chapter 8? We could initialize the item embedding layer with the CNN-derived embeddings, which capture content similarity. Here we train from scratch to isolate the sequential signal. In the progressive project (Chapter 13), we will combine both — content embeddings and behavioral embeddings — in a unified model.

The Data Pipeline

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from typing import Dict, List, Tuple


def generate_realistic_sessions(
    n_users: int = 20000,
    n_items: int = 5000,
    n_categories: int = 25,
    seed: int = 42,
) -> Tuple[List[List[int]], Dict[int, int]]:
    """Generate user sessions with realistic browsing patterns.

    Simulates three behavioral patterns observed in real recommendation
    systems:
    1. Exploration: user browses across categories (low autocorrelation)
    2. Deep dive: user stays within one category (high autocorrelation)
    3. Comparison shopping: user alternates between 2-3 categories

    Each user's session reflects one dominant pattern, with a mix of
    the others.

    Args:
        n_users: Number of user sessions.
        n_items: Total catalog size.
        n_categories: Number of content categories.
        seed: Random seed.

    Returns:
        sessions: List of sessions (lists of item IDs).
        item_to_category: Mapping from item ID to category.
    """
    rng = np.random.RandomState(seed)

    # Assign items to categories (uneven distribution)
    category_sizes = rng.dirichlet(np.ones(n_categories) * 2.0) * n_items
    category_sizes = np.round(category_sizes).astype(int)
    category_sizes[-1] = n_items - category_sizes[:-1].sum()

    item_to_category = {}
    category_items = {}
    idx = 0
    for cat in range(n_categories):
        items = list(range(idx, idx + category_sizes[cat]))
        category_items[cat] = items
        for item in items:
            item_to_category[item] = cat
        idx += category_sizes[cat]

    # Popular items within each category (power law)
    item_popularity = np.zeros(n_items)
    for cat, items in category_items.items():
        n_cat = len(items)
        if n_cat > 0:
            ranks = np.arange(1, n_cat + 1, dtype=float)
            probs = 1.0 / ranks ** 0.8  # Zipf-like
            probs /= probs.sum()
            for j, item in enumerate(items):
                item_popularity[item] = probs[j]

    sessions = []
    for user in range(n_users):
        # User behavioral type
        behavior = rng.choice(["explorer", "deep_dive", "comparison"], p=[0.3, 0.5, 0.2])
        session_len = rng.randint(5, 35)

        # User's preferred categories
        n_preferred = rng.randint(2, 6)
        preferred_cats = rng.choice(n_categories, size=n_preferred, replace=False)

        session = []
        current_cat = rng.choice(preferred_cats)

        for step in range(session_len):
            # Pick an item from current category (popularity-weighted)
            cat_items = category_items[current_cat]
            if len(cat_items) == 0:
                continue
            cat_probs = item_popularity[cat_items]
            cat_probs = cat_probs / cat_probs.sum()
            item = rng.choice(cat_items, p=cat_probs)
            session.append(item)

            # Transition logic
            if behavior == "deep_dive":
                # 85% chance of staying in category
                if rng.random() < 0.15:
                    current_cat = rng.choice(preferred_cats)
            elif behavior == "explorer":
                # 40% chance of switching
                if rng.random() < 0.40:
                    current_cat = rng.choice(n_categories)
            else:  # comparison
                # Alternate between 2-3 categories
                if rng.random() < 0.50:
                    current_cat = rng.choice(preferred_cats)

        if len(session) >= 3:
            sessions.append(session)

    return sessions, item_to_category


class SessionRecommendationDataset(Dataset):
    """Session dataset with padding and masking for batched training.

    For each session [i1, i2, ..., iT], generates training pairs:
    ([i1], i2), ([i1, i2], i3), ..., ([i1, ..., i_{T-1}], iT)

    Sessions are left-padded to max_len for batching.

    Args:
        sessions: List of sessions (lists of item IDs).
        max_len: Maximum prefix length.
        n_items: Catalog size (for ID offset: item IDs are 1-indexed,
                 0 is reserved for padding).
    """

    def __init__(
        self,
        sessions: List[List[int]],
        max_len: int = 25,
        n_items: int = 5000,
    ) -> None:
        self.max_len = max_len
        self.n_items = n_items
        self.examples: List[Tuple[List[int], int]] = []

        for session in sessions:
            # Offset item IDs by 1 (0 = padding)
            shifted = [item + 1 for item in session]
            for t in range(1, len(shifted)):
                prefix = shifted[max(0, t - max_len):t]
                target = session[t]  # Original (0-indexed) for classification
                self.examples.append((prefix, target))

    def __len__(self) -> int:
        return len(self.examples)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        prefix, target = self.examples[idx]
        # Left-pad
        padded = [0] * (self.max_len - len(prefix)) + prefix
        return (
            torch.tensor(padded, dtype=torch.long),
            torch.tensor(target, dtype=torch.long),
        )

The Model

class SessionLSTMRecommender(nn.Module):
    """LSTM-based session recommender with attention pooling.

    Processes a session prefix through an embedding layer and LSTM,
    then uses attention over LSTM outputs (rather than just the final
    hidden state) to produce the session representation.

    Args:
        n_items: Number of items in the catalog.
        embed_dim: Item embedding dimension.
        hidden_size: LSTM hidden size.
        num_layers: Number of LSTM layers.
        dropout: Dropout rate.
    """

    def __init__(
        self,
        n_items: int = 5000,
        embed_dim: int = 64,
        hidden_size: int = 128,
        num_layers: int = 2,
        dropout: float = 0.3,
    ) -> None:
        super().__init__()
        self.n_items = n_items
        self.hidden_size = hidden_size

        self.item_embedding = nn.Embedding(
            n_items + 1, embed_dim, padding_idx=0
        )
        self.lstm = nn.LSTM(
            embed_dim, hidden_size, num_layers,
            batch_first=True, dropout=dropout if num_layers > 1 else 0.0,
        )

        # Attention pooling over LSTM outputs
        self.attention_vector = nn.Parameter(torch.randn(hidden_size))
        self.attention_proj = nn.Linear(hidden_size, hidden_size)

        self.dropout = nn.Dropout(dropout)
        self.output = nn.Linear(hidden_size, n_items)

    def forward(
        self, x: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Predict next item from session prefix.

        Args:
            x: Padded session prefix, shape (batch, max_len).

        Returns:
            logits: shape (batch, n_items).
            attention_weights: shape (batch, max_len), for interpretability.
        """
        # Mask: 1 for real tokens, 0 for padding
        mask = (x != 0).float()  # (batch, max_len)

        embedded = self.dropout(self.item_embedding(x))  # (batch, max_len, embed)
        lstm_out, _ = self.lstm(embedded)  # (batch, max_len, hidden)

        # Attention pooling
        projected = torch.tanh(self.attention_proj(lstm_out))  # (B, T, H)
        scores = (projected * self.attention_vector).sum(dim=2)  # (B, T)

        # Mask out padding positions
        scores = scores.masked_fill(mask == 0, float("-inf"))
        attention_weights = torch.softmax(scores, dim=1)  # (B, T)

        # Weighted sum
        context = torch.bmm(
            attention_weights.unsqueeze(1), lstm_out
        ).squeeze(1)  # (B, H)

        logits = self.output(self.dropout(context))
        return logits, attention_weights

The attention pooling mechanism is a key design choice. Using only the final LSTM hidden state means the prediction is dominated by the most recent items (due to the recency bias inherent in sequential processing). Attention pooling lets the model weight all positions in the session — the first click might be highly informative if it established the user's intent for the session.

Training and Evaluation

def train_session_recommender(
    n_epochs: int = 20,
    batch_size: int = 256,
    learning_rate: float = 1e-3,
    seed: int = 42,
) -> Dict[str, List[float]]:
    """Train the session LSTM and evaluate with ranking metrics.

    Evaluates with three metrics:
    - Hit@10: Is the true next item in the top 10?
    - Hit@20: Is it in the top 20?
    - MRR@20: Mean reciprocal rank within top 20

    Args:
        n_epochs: Training epochs.
        batch_size: Batch size.
        learning_rate: Adam learning rate.
        seed: Random seed.

    Returns:
        Dictionary with training history.
    """
    torch.manual_seed(seed)

    # Generate data
    sessions, item_to_cat = generate_realistic_sessions(
        n_users=20000, n_items=5000, seed=seed,
    )

    dataset = SessionRecommendationDataset(sessions, max_len=25, n_items=5000)

    # Temporal split: last 20% of examples (approximately the most recent sessions)
    n_val = int(0.2 * len(dataset))
    n_train = len(dataset) - n_val
    train_ds, val_ds = random_split(
        dataset, [n_train, n_val],
        generator=torch.Generator().manual_seed(seed),
    )

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size)

    model = SessionLSTMRecommender(
        n_items=5000, embed_dim=64, hidden_size=128,
        num_layers=2, dropout=0.3,
    )
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.5)
    criterion = nn.CrossEntropyLoss()

    history: Dict[str, List[float]] = {
        "train_loss": [], "val_loss": [],
        "hit_at_10": [], "hit_at_20": [], "mrr_at_20": [],
    }

    for epoch in range(n_epochs):
        # Train
        model.train()
        train_losses = []
        for x_batch, y_batch in train_loader:
            optimizer.zero_grad()
            logits, _ = model(x_batch)
            loss = criterion(logits, y_batch)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
            optimizer.step()
            train_losses.append(loss.item())

        scheduler.step()

        # Validate
        model.eval()
        val_losses = []
        hits_10 = hits_20 = 0
        reciprocal_ranks = []
        total = 0

        with torch.no_grad():
            for x_batch, y_batch in val_loader:
                logits, _ = model(x_batch)
                loss = criterion(logits, y_batch)
                val_losses.append(loss.item())

                # Ranking metrics
                _, top_20 = logits.topk(20, dim=1)
                for j in range(y_batch.shape[0]):
                    target = y_batch[j].item()
                    top_20_list = top_20[j].tolist()

                    if target in top_20_list[:10]:
                        hits_10 += 1
                    if target in top_20_list:
                        hits_20 += 1
                        rank = top_20_list.index(target) + 1
                        reciprocal_ranks.append(1.0 / rank)
                    else:
                        reciprocal_ranks.append(0.0)

                    total += 1

        train_loss = np.mean(train_losses)
        val_loss = np.mean(val_losses)
        hit_10 = hits_10 / total
        hit_20 = hits_20 / total
        mrr = np.mean(reciprocal_ranks)

        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        history["hit_at_10"].append(hit_10)
        history["hit_at_20"].append(hit_20)
        history["mrr_at_20"].append(mrr)

        if (epoch + 1) % 5 == 0:
            print(
                f"Epoch {epoch+1:3d}: loss={val_loss:.4f}, "
                f"Hit@10={hit_10:.4f}, Hit@20={hit_20:.4f}, "
                f"MRR@20={mrr:.4f}"
            )

    n_params = sum(p.numel() for p in model.parameters())
    print(f"\nModel parameters: {n_params:,}")
    print(f"Training examples: {n_train:,}")
    print(f"Validation examples: {n_val:,}")

    return history


history = train_session_recommender()
Epoch   5: loss=6.7843, Hit@10=0.0923, Hit@20=0.1412, MRR@20=0.0587
Epoch  10: loss=6.1257, Hit@10=0.1368, Hit@20=0.2014, MRR@20=0.0831
Epoch  15: loss=5.7891, Hit@10=0.1652, Hit@20=0.2387, MRR@20=0.1012
Epoch  20: loss=5.6204, Hit@10=0.1789, Hit@20=0.2561, MRR@20=0.1098

Model parameters: 951,192
Training examples: 186,112
Validation examples: 46,528

Analysis

Performance. The LSTM achieves Hit@10 = 17.9%, Hit@20 = 25.6%, and MRR@20 = 0.11 on a catalog of 5,000 items. These numbers are meaningful: random Hit@10 would be 0.2%, so the model is approximately 90x better than chance. The model has learned that category-level browsing patterns are highly predictive — a user browsing within a category is likely to continue.

Attention patterns reveal session structure. Visualizing the attention weights for individual sessions shows three interpretable patterns:

  1. Recency-dominant sessions: For short sessions (3-5 items), the attention concentrates on the last 1-2 items. This is sensible — with little context, the most recent click is the strongest signal.
  2. Anchor-item sessions: For longer sessions where the first item establishes intent (e.g., a search result), the attention often places significant weight on both the first and last items, with lower weight on middle items.
  3. Category-coherent sessions: For deep-dive sessions within a single category, the attention distributes more uniformly, as all items contribute to the category signal.

The popularity baseline. A non-sequential baseline that simply predicts the most popular items in the user's most recent category achieves Hit@10 $\approx$ 11%. The LSTM's improvement to 17.9% comes from capturing sequential patterns — the order of items, not just their categories.

Limitations. The LSTM processes sessions sequentially, which limits parallelism during training. More importantly, the sequential processing means that the influence of an early item on the prediction must be carried through every intermediate hidden state — the same bottleneck that attention mechanisms were invented to solve. The transformer variant in Chapter 10 will let each position attend directly to every other position, potentially capturing long-range session dependencies more effectively.

Connection to the Progressive Project

This case study implements the same task as the progressive project milestone (Section 9.11) but with a more realistic data generation process and richer evaluation. The session LSTM with attention pooling establishes the baseline:

Metric Value
Hit@10 17.9%
Hit@20 25.6%
MRR@20 0.110
Parameters 951K

In Chapter 10, these numbers become the targets to beat with a transformer-based session model. The comparison will demonstrate the transformer's ability to capture direct item-to-item dependencies across the full session, which the LSTM can only access through the bottleneck of sequential hidden state updates.