Chapter 27: Case Study 2 - Automated Play Classification Using Deep Learning

Introduction

Understanding offensive and defensive schemes at scale requires automated classification of basketball plays. This case study details building a deep learning system that classifies offensive plays from tracking data, enabling coaches and analysts to study tendencies without manual video tagging.

Part 1: Problem Definition

Objective

Build a system that: - Classifies offensive plays into defined categories - Operates on tracking data (player positions over time) - Achieves high accuracy for strategic analysis - Provides interpretable confidence scores

Play Categories

We define 12 offensive play types:

Category Description Example
Pick and Roll Ball handler uses screen, rolls to rim Classic two-man game
Pick and Pop Ball handler uses screen, screener pops Stretch-5 action
Isolation One-on-one with cleared out floor Star player in clutch
Post Up Back-to-basket play Traditional center play
Spot Up Catch and shoot opportunity Off-ball movement
Cut Player cuts to basket Backdoor cut
Transition Fast break opportunity Numbers advantage
Handoff Dribble handoff action Guard-to-guard
Off Screen Player uses off-ball screen Shooter coming off curl
Drive and Kick Penetrate and pass out Inside-out offense
Motion Continuous player movement System offense
Other Unclassified actions Broken plays

Data Sources

  • NBA tracking data (2018-2023 seasons)
  • Synergy play-type labels (ground truth)
  • ~500,000 labeled possessions

Part 2: Data Representation

Input Format

Each possession represented as a tensor:

def create_possession_tensor(tracking_df, possession_id):
    """
    Convert tracking data to model input format.

    Output shape: (T, 23) where T is sequence length
    - 10 players × 2 coordinates = 20
    - 1 ball × 3 coordinates (x, y, z) = 3
    """
    possession = tracking_df[tracking_df['possession_id'] == possession_id]

    # Normalize to offensive half-court
    possession = normalize_to_halfcourt(possession)

    # Resample to fixed length (5 seconds at 10 fps = 50 frames)
    possession = resample_sequence(possession, target_frames=50)

    # Extract position matrix
    positions = []
    for frame in range(len(possession)):
        frame_data = []

        # 5 offensive players (sorted by jersey number for consistency)
        for player in get_offensive_players(possession, frame):
            frame_data.extend([player['x'], player['y']])

        # 5 defensive players
        for player in get_defensive_players(possession, frame):
            frame_data.extend([player['x'], player['y']])

        # Ball position
        ball = get_ball_position(possession, frame)
        frame_data.extend([ball['x'], ball['y'], ball['z']])

        positions.append(frame_data)

    return np.array(positions)

Data Augmentation

def augment_possession(positions):
    """
    Create augmented versions of possession data.
    """
    augmented = []

    # Original
    augmented.append(positions)

    # Mirror across court center
    mirrored = positions.copy()
    mirrored[:, ::2] = 47 - mirrored[:, ::2]  # Flip x coordinates
    augmented.append(mirrored)

    # Add small noise
    noisy = positions + np.random.normal(0, 0.5, positions.shape)
    augmented.append(noisy)

    # Time stretch (speed variation)
    stretched = time_stretch(positions, factor=np.random.uniform(0.9, 1.1))
    augmented.append(stretched)

    return augmented

Part 3: Model Architecture

Approach Comparison

We evaluate three architectures:

Model Description Strengths
LSTM Recurrent network Good for sequences
Transformer Attention-based Captures long-range dependencies
Graph Neural Network Models relationships Natural for player interactions

Final Architecture: Temporal Graph Network

import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv

class PlayClassifier(nn.Module):
    def __init__(self, input_dim=23, hidden_dim=128, num_classes=12):
        super().__init__()

        # Spatial encoder (GNN for each timestep)
        self.spatial_conv1 = GCNConv(2, 32)
        self.spatial_conv2 = GCNConv(32, 64)

        # Temporal encoder (LSTM over time)
        self.temporal_lstm = nn.LSTM(
            input_size=64 * 11,  # 10 players + 1 ball
            hidden_size=hidden_dim,
            num_layers=2,
            batch_first=True,
            bidirectional=True
        )

        # Attention layer
        self.attention = nn.MultiheadAttention(
            embed_dim=hidden_dim * 2,
            num_heads=8,
            batch_first=True
        )

        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, x, edge_index):
        """
        x: (batch, time, 11, 2) - positions
        edge_index: graph connectivity
        """
        batch_size, seq_len, num_nodes, features = x.shape

        # Process each timestep with GNN
        spatial_features = []
        for t in range(seq_len):
            h = x[:, t, :, :]  # (batch, 11, 2)
            h = h.view(-1, features)  # (batch*11, 2)

            h = F.relu(self.spatial_conv1(h, edge_index))
            h = F.relu(self.spatial_conv2(h, edge_index))

            h = h.view(batch_size, num_nodes, -1)  # (batch, 11, 64)
            h = h.view(batch_size, -1)  # (batch, 11*64)
            spatial_features.append(h)

        # Stack temporal features
        spatial_features = torch.stack(spatial_features, dim=1)  # (batch, time, 704)

        # Temporal modeling
        temporal_features, _ = self.temporal_lstm(spatial_features)

        # Self-attention
        attended, _ = self.attention(
            temporal_features, temporal_features, temporal_features
        )

        # Global pooling and classification
        pooled = attended.mean(dim=1)  # (batch, hidden*2)
        logits = self.classifier(pooled)

        return logits

Graph Structure

def create_court_graph():
    """
    Define edges between players based on court relationships.
    """
    edges = []

    # Offensive players connected to each other (complete subgraph)
    for i in range(5):
        for j in range(i+1, 5):
            edges.append([i, j])
            edges.append([j, i])

    # Defensive players connected to each other
    for i in range(5, 10):
        for j in range(i+1, 10):
            edges.append([i, j])
            edges.append([j, i])

    # Each offensive player connected to nearest defenders
    # (Dynamic edges based on proximity)

    # Ball connected to all players
    for i in range(10):
        edges.append([10, i])
        edges.append([i, 10])

    return torch.tensor(edges).t().contiguous()

Part 4: Training

Dataset Split

  • Training: 2018-2021 seasons (400,000 possessions)
  • Validation: 2021-22 season (50,000 possessions)
  • Test: 2022-23 season (50,000 possessions)

Training Configuration

config = {
    'batch_size': 64,
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'epochs': 100,
    'early_stopping_patience': 10,
    'label_smoothing': 0.1
}

# Class weights for imbalanced data
class_weights = compute_class_weights(train_labels)

criterion = nn.CrossEntropyLoss(
    weight=class_weights,
    label_smoothing=0.1
)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config['learning_rate'],
    weight_decay=config['weight_decay']
)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5
)

Training Loop

def train_epoch(model, dataloader, optimizer, criterion):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for batch in dataloader:
        positions = batch['positions'].to(device)
        labels = batch['labels'].to(device)
        edge_index = batch['edge_index'].to(device)

        optimizer.zero_grad()
        logits = model(positions, edge_index)
        loss = criterion(logits, labels)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item()
        predictions = logits.argmax(dim=1)
        correct += (predictions == labels).sum().item()
        total += labels.size(0)

    return total_loss / len(dataloader), correct / total

Part 5: Results

Overall Performance

Metric Value
Overall Accuracy 78.3%
Macro F1 Score 0.72
Top-2 Accuracy 91.2%
AUC-ROC (avg) 0.94

Per-Class Performance

Play Type Precision Recall F1 Support
Pick and Roll 0.85 0.88 0.86 12,450
Pick and Pop 0.72 0.68 0.70 3,210
Isolation 0.81 0.79 0.80 5,890
Post Up 0.88 0.85 0.86 2,340
Spot Up 0.76 0.82 0.79 8,920
Cut 0.69 0.61 0.65 2,100
Transition 0.92 0.94 0.93 6,780
Handoff 0.67 0.63 0.65 1,890
Off Screen 0.71 0.67 0.69 2,560
Drive and Kick 0.62 0.58 0.60 1,450
Motion 0.55 0.48 0.51 980
Other 0.42 0.38 0.40 1,430

Confusion Analysis

Most common confusions: 1. Pick and Roll vs Pick and Pop (14%): Same initial action 2. Handoff vs Off Screen (18%): Similar movement patterns 3. Motion vs Other (22%): Continuous motion hard to classify 4. Cut vs Spot Up (11%): Quick transitions between states

Part 6: Applications

Team Tendency Analysis

def analyze_team_tendencies(model, team_id, season):
    """
    Analyze play type distribution for a team.
    """
    possessions = get_team_possessions(team_id, season)

    play_counts = defaultdict(int)
    play_efficiency = defaultdict(list)

    for poss in possessions:
        # Classify play
        input_tensor = create_possession_tensor(poss)
        logits = model(input_tensor)
        play_type = logits.argmax().item()

        play_counts[play_type] += 1
        play_efficiency[play_type].append(poss['points_scored'])

    return {
        'distribution': play_counts,
        'efficiency': {k: np.mean(v) for k, v in play_efficiency.items()}
    }

Example: 2022-23 Team Profiles

Team Top Play Type Usage Efficiency
Warriors Motion 28% 1.12 PPP
Celtics Pick and Roll 34% 1.08 PPP
Nuggets Post Up 22% 1.15 PPP
Heat Spot Up 31% 1.06 PPP

Opponent Scouting

def generate_scouting_report(model, opponent_id):
    """
    Generate defensive scouting report based on opponent tendencies.
    """
    tendencies = analyze_team_tendencies(model, opponent_id)

    report = {
        'primary_actions': get_top_plays(tendencies, n=3),
        'situational_tendencies': {
            'clutch': analyze_situation(opponent_id, 'clutch'),
            'after_timeout': analyze_situation(opponent_id, 'ato'),
            'vs_zone': analyze_situation(opponent_id, 'vs_zone')
        },
        'personnel_combos': analyze_lineup_plays(opponent_id)
    }

    return report

Real-Time Classification

class RealTimeClassifier:
    def __init__(self, model, buffer_size=50):
        self.model = model
        self.buffer = deque(maxlen=buffer_size)
        self.current_prediction = None

    def update(self, frame_data):
        """
        Update with new tracking frame.
        """
        self.buffer.append(frame_data)

        if len(self.buffer) >= 25:  # Minimum for prediction
            input_tensor = self.prepare_input()
            with torch.no_grad():
                logits = self.model(input_tensor)
                probs = F.softmax(logits, dim=-1)

            self.current_prediction = {
                'play_type': PLAY_TYPES[probs.argmax()],
                'confidence': probs.max().item(),
                'all_probs': probs.tolist()
            }

        return self.current_prediction

Part 7: Limitations and Future Work

Current Limitations

  1. Boundary cases: Plays that transition between types are misclassified
  2. Rare plays: Low accuracy for infrequent play types
  3. Context blindness: Model doesn't consider score, time, or personnel
  4. Defensive schemes: Not yet classifying defensive play types

Future Improvements

  1. Multi-task learning: Jointly predict play type and outcome
  2. Hierarchical classification: Coarse-to-fine play categorization
  3. Temporal attention: Learn to focus on key moments
  4. Explainability: Visualize which player movements drive predictions

Research Directions

  • Self-supervised pretraining: Learn representations without labels
  • Few-shot learning: Quickly adapt to new play types
  • Cross-league transfer: Apply to college/international basketball

Conclusion

Deep learning enables automated classification of basketball plays with sufficient accuracy for practical use. The combination of graph neural networks for spatial relationships and LSTMs for temporal modeling captures the essence of coordinated team movement. While challenges remain for ambiguous plays and rare categories, the system provides valuable automation for video analysis workflows.

Technical Appendix

Hardware Requirements

  • Training: 4× NVIDIA A100 GPUs
  • Inference: Single GPU or CPU (100ms per possession)
  • Real-time: NVIDIA RTX 3080 or better

Reproducibility

  • Random seeds: 42 for all experiments
  • Framework versions: PyTorch 2.0, PyG 2.3
  • Code available: github.com/example/play-classification

Exercises

Exercise 1

Modify the graph structure to use dynamic edges based on player proximity. How does this affect accuracy?

Exercise 2

Implement a visualization that shows which timesteps the attention mechanism focuses on for different play types.

Exercise 3

Design a data augmentation strategy specific to basketball tracking data. What transformations preserve play semantics?