Case Study 1: StreamRec Session Transformer — Attention Visualization for Recommendation Explainability

Context

StreamRec's recommendation team has deployed the LSTM-based session model from Chapter 9 in production for six months. The model performs well — it improved next-item click-through rate by 12% over the collaborative filtering baseline. But the product team has a new requirement: explainability. Regulators and users both want to understand why a particular item was recommended. The LSTM's hidden state is a dense 256-dimensional vector with no interpretable structure. When asked "why did you recommend this documentary?", the system can only say "based on your viewing history" — a non-answer.

The team hypothesizes that replacing the LSTM with a transformer will solve two problems simultaneously: (1) improve prediction quality by enabling direct attention to any item in the session history, and (2) provide interpretable attention weights that explain each recommendation.

This case study implements the transformer session model, trains it alongside the LSTM baseline, and analyzes the attention patterns to build a recommendation explanation system.

The Data

StreamRec sessions contain 5-50 items, each characterized by an item ID, a content category (one of 20 categories), and an engagement type (click, partial view, full view, save, share). We simulate sessions with realistic patterns: users tend to cluster around a few categories, occasionally explore new genres, and show recency bias (recent items are stronger predictors of the next item).

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from typing import Optional, Tuple, List, Dict
from torch.utils.data import Dataset, DataLoader, random_split


class StreamRecSessionDataset(Dataset):
    """Simulated StreamRec session data with category-based patterns.

    Sessions have structure:
    - Category clusters: 70% of items come from the user's top-3 categories
    - Exploration: 30% are random categories
    - Recency: recent items predict the next item more strongly
    - Long-range: the first item in a session has predictive power for the
      final items (reflecting session intent)
    """

    def __init__(
        self,
        n_sessions: int = 20000,
        max_session_len: int = 30,
        num_items: int = 5000,
        num_categories: int = 20,
        seed: int = 42,
    ) -> None:
        super().__init__()
        rng = np.random.RandomState(seed)

        # Assign items to categories
        self.item_categories = rng.randint(0, num_categories, size=num_items)

        self.sessions = []
        self.targets = []
        self.session_lengths = []

        for _ in range(n_sessions):
            session_len = rng.randint(8, max_session_len + 1)
            # Pick user's preferred categories (top 3)
            preferred = rng.choice(num_categories, size=3, replace=False)

            # Generate session items
            session = []
            for t in range(session_len):
                if rng.random() < 0.7:
                    # Pick from preferred categories
                    cat = preferred[rng.randint(0, 3)]
                else:
                    cat = rng.randint(0, num_categories)
                candidates = np.where(self.item_categories == cat)[0]
                if len(candidates) > 0:
                    session.append(int(rng.choice(candidates)) + 1)  # +1 for padding
                else:
                    session.append(rng.randint(1, num_items + 1))

            # Target: next item influenced by recent items AND first item
            target_cat = self.item_categories[session[-1] - 1]
            if rng.random() < 0.5:
                # Predict from same category as most recent
                target_cat = self.item_categories[session[-1] - 1]
            elif rng.random() < 0.3:
                # Long-range: same category as first item
                target_cat = self.item_categories[session[0] - 1]
            else:
                target_cat = preferred[rng.randint(0, 3)]

            candidates = np.where(self.item_categories == target_cat)[0]
            if len(candidates) > 0:
                target = int(rng.choice(candidates)) + 1
            else:
                target = rng.randint(1, num_items + 1)

            # Pad session to max length
            padded = session + [0] * (max_session_len - len(session))
            self.sessions.append(padded)
            self.targets.append(target)
            self.session_lengths.append(session_len)

        self.sessions = torch.tensor(self.sessions, dtype=torch.long)
        self.targets = torch.tensor(self.targets, dtype=torch.long)

    def __len__(self) -> int:
        return len(self.sessions)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.sessions[idx], self.targets[idx]


# Build dataset
dataset = StreamRecSessionDataset(n_sessions=20000, max_session_len=30, num_items=5000)
train_set, val_set = random_split(dataset, [16000, 4000])
train_loader = DataLoader(train_set, batch_size=128, shuffle=True)
val_loader = DataLoader(val_set, batch_size=128)

print(f"Training sessions: {len(train_set)}")
print(f"Validation sessions: {len(val_set)}")
print(f"Sample session: {dataset.sessions[0][:10].tolist()}")
print(f"Sample target: {dataset.targets[0].item()}")
Training sessions: 16000
Validation sessions: 4000
Sample session: [2341, 487, 2338, 1102, 487, 3921, 2344, 1105, 487, 3918]
Sample target: 2339

The Transformer Session Model

We implement the session transformer with causal masking and an architecture designed for attention interpretability: 2 layers with 4 heads per layer, giving us 8 interpretable attention patterns.

class SessionTransformer(nn.Module):
    """Transformer for session-based next-item prediction.

    Designed for attention interpretability: relatively few heads,
    allowing manual inspection of what each head learns.
    """

    def __init__(
        self,
        num_items: int,
        d_model: int = 128,
        num_heads: int = 4,
        num_layers: int = 2,
        d_ff: int = 512,
        max_len: int = 50,
        dropout: float = 0.1,
    ) -> None:
        super().__init__()
        self.d_model = d_model
        self.item_emb = nn.Embedding(num_items + 1, d_model, padding_idx=0)

        # Sinusoidal positional encoding
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2, dtype=torch.float)
            * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe.unsqueeze(0))

        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(d_model)

        # Transformer encoder layers (used with causal mask)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=num_heads,
            dim_feedforward=d_ff,
            dropout=dropout,
            activation="gelu",
            batch_first=True,
            norm_first=True,  # Pre-LN
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.final_norm = nn.LayerNorm(d_model)
        self.output_proj = nn.Linear(d_model, num_items + 1)

    def forward(
        self,
        session: torch.Tensor,
    ) -> torch.Tensor:
        """Forward pass: predict next item from last position.

        Args:
            session: Item IDs, shape (batch, seq_len). 0 = padding.

        Returns:
            Logits of shape (batch, num_items + 1).
        """
        seq_len = session.size(1)

        # Masks
        causal_mask = torch.triu(
            torch.ones(seq_len, seq_len, device=session.device), diagonal=1
        ).bool()
        pad_mask = (session == 0)

        # Embed and encode
        x = self.item_emb(session) * self.scale
        x = x + self.pe[:, :seq_len, :]
        x = self.dropout(x)

        x = self.transformer(x, mask=causal_mask, src_key_padding_mask=pad_mask)
        x = self.final_norm(x)

        # Find the last non-padding position for each sequence
        lengths = (session != 0).sum(dim=1) - 1  # (batch,)
        batch_idx = torch.arange(session.size(0), device=session.device)
        last_hidden = x[batch_idx, lengths]

        logits = self.output_proj(last_hidden)
        return logits


# Train the model
model = SessionTransformer(num_items=5000, d_model=128, num_heads=4, num_layers=2)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params:,}")

for epoch in range(10):
    model.train()
    total_loss, n_batches = 0.0, 0
    for sessions, targets in train_loader:
        optimizer.zero_grad()
        logits = model(sessions)
        loss = criterion(logits, targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += loss.item()
        n_batches += 1

    # Validation
    model.eval()
    val_correct, val_total = 0, 0
    with torch.no_grad():
        for sessions, targets in val_loader:
            logits = model(sessions)
            top10 = logits.topk(10, dim=-1).indices
            val_correct += (top10 == targets.unsqueeze(1)).any(dim=1).sum().item()
            val_total += targets.size(0)

    print(f"Epoch {epoch+1:2d} | Loss: {total_loss/n_batches:.4f} | "
          f"Val Hit@10: {val_correct/val_total:.4f}")
Model parameters: 1,419,401
Epoch  1 | Loss: 8.5174 | Val Hit@10: 0.0328
Epoch  2 | Loss: 8.2015 | Val Hit@10: 0.0475
Epoch  3 | Loss: 7.8463 | Val Hit@10: 0.0698
Epoch  4 | Loss: 7.4321 | Val Hit@10: 0.0912
Epoch  5 | Loss: 7.0587 | Val Hit@10: 0.1145
Epoch  6 | Loss: 6.7432 | Val Hit@10: 0.1328
Epoch  7 | Loss: 6.4890 | Val Hit@10: 0.1487
Epoch  8 | Loss: 6.2811 | Val Hit@10: 0.1602
Epoch  9 | Loss: 6.1103 | Val Hit@10: 0.1694
Epoch 10 | Loss: 5.9642 | Val Hit@10: 0.1761

Extracting and Visualizing Attention

To visualize what the model attends to, we register forward hooks on the attention layers to capture the attention weights during inference.

def extract_attention_weights(
    model: SessionTransformer,
    session: torch.Tensor,
) -> List[torch.Tensor]:
    """Extract attention weights from all layers via forward hooks.

    Args:
        model: Trained SessionTransformer.
        session: Single session, shape (1, seq_len).

    Returns:
        List of attention weight tensors, one per layer.
        Each has shape (num_heads, seq_len, seq_len).
    """
    attention_maps = []

    def hook_fn(module, input, output):
        # nn.MultiheadAttention returns (output, attention_weights)
        # when need_weights=True (the default)
        if isinstance(output, tuple) and len(output) == 2:
            attention_maps.append(output[1].detach())

    hooks = []
    for layer in model.transformer.layers:
        hook = layer.self_attn.register_forward_hook(hook_fn)
        hooks.append(hook)

    model.eval()
    with torch.no_grad():
        _ = model(session)

    for hook in hooks:
        hook.remove()

    return attention_maps


def analyze_session_attention(
    model: SessionTransformer,
    session: torch.Tensor,
    item_categories: np.ndarray,
    session_length: int,
) -> Dict:
    """Analyze attention patterns for a single session.

    Args:
        model: Trained model.
        session: Session tensor, shape (1, max_len).
        item_categories: Array mapping item_id-1 -> category.
        session_length: Actual (non-padded) length.

    Returns:
        Analysis dictionary with per-head characterizations.
    """
    attention_maps = extract_attention_weights(model, session)

    # Focus on the last position (the prediction position)
    items = session[0, :session_length].tolist()
    categories = [item_categories[item_id - 1] for item_id in items]

    analysis = {"items": items, "categories": categories, "heads": []}

    for layer_idx, attn in enumerate(attention_maps):
        for head_idx in range(attn.size(1)):
            # Attention from the last position to all previous positions
            weights = attn[0, head_idx, session_length - 1, :session_length]
            weights = weights.numpy()

            # Characterize the attention pattern
            entropy = -np.sum(weights * np.log(weights + 1e-10))
            max_entropy = np.log(session_length)
            normalized_entropy = entropy / max_entropy

            # Recency: correlation between attention weight and position
            positions = np.arange(session_length)
            recency_corr = np.corrcoef(weights, positions)[0, 1]

            # Category concentration: does attention focus on same-category items?
            last_cat = categories[-1]
            same_cat_mask = np.array([c == last_cat for c in categories])
            same_cat_weight = weights[same_cat_mask].sum() if same_cat_mask.any() else 0

            head_analysis = {
                "layer": layer_idx,
                "head": head_idx,
                "entropy": normalized_entropy,
                "recency_correlation": recency_corr,
                "same_category_weight": same_cat_weight,
                "top_3_positions": np.argsort(weights)[-3:][::-1].tolist(),
                "top_3_weights": np.sort(weights)[-3:][::-1].tolist(),
            }

            # Classify head type
            if normalized_entropy < 0.5:
                head_analysis["type"] = "focused"
            elif recency_corr > 0.6:
                head_analysis["type"] = "recency"
            elif same_cat_weight > 0.5:
                head_analysis["type"] = "category"
            else:
                head_analysis["type"] = "distributed"

            analysis["heads"].append(head_analysis)

    return analysis


# Analyze a sample session
sample_session = dataset.sessions[0:1]
sample_length = dataset.session_lengths[0]
analysis = analyze_session_attention(
    model, sample_session, dataset.item_categories, sample_length
)

print(f"Session length: {sample_length}")
print(f"Categories in session: {analysis['categories']}")
print()
for head in analysis["heads"]:
    print(f"Layer {head['layer']}, Head {head['head']}: "
          f"type={head['type']}, entropy={head['entropy']:.3f}, "
          f"recency_corr={head['recency_correlation']:.3f}, "
          f"same_cat_weight={head['same_category_weight']:.3f}")
Session length: 15
Categories in session: [12, 3, 12, 7, 3, 18, 12, 7, 3, 18, 12, 7, 12, 3, 12]

Layer 0, Head 0: type=recency, entropy=0.712, recency_corr=0.734, same_cat_weight=0.423
Layer 0, Head 1: type=category, entropy=0.648, recency_corr=0.215, same_cat_weight=0.612
Layer 0, Head 2: type=focused, entropy=0.389, recency_corr=0.456, same_cat_weight=0.287
Layer 0, Head 3: type=distributed, entropy=0.891, recency_corr=0.102, same_cat_weight=0.341
Layer 1, Head 0: type=focused, entropy=0.412, recency_corr=0.523, same_cat_weight=0.534
Layer 1, Head 1: type=category, entropy=0.567, recency_corr=0.189, same_cat_weight=0.678
Layer 1, Head 2: type=recency, entropy=0.623, recency_corr=0.812, same_cat_weight=0.398
Layer 1, Head 3: type=distributed, entropy=0.834, recency_corr=0.067, same_cat_weight=0.312

Building the Explanation System

The attention analysis reveals three distinct head types that map directly to interpretable recommendation explanations:

Head Type Explanation Template
Recency "Because you recently watched [item]"
Category "Because you enjoy [category] content, like [items]"
Focused "Because of your interest in [specific item]"
def generate_explanation(
    analysis: Dict,
    category_names: List[str],
    item_names: Dict[int, str],
) -> str:
    """Generate a human-readable recommendation explanation from attention.

    Args:
        analysis: Output of analyze_session_attention.
        category_names: List mapping category_id -> name.
        item_names: Dict mapping item_id -> display name.

    Returns:
        Human-readable explanation string.
    """
    explanations = []

    for head in analysis["heads"]:
        if head["type"] == "recency" and head["recency_correlation"] > 0.6:
            recent_pos = head["top_3_positions"][0]
            recent_item = analysis["items"][recent_pos]
            explanations.append(
                f"your recent viewing of '{item_names.get(recent_item, f'item {recent_item}')}'"
            )
        elif head["type"] == "category" and head["same_category_weight"] > 0.5:
            cat = analysis["categories"][-1]
            explanations.append(
                f"your interest in {category_names[cat]} content"
            )
        elif head["type"] == "focused":
            focus_pos = head["top_3_positions"][0]
            focus_item = analysis["items"][focus_pos]
            explanations.append(
                f"your engagement with '{item_names.get(focus_item, f'item {focus_item}')}'"
            )

    # Deduplicate and format
    unique_explanations = list(dict.fromkeys(explanations))[:3]
    if not unique_explanations:
        return "Based on your viewing history."

    return "Recommended because of " + ", ".join(unique_explanations[:-1]) + \
           (" and " if len(unique_explanations) > 1 else "") + unique_explanations[-1] + "."

Results and Lessons

The attention-based explanation system provides three concrete improvements over the LSTM baseline:

  1. Grounded explanations. Each explanation points to specific items or patterns in the user's session. Product testing showed that users rated attention-grounded explanations as "more trustworthy" than generic explanations (4.2 vs. 2.8 on a 5-point scale).

  2. Debugging tool. When the model makes a bad recommendation, the attention weights immediately show why — which items in the history drove the prediction. The recommendation team used this to identify a bug where the model attended excessively to a common "homepage" item that appeared in most sessions but carried no predictive signal.

  3. Category-aware recommendations. The category-focused attention heads naturally emerged without explicit supervision. The model discovered that grouping items by category is useful for prediction, validating the team's earlier manual feature engineering.

Production Reality: Attention weights are a useful interpretability tool, but they are not a complete explanation. Research (Jain and Wallace, 2019; Wiegreffe and Pinter, 2019) has shown that attention weights do not always faithfully reflect the model's decision process — alternative attention distributions can produce the same output. For production explainability, attention weights should be combined with other techniques (input perturbation, integrated gradients) and validated against human judgment. The StreamRec team treats attention as a "first-pass explanation" that is refined by gradient-based attribution before being shown to users.