Chapter 27: Case Study 2 - Automated Play Classification Using Deep Learning
Introduction
Understanding offensive and defensive schemes at scale requires automated classification of basketball plays. This case study details building a deep learning system that classifies offensive plays from tracking data, enabling coaches and analysts to study tendencies without manual video tagging.
Part 1: Problem Definition
Objective
Build a system that: - Classifies offensive plays into defined categories - Operates on tracking data (player positions over time) - Achieves high accuracy for strategic analysis - Provides interpretable confidence scores
Play Categories
We define 12 offensive play types:
| Category | Description | Example |
|---|---|---|
| Pick and Roll | Ball handler uses screen, rolls to rim | Classic two-man game |
| Pick and Pop | Ball handler uses screen, screener pops | Stretch-5 action |
| Isolation | One-on-one with cleared out floor | Star player in clutch |
| Post Up | Back-to-basket play | Traditional center play |
| Spot Up | Catch and shoot opportunity | Off-ball movement |
| Cut | Player cuts to basket | Backdoor cut |
| Transition | Fast break opportunity | Numbers advantage |
| Handoff | Dribble handoff action | Guard-to-guard |
| Off Screen | Player uses off-ball screen | Shooter coming off curl |
| Drive and Kick | Penetrate and pass out | Inside-out offense |
| Motion | Continuous player movement | System offense |
| Other | Unclassified actions | Broken plays |
Data Sources
- NBA tracking data (2018-2023 seasons)
- Synergy play-type labels (ground truth)
- ~500,000 labeled possessions
Part 2: Data Representation
Input Format
Each possession represented as a tensor:
def create_possession_tensor(tracking_df, possession_id):
"""
Convert tracking data to model input format.
Output shape: (T, 23) where T is sequence length
- 10 players × 2 coordinates = 20
- 1 ball × 3 coordinates (x, y, z) = 3
"""
possession = tracking_df[tracking_df['possession_id'] == possession_id]
# Normalize to offensive half-court
possession = normalize_to_halfcourt(possession)
# Resample to fixed length (5 seconds at 10 fps = 50 frames)
possession = resample_sequence(possession, target_frames=50)
# Extract position matrix
positions = []
for frame in range(len(possession)):
frame_data = []
# 5 offensive players (sorted by jersey number for consistency)
for player in get_offensive_players(possession, frame):
frame_data.extend([player['x'], player['y']])
# 5 defensive players
for player in get_defensive_players(possession, frame):
frame_data.extend([player['x'], player['y']])
# Ball position
ball = get_ball_position(possession, frame)
frame_data.extend([ball['x'], ball['y'], ball['z']])
positions.append(frame_data)
return np.array(positions)
Data Augmentation
def augment_possession(positions):
"""
Create augmented versions of possession data.
"""
augmented = []
# Original
augmented.append(positions)
# Mirror across court center
mirrored = positions.copy()
mirrored[:, ::2] = 47 - mirrored[:, ::2] # Flip x coordinates
augmented.append(mirrored)
# Add small noise
noisy = positions + np.random.normal(0, 0.5, positions.shape)
augmented.append(noisy)
# Time stretch (speed variation)
stretched = time_stretch(positions, factor=np.random.uniform(0.9, 1.1))
augmented.append(stretched)
return augmented
Part 3: Model Architecture
Approach Comparison
We evaluate three architectures:
| Model | Description | Strengths |
|---|---|---|
| LSTM | Recurrent network | Good for sequences |
| Transformer | Attention-based | Captures long-range dependencies |
| Graph Neural Network | Models relationships | Natural for player interactions |
Final Architecture: Temporal Graph Network
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv
class PlayClassifier(nn.Module):
def __init__(self, input_dim=23, hidden_dim=128, num_classes=12):
super().__init__()
# Spatial encoder (GNN for each timestep)
self.spatial_conv1 = GCNConv(2, 32)
self.spatial_conv2 = GCNConv(32, 64)
# Temporal encoder (LSTM over time)
self.temporal_lstm = nn.LSTM(
input_size=64 * 11, # 10 players + 1 ball
hidden_size=hidden_dim,
num_layers=2,
batch_first=True,
bidirectional=True
)
# Attention layer
self.attention = nn.MultiheadAttention(
embed_dim=hidden_dim * 2,
num_heads=8,
batch_first=True
)
# Classifier
self.classifier = nn.Sequential(
nn.Linear(hidden_dim * 2, 128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, num_classes)
)
def forward(self, x, edge_index):
"""
x: (batch, time, 11, 2) - positions
edge_index: graph connectivity
"""
batch_size, seq_len, num_nodes, features = x.shape
# Process each timestep with GNN
spatial_features = []
for t in range(seq_len):
h = x[:, t, :, :] # (batch, 11, 2)
h = h.view(-1, features) # (batch*11, 2)
h = F.relu(self.spatial_conv1(h, edge_index))
h = F.relu(self.spatial_conv2(h, edge_index))
h = h.view(batch_size, num_nodes, -1) # (batch, 11, 64)
h = h.view(batch_size, -1) # (batch, 11*64)
spatial_features.append(h)
# Stack temporal features
spatial_features = torch.stack(spatial_features, dim=1) # (batch, time, 704)
# Temporal modeling
temporal_features, _ = self.temporal_lstm(spatial_features)
# Self-attention
attended, _ = self.attention(
temporal_features, temporal_features, temporal_features
)
# Global pooling and classification
pooled = attended.mean(dim=1) # (batch, hidden*2)
logits = self.classifier(pooled)
return logits
Graph Structure
def create_court_graph():
"""
Define edges between players based on court relationships.
"""
edges = []
# Offensive players connected to each other (complete subgraph)
for i in range(5):
for j in range(i+1, 5):
edges.append([i, j])
edges.append([j, i])
# Defensive players connected to each other
for i in range(5, 10):
for j in range(i+1, 10):
edges.append([i, j])
edges.append([j, i])
# Each offensive player connected to nearest defenders
# (Dynamic edges based on proximity)
# Ball connected to all players
for i in range(10):
edges.append([10, i])
edges.append([i, 10])
return torch.tensor(edges).t().contiguous()
Part 4: Training
Dataset Split
- Training: 2018-2021 seasons (400,000 possessions)
- Validation: 2021-22 season (50,000 possessions)
- Test: 2022-23 season (50,000 possessions)
Training Configuration
config = {
'batch_size': 64,
'learning_rate': 1e-4,
'weight_decay': 1e-5,
'epochs': 100,
'early_stopping_patience': 10,
'label_smoothing': 0.1
}
# Class weights for imbalanced data
class_weights = compute_class_weights(train_labels)
criterion = nn.CrossEntropyLoss(
weight=class_weights,
label_smoothing=0.1
)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config['learning_rate'],
weight_decay=config['weight_decay']
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=5
)
Training Loop
def train_epoch(model, dataloader, optimizer, criterion):
model.train()
total_loss = 0
correct = 0
total = 0
for batch in dataloader:
positions = batch['positions'].to(device)
labels = batch['labels'].to(device)
edge_index = batch['edge_index'].to(device)
optimizer.zero_grad()
logits = model(positions, edge_index)
loss = criterion(logits, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item()
predictions = logits.argmax(dim=1)
correct += (predictions == labels).sum().item()
total += labels.size(0)
return total_loss / len(dataloader), correct / total
Part 5: Results
Overall Performance
| Metric | Value |
|---|---|
| Overall Accuracy | 78.3% |
| Macro F1 Score | 0.72 |
| Top-2 Accuracy | 91.2% |
| AUC-ROC (avg) | 0.94 |
Per-Class Performance
| Play Type | Precision | Recall | F1 | Support |
|---|---|---|---|---|
| Pick and Roll | 0.85 | 0.88 | 0.86 | 12,450 |
| Pick and Pop | 0.72 | 0.68 | 0.70 | 3,210 |
| Isolation | 0.81 | 0.79 | 0.80 | 5,890 |
| Post Up | 0.88 | 0.85 | 0.86 | 2,340 |
| Spot Up | 0.76 | 0.82 | 0.79 | 8,920 |
| Cut | 0.69 | 0.61 | 0.65 | 2,100 |
| Transition | 0.92 | 0.94 | 0.93 | 6,780 |
| Handoff | 0.67 | 0.63 | 0.65 | 1,890 |
| Off Screen | 0.71 | 0.67 | 0.69 | 2,560 |
| Drive and Kick | 0.62 | 0.58 | 0.60 | 1,450 |
| Motion | 0.55 | 0.48 | 0.51 | 980 |
| Other | 0.42 | 0.38 | 0.40 | 1,430 |
Confusion Analysis
Most common confusions: 1. Pick and Roll vs Pick and Pop (14%): Same initial action 2. Handoff vs Off Screen (18%): Similar movement patterns 3. Motion vs Other (22%): Continuous motion hard to classify 4. Cut vs Spot Up (11%): Quick transitions between states
Part 6: Applications
Team Tendency Analysis
def analyze_team_tendencies(model, team_id, season):
"""
Analyze play type distribution for a team.
"""
possessions = get_team_possessions(team_id, season)
play_counts = defaultdict(int)
play_efficiency = defaultdict(list)
for poss in possessions:
# Classify play
input_tensor = create_possession_tensor(poss)
logits = model(input_tensor)
play_type = logits.argmax().item()
play_counts[play_type] += 1
play_efficiency[play_type].append(poss['points_scored'])
return {
'distribution': play_counts,
'efficiency': {k: np.mean(v) for k, v in play_efficiency.items()}
}
Example: 2022-23 Team Profiles
| Team | Top Play Type | Usage | Efficiency |
|---|---|---|---|
| Warriors | Motion | 28% | 1.12 PPP |
| Celtics | Pick and Roll | 34% | 1.08 PPP |
| Nuggets | Post Up | 22% | 1.15 PPP |
| Heat | Spot Up | 31% | 1.06 PPP |
Opponent Scouting
def generate_scouting_report(model, opponent_id):
"""
Generate defensive scouting report based on opponent tendencies.
"""
tendencies = analyze_team_tendencies(model, opponent_id)
report = {
'primary_actions': get_top_plays(tendencies, n=3),
'situational_tendencies': {
'clutch': analyze_situation(opponent_id, 'clutch'),
'after_timeout': analyze_situation(opponent_id, 'ato'),
'vs_zone': analyze_situation(opponent_id, 'vs_zone')
},
'personnel_combos': analyze_lineup_plays(opponent_id)
}
return report
Real-Time Classification
class RealTimeClassifier:
def __init__(self, model, buffer_size=50):
self.model = model
self.buffer = deque(maxlen=buffer_size)
self.current_prediction = None
def update(self, frame_data):
"""
Update with new tracking frame.
"""
self.buffer.append(frame_data)
if len(self.buffer) >= 25: # Minimum for prediction
input_tensor = self.prepare_input()
with torch.no_grad():
logits = self.model(input_tensor)
probs = F.softmax(logits, dim=-1)
self.current_prediction = {
'play_type': PLAY_TYPES[probs.argmax()],
'confidence': probs.max().item(),
'all_probs': probs.tolist()
}
return self.current_prediction
Part 7: Limitations and Future Work
Current Limitations
- Boundary cases: Plays that transition between types are misclassified
- Rare plays: Low accuracy for infrequent play types
- Context blindness: Model doesn't consider score, time, or personnel
- Defensive schemes: Not yet classifying defensive play types
Future Improvements
- Multi-task learning: Jointly predict play type and outcome
- Hierarchical classification: Coarse-to-fine play categorization
- Temporal attention: Learn to focus on key moments
- Explainability: Visualize which player movements drive predictions
Research Directions
- Self-supervised pretraining: Learn representations without labels
- Few-shot learning: Quickly adapt to new play types
- Cross-league transfer: Apply to college/international basketball
Conclusion
Deep learning enables automated classification of basketball plays with sufficient accuracy for practical use. The combination of graph neural networks for spatial relationships and LSTMs for temporal modeling captures the essence of coordinated team movement. While challenges remain for ambiguous plays and rare categories, the system provides valuable automation for video analysis workflows.
Technical Appendix
Hardware Requirements
- Training: 4× NVIDIA A100 GPUs
- Inference: Single GPU or CPU (100ms per possession)
- Real-time: NVIDIA RTX 3080 or better
Reproducibility
- Random seeds: 42 for all experiments
- Framework versions: PyTorch 2.0, PyG 2.3
- Code available: github.com/example/play-classification
Exercises
Exercise 1
Modify the graph structure to use dynamic edges based on player proximity. How does this affect accuracy?
Exercise 2
Implement a visualization that shows which timesteps the attention mechanism focuses on for different play types.
Exercise 3
Design a data augmentation strategy specific to basketball tracking data. What transformations preserve play semantics?