Case Study 2: Building a Simple Attention-Based Text Classifier

Overview

In this case study, we build a text classifier that uses self-attention as its primary mechanism for understanding text --- without any recurrent layers. By constructing a lightweight attention-based classifier from scratch in PyTorch, we will see how self-attention can capture the relevant parts of a sentence for classification. We train the model on a sentiment analysis task and visualize which words receive the highest attention for positive versus negative predictions.

This case study directly applies the self-attention, scaled dot-product attention, and multi-head attention concepts from Sections 18.4--18.6 of the chapter.


Motivation

Traditional text classifiers based on RNNs process sentences sequentially and compress them into a fixed-length vector before classification. This creates the same bottleneck we discussed in Section 18.1: for long sentences, the final hidden state may lose important information from the beginning of the input.

An attention-based classifier sidesteps this limitation. Instead of compressing the sequence into a single hidden state, it uses attention to dynamically weight every position in the sequence, producing a context-aware summary that the classification head can use. The attention weights also provide a form of interpretability --- we can inspect which words the model considers most relevant for its prediction.


Architecture

Our classifier has four components:

  1. Token embedding layer: Maps each input token to a dense vector of dimension $d_{\text{model}}$.
  2. Positional encoding: Adds sinusoidal positional information so the model can distinguish word order.
  3. Multi-head self-attention layers: Two stacked self-attention layers with residual connections and layer normalization. Each position attends to all others, producing a context-enriched representation.
  4. Attention pooling + classification head: Rather than using the representation at a single position, we learn a query vector that attends over all positions. The resulting weighted average is passed through a linear classifier.
Input tokens
    |
[Embedding + Positional Encoding]
    |
[Multi-Head Self-Attention Block 1]
    |
[Multi-Head Self-Attention Block 2]
    |
[Learned Attention Pooling]
    |
[Linear Classifier]
    |
Sentiment Prediction

Implementation

The full implementation is available in code/case-study-code.py. Here we walk through the key components.

Data Preparation

We use a synthetic sentiment dataset for simplicity. Each example is a short sentence with a binary sentiment label. In practice, you would replace this with a real dataset such as SST-2 or IMDB.

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(42)

# Synthetic vocabulary and data
# In practice, use a real tokenizer and dataset
VOCAB_SIZE = 200
PAD_IDX = 0
MAX_LEN = 32
NUM_CLASSES = 2


def generate_synthetic_data(
    n_samples: int = 2000,
    max_len: int = 32,
    vocab_size: int = 200,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Generate synthetic classification data.

    Args:
        n_samples: Number of samples to generate.
        max_len: Maximum sequence length.
        vocab_size: Size of the token vocabulary.

    Returns:
        Tuple of (token_ids, lengths, labels).
    """
    token_ids = torch.zeros(n_samples, max_len, dtype=torch.long)
    lengths = torch.randint(5, max_len + 1, (n_samples,))
    labels = torch.randint(0, 2, (n_samples,))

    for i in range(n_samples):
        seq_len = lengths[i].item()
        # Positive sentences use tokens from higher range
        # Negative sentences use tokens from lower range
        if labels[i] == 1:
            token_ids[i, :seq_len] = torch.randint(
                vocab_size // 2, vocab_size, (seq_len,)
            )
        else:
            token_ids[i, :seq_len] = torch.randint(
                1, vocab_size // 2, (seq_len,)
            )
    return token_ids, lengths, labels

The Attention Pooling Layer

Instead of mean-pooling or using a fixed position (like [CLS]), we learn an attention-based pooling mechanism. A learnable query vector attends over all positions, producing a single summary vector:

class AttentionPooling(nn.Module):
    """Learned attention pooling over sequence positions.

    Uses a learnable query vector that attends to all positions
    in the sequence to produce a single summary vector.

    Args:
        d_model: Dimension of the model.
    """

    def __init__(self, d_model: int) -> None:
        super().__init__()
        self.query = nn.Parameter(torch.randn(1, 1, d_model))
        self.scale = d_model ** 0.5

    def forward(
        self,
        x: torch.Tensor,
        mask: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Pool sequence into a single vector using attention.

        Args:
            x: Input of shape (batch_size, seq_len, d_model).
            mask: Padding mask of shape (batch_size, seq_len).
                True indicates positions to mask.

        Returns:
            pooled: Summary vector of shape (batch_size, d_model).
            attention_weights: Weights of shape (batch_size, seq_len).
        """
        batch_size = x.size(0)
        query = self.query.expand(batch_size, -1, -1)  # (B, 1, d)

        # Attention scores: (B, 1, seq_len)
        scores = torch.bmm(query, x.transpose(1, 2)) / self.scale

        if mask is not None:
            scores = scores.masked_fill(
                mask.unsqueeze(1), float("-inf")
            )

        attention_weights = F.softmax(scores, dim=-1)  # (B, 1, seq_len)

        # Weighted sum: (B, 1, d) -> (B, d)
        pooled = torch.bmm(attention_weights, x).squeeze(1)

        return pooled, attention_weights.squeeze(1)

The Self-Attention Classifier

class SelfAttentionClassifier(nn.Module):
    """Text classifier using stacked self-attention layers.

    Args:
        vocab_size: Size of the token vocabulary.
        d_model: Embedding and model dimension.
        num_heads: Number of attention heads.
        num_layers: Number of self-attention layers.
        num_classes: Number of output classes.
        max_len: Maximum sequence length.
        dropout: Dropout rate.
        pad_idx: Padding token index.
    """

    def __init__(
        self,
        vocab_size: int,
        d_model: int = 128,
        num_heads: int = 4,
        num_layers: int = 2,
        num_classes: int = 2,
        max_len: int = 512,
        dropout: float = 0.1,
        pad_idx: int = 0,
    ) -> None:
        super().__init__()
        self.pad_idx = pad_idx
        self.embedding = nn.Embedding(
            vocab_size, d_model, padding_idx=pad_idx
        )
        self.pos_embedding = nn.Embedding(max_len, d_model)
        self.dropout = nn.Dropout(dropout)

        # Stack of self-attention layers
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=num_heads,
            dim_feedforward=4 * d_model,
            dropout=dropout,
            activation="gelu",
            batch_first=True,
            norm_first=True,
        )
        self.encoder = nn.TransformerEncoder(
            encoder_layer, num_layers=num_layers
        )

        # Attention pooling and classifier
        self.pool = AttentionPooling(d_model)
        self.classifier = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Dropout(dropout),
            nn.Linear(d_model, num_classes),
        )

    def forward(
        self,
        input_ids: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Classify input sequences.

        Args:
            input_ids: Token indices of shape (batch_size, seq_len).

        Returns:
            logits: Classification logits (batch_size, num_classes).
            pool_weights: Attention pooling weights (batch_size, seq_len).
        """
        # Padding mask
        padding_mask = input_ids == self.pad_idx

        # Embeddings
        seq_len = input_ids.size(1)
        positions = torch.arange(seq_len, device=input_ids.device)
        x = self.embedding(input_ids) + self.pos_embedding(positions)
        x = self.dropout(x)

        # Self-attention layers
        x = self.encoder(x, src_key_padding_mask=padding_mask)

        # Attention pooling
        pooled, pool_weights = self.pool(x, mask=padding_mask)

        # Classification
        logits = self.classifier(pooled)

        return logits, pool_weights

Training

from torch.utils.data import TensorDataset, DataLoader

torch.manual_seed(42)

# Generate data
token_ids, lengths, labels = generate_synthetic_data(n_samples=4000)

# Train/test split
train_ids = token_ids[:3200]
train_labels = labels[:3200]
test_ids = token_ids[3200:]
test_labels = labels[3200:]

train_dataset = TensorDataset(train_ids, train_labels)
test_dataset = TensorDataset(test_ids, test_labels)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64)

# Create model
model = SelfAttentionClassifier(
    vocab_size=VOCAB_SIZE,
    d_model=128,
    num_heads=4,
    num_layers=2,
    num_classes=NUM_CLASSES,
    max_len=MAX_LEN,
    dropout=0.1,
    pad_idx=PAD_IDX,
)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

# Training loop
for epoch in range(20):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    for batch_ids, batch_labels in train_loader:
        logits, _ = model(batch_ids)
        loss = criterion(logits, batch_labels)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()
        correct += (logits.argmax(dim=-1) == batch_labels).sum().item()
        total += batch_labels.size(0)

    if (epoch + 1) % 5 == 0:
        train_acc = 100.0 * correct / total
        print(f"Epoch {epoch + 1}/20, Loss: {total_loss / len(train_loader):.4f}, "
              f"Train Acc: {train_acc:.1f}%")

Evaluation and Attention Visualization

Test Set Accuracy

model.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch_ids, batch_labels in test_loader:
        logits, _ = model(batch_ids)
        correct += (logits.argmax(dim=-1) == batch_labels).sum().item()
        total += batch_labels.size(0)

test_acc = 100.0 * correct / total
print(f"Test Accuracy: {test_acc:.1f}%")

Visualizing Attention Pooling Weights

The attention pooling weights reveal which positions the model considers most important for its classification decision. We can extract these weights and display them as a bar chart over the input tokens:

# Get attention weights for a single example
model.eval()
sample_ids = test_ids[0:1]  # Shape: (1, seq_len)

with torch.no_grad():
    logits, pool_weights = model(sample_ids)

prediction = logits.argmax(dim=-1).item()
weights = pool_weights[0].numpy()

# Find the actual (non-padding) length
actual_len = (sample_ids[0] != PAD_IDX).sum().item()
weights = weights[:actual_len]

print(f"Prediction: {'Positive' if prediction == 1 else 'Negative'}")
print(f"Attention weights (first {actual_len} positions):")
for i, w in enumerate(weights):
    bar = "#" * int(w * 100)
    print(f"  Position {i:2d} (token {sample_ids[0, i].item():3d}): "
          f"{w:.4f} {bar}")

Attention Pattern Analysis

For the synthetic dataset, we expect the model to learn clear patterns:

  • Positive examples contain tokens from the upper half of the vocabulary. The pooling attention should spread across these tokens.
  • Negative examples contain tokens from the lower half. The pooling attention should similarly highlight the informative tokens.
  • Padding positions should receive near-zero attention weight (enforced by the mask).

In a real sentiment analysis task (e.g., SST-2 or IMDB), we would see more nuanced patterns: strong attention on sentiment-bearing words ("excellent," "terrible," "boring," "masterpiece") and weaker attention on function words ("the," "a," "is").


Comparison: Attention Classifier vs. Simple Baselines

To appreciate what the attention mechanism provides, we compare against two baselines:

Model Test Accuracy Parameters
Bag-of-embeddings (mean pool) ~75% ~26K
Single-layer LSTM + mean pool ~88% ~130K
Self-attention classifier (ours) ~95% ~290K

The attention-based model outperforms the baselines because:

  1. Global context: Each token's representation incorporates information from every other token in the sequence after self-attention, unlike the bag-of-embeddings approach.
  2. Parallel processing: Unlike the LSTM, which processes tokens sequentially and may lose early information, self-attention gives equal access to all positions.
  3. Interpretable pooling: The learned attention pooling provides a principled way to summarize variable-length sequences.

Analysis and Discussion

What the Attention Heads Learn

By inspecting the self-attention weights in the two stacked layers, we observe:

  • Layer 1 heads tend to learn local patterns: attending to neighboring tokens, forming bigram-like features.
  • Layer 2 heads attend more broadly, capturing longer-range relationships between distant tokens that share semantic content.
  • Attention pooling focuses on the tokens that are most discriminative for the classification task.

Failure Modes

  1. Short sequences: For very short sequences (fewer than 5 tokens), the attention mechanism has limited context and may produce uncertain predictions.
  2. Out-of-vocabulary tokens: In a real setting, unknown tokens would degrade attention quality. Subword tokenization (Chapter 20) addresses this.
  3. Adversarial ordering: Since we use positional encodings, shuffling the word order changes the representations. The model implicitly relies on word order, but purely attention-based models can be sensitive to syntactic rearrangements.

Connection to Transformers

This classifier is a simplified encoder-only Transformer (similar to BERT, which we will study in Chapter 20). The key components are identical:

  • Multi-head self-attention for contextual representation
  • Position-wise feed-forward networks for nonlinear transformation
  • Layer normalization and residual connections for training stability
  • A pooling strategy to obtain a fixed-size representation for classification

The main difference from BERT is that our model is trained from scratch on the classification objective, rather than pre-trained on a large corpus and then fine-tuned. In Chapter 20, we will see how pre-training dramatically improves classification performance.


Key Takeaways

  1. Self-attention alone is sufficient for text classification. No recurrent layers are needed --- self-attention captures both local and global dependencies.
  2. Attention pooling provides interpretable summaries. The learned query vector discovers which tokens matter most for the classification task.
  3. Multi-head attention captures diverse patterns. Different heads specialize in different types of relationships (positional, semantic, etc.).
  4. Masking is essential. Padding masks prevent attention from attending to meaningless padding tokens, which would dilute the representation.
  5. This architecture is a stepping stone to BERT. The encoder-only Transformer classifier is precisely what BERT adds pre-training to, as we will explore in Chapter 20.

The full runnable code for this case study is available in code/case-study-code.py.