Introduction to ML in Baseball
Advanced
10 min read
0 views
Nov 26, 2025
# Introduction to Machine Learning for Baseball Analytics
Machine learning has revolutionized baseball analytics, transforming how teams evaluate players, make strategic decisions, and gain competitive advantages. This comprehensive guide introduces machine learning concepts specifically for baseball analytics applications.
## Why Machine Learning for Baseball Analytics?
Baseball generates massive amounts of data—from traditional statistics to modern tracking data like Statcast. Machine learning excels at finding patterns in this complex data that traditional statistical methods might miss.
**Key advantages of ML in baseball:**
- **Pattern Recognition**: ML algorithms can identify subtle patterns across thousands of at-bats, pitches, or defensive plays that human analysts might overlook
- **Predictive Power**: Models can forecast future performance, injury risk, or optimal strategies based on historical patterns
- **High-Dimensional Analysis**: Modern baseball data has hundreds of features (spin rate, launch angle, exit velocity, etc.). ML handles this complexity naturally
- **Nonlinear Relationships**: Baseball outcomes often depend on complex interactions between variables. ML captures these nonlinear relationships effectively
- **Automation**: Once trained, ML models can process new data instantly, providing real-time insights during games
- **Objective Decision-Making**: Models reduce human bias and provide data-driven recommendations
Traditional statistics like batting average or ERA provide useful summaries, but they can't capture the full complexity of baseball. For example, two pitchers with identical ERAs might have vastly different pitch movement profiles, release points, and sequencing strategies. ML can analyze all these factors simultaneously to provide deeper insights.
## Types of Machine Learning Problems in Baseball
Machine learning problems fall into several categories, each suited for different baseball analytics questions.
### Classification Problems
Classification involves predicting categorical outcomes—answering "which category?" questions.
**Baseball applications:**
- **Pitch Type Classification**: Identifying whether the next pitch will be a fastball, curveball, or slider based on game situation
- **Ball/Strike Prediction**: Predicting whether a pitch will be called a ball or strike based on location and movement
- **Hit Type Classification**: Categorizing batted balls as ground balls, line drives, fly balls, or pop-ups
- **Player Clustering**: Grouping players into categories (power hitters, contact hitters, etc.)
- **Defensive Positioning**: Classifying where batters are most likely to hit based on pitcher and situation
- **Injury Risk Assessment**: Identifying high-risk vs. low-risk players for specific injuries
### Regression Problems
Regression predicts continuous numerical values—answering "how much?" questions.
**Baseball applications:**
- **Exit Velocity Prediction**: Estimating how hard a ball will be hit based on swing mechanics and pitch characteristics
- **Launch Angle Optimization**: Predicting optimal swing paths to maximize expected runs
- **WAR Projection**: Forecasting a player's Wins Above Replacement for upcoming seasons
- **Contract Value Estimation**: Predicting fair market value for free agents based on projected performance
- **Run Scoring Prediction**: Estimating expected runs from specific game situations
- **Pitch Velocity Forecasting**: Predicting a pitcher's velocity trajectory throughout a game or season
### Clustering Problems
Clustering groups similar data points together without predefined categories—discovering natural patterns.
**Baseball applications:**
- **Player Similarity Analysis**: Finding comparable players for trade evaluations or prospect comparisons
- **Pitch Arsenal Grouping**: Identifying distinct pitch movement profiles
- **Defensive Zone Clustering**: Discovering natural defensive zones based on batted ball data
- **Game Situation Clustering**: Grouping similar strategic scenarios
- **Swing Pattern Analysis**: Identifying distinct swing types among hitters
### Time Series Forecasting
Time series models predict future values based on sequential data over time.
**Baseball applications:**
- **Performance Trend Analysis**: Forecasting whether a player is improving or declining
- **Season Projections**: Predicting full-season statistics from partial data
- **Fatigue Monitoring**: Detecting performance degradation due to workload
- **Aging Curves**: Modeling typical performance trajectories across career arcs
## Common Machine Learning Algorithms Used in Baseball
Different algorithms have different strengths. Here are the most popular choices for baseball analytics.
### Random Forests
Random forests create many decision trees and combine their predictions. They're robust, handle mixed data types well, and provide feature importance rankings.
**Strengths for baseball:**
- Excellent for tabular data (perfect for baseball statistics)
- Handles both categorical (position, pitch type) and numerical (velocity, spin rate) features
- Resistant to overfitting
- Provides interpretable feature importance
- No need to normalize/scale features
**Common applications:**
- Player performance prediction
- Pitch outcome classification
- Injury risk modeling
### Gradient Boosting (XGBoost, LightGBM, CatBoost)
Gradient boosting builds trees sequentially, with each tree correcting the errors of previous ones. XGBoost is particularly popular in baseball analytics.
**Strengths for baseball:**
- Often achieves best predictive accuracy on structured data
- Handles missing data well (common in baseball datasets)
- Fast training and prediction
- Excellent feature importance analysis
**Common applications:**
- WAR projections
- Win probability models
- Player valuation systems
### Logistic Regression
Despite being a classical method, logistic regression remains valuable for binary classification with interpretable results.
**Strengths for baseball:**
- Simple and interpretable
- Provides probability estimates
- Works well with limited data
- Computationally efficient
**Common applications:**
- Ball/strike classification
- Hit/out prediction
- Base stealing success probability
### Neural Networks (Deep Learning)
Neural networks excel at learning complex patterns from large datasets, especially with spatial or sequential data.
**Strengths for baseball:**
- Exceptional for computer vision (video analysis, pitch tracking)
- Handles sequential data well (pitch sequences, at-bat progression)
- Can learn extremely complex patterns
- Powerful for representation learning
**Common applications:**
- Video-based swing analysis
- Pitch sequence optimization
- Advanced Statcast data modeling
- Automated scouting reports from video
### K-Nearest Neighbors (KNN)
KNN classifies data points based on their similarity to nearby points in the feature space.
**Strengths for baseball:**
- Simple and intuitive
- Excellent for player similarity analysis
- No training phase required
**Common applications:**
- Finding comparable players
- Historical situation matching
- Defensive shift positioning
### Support Vector Machines (SVM)
SVMs find optimal boundaries between classes in high-dimensional space.
**Strengths for baseball:**
- Effective with high-dimensional data
- Works well with smaller datasets
- Robust to outliers
**Common applications:**
- Pitch classification
- Quality start prediction
- Hall of Fame likelihood
## Feature Engineering for Baseball Data
Feature engineering—creating meaningful input variables—is crucial for ML success in baseball. Raw data often needs transformation to maximize predictive power.
### Domain-Specific Features
Create features that capture baseball knowledge:
**Pitch sequencing features:**
- Previous pitch type, velocity, location
- Pitch count and sequence patterns
- Time since last pitch of same type
- Tunnel differential between pitches
**Situational features:**
- Count leverage (how much count favors pitcher/batter)
- Base-out states encoded numerically
- Score differential and inning
- Platoon advantage (handedness matchups)
**Statcast-derived features:**
- Spin efficiency (ratio of useful spin to total spin)
- Approach angle (vertical bat path at contact)
- Attack angle differential (swing vs. pitch angle)
- Extension-adjusted velocity
**Historical aggregates:**
- Rolling averages (last 10 games, 30 days, etc.)
- Seasonal trends and momentum
- Career statistics vs. specific opponents
- Performance splits (home/away, day/night)
### Interaction Features
Combine variables to capture relationships:
```
# Examples of interaction features
exit_velocity * launch_angle # Batted ball quality
spin_rate * velocity # Pitch movement potential
count_leverage * platoon_advantage # Situational edge
```
### Temporal Features
Capture time-based patterns:
- Day of week, month of season
- Days rest for pitchers
- Games played in last week
- Career age and experience level
### Normalization and Scaling
Many algorithms require feature scaling:
**When to scale:**
- Neural networks, SVM, KNN: Always scale
- Tree-based methods (Random Forest, XGBoost): Usually no need
- Linear models: Scale for regularization
**Common scaling methods:**
- **Standardization**: Zero mean, unit variance (most common)
- **Min-max scaling**: Scale to [0,1] range
- **Robust scaling**: Use median/IQR (resistant to outliers)
### Handling Categorical Variables
Convert categories to numeric representations:
- **One-hot encoding**: Create binary column for each category (use for low cardinality)
- **Label encoding**: Assign numbers to categories (use with tree methods)
- **Target encoding**: Replace category with mean target value (powerful but risky)
- **Frequency encoding**: Use category frequency as feature
### Dealing with Missing Data
Baseball data often has missing values:
**Strategies:**
- **Imputation**: Fill with mean, median, or model-based prediction
- **Indicator variables**: Create flag for missing status
- **Forward fill**: Use last known value (for time series)
- **Leave as-is**: XGBoost and similar algorithms handle missing data natively
## Training/Testing Splits and Cross-Validation
Proper data splitting prevents overfitting and ensures models generalize to new data.
### The Train/Test Split
Always evaluate models on data they haven't seen during training.
**Standard approach:**
- **Training set** (70-80%): Used to fit model parameters
- **Test set** (20-30%): Used only for final evaluation
**Important for baseball:**
Use **temporal splits** when predicting future performance—train on earlier seasons, test on later seasons. This mimics real-world usage better than random splits.
```python
# Bad: Random split (can use future to predict past)
train, test = train_test_split(data, test_size=0.2)
# Good: Temporal split (respects time order)
train = data[data['season'] < 2024]
test = data[data['season'] == 2024]
```
### Cross-Validation
Cross-validation provides more robust performance estimates by testing on multiple data subsets.
**K-Fold Cross-Validation:**
- Split data into K folds (typically 5-10)
- Train on K-1 folds, test on remaining fold
- Repeat K times with different test fold
- Average performance across all folds
**Time Series Cross-Validation:**
For sequential data, use rolling or expanding windows:
```
Fold 1: Train [2018] Test [2019]
Fold 2: Train [2018-2019] Test [2020]
Fold 3: Train [2018-2020] Test [2021]
Fold 4: Train [2018-2021] Test [2022]
```
**Stratified Cross-Validation:**
Ensures each fold has similar class distributions—important for imbalanced outcomes (e.g., home runs are rare).
### Validation Sets
For hyperparameter tuning, use three-way split:
- **Training**: Fit model parameters
- **Validation**: Tune hyperparameters
- **Test**: Final performance evaluation
This prevents "overfitting to the test set" during model selection.
## Overfitting and Model Selection
Overfitting occurs when models learn noise rather than signal, performing well on training data but poorly on new data.
### Signs of Overfitting
- Large gap between training and test performance
- Model performs worse on recent data than historical data
- Overly complex decision boundaries
- High variance in cross-validation scores
### Preventing Overfitting
**Regularization:**
Add penalties for model complexity:
- **L1 regularization (Lasso)**: Forces some coefficients to zero, performs feature selection
- **L2 regularization (Ridge)**: Shrinks coefficients toward zero
- **Elastic Net**: Combines L1 and L2
**Tree-based regularization:**
- Limit tree depth
- Require minimum samples per leaf
- Limit number of trees
- Reduce learning rate (for boosting)
**Early stopping:**
Monitor validation performance and stop training when it stops improving.
**Ensemble methods:**
Combine multiple models to reduce overfitting risk of individual models.
### Model Selection Strategies
**Compare multiple algorithms:**
Test various model types on your specific problem—different algorithms excel at different tasks.
**Hyperparameter tuning:**
- **Grid search**: Try all combinations of parameters
- **Random search**: Sample random parameter combinations
- **Bayesian optimization**: Use previous results to guide search
**Feature selection:**
Remove irrelevant or redundant features:
- Recursive feature elimination
- Feature importance from tree models
- L1 regularization for automatic selection
### Evaluation Metrics
Choose metrics appropriate for your problem:
**Classification metrics:**
- **Accuracy**: Overall correctness (can be misleading with imbalanced classes)
- **Precision**: Of predicted positives, how many are correct?
- **Recall**: Of actual positives, how many did we catch?
- **F1 Score**: Harmonic mean of precision and recall
- **ROC AUC**: Trade-off between true positive and false positive rates
- **Log Loss**: Penalizes confident wrong predictions
**Regression metrics:**
- **MAE (Mean Absolute Error)**: Average prediction error
- **RMSE (Root Mean Squared Error)**: Penalizes large errors more
- **R² (R-squared)**: Proportion of variance explained
- **MAPE (Mean Absolute Percentage Error)**: Error relative to actual value
## Popular Machine Learning Libraries
### Python: scikit-learn
The standard Python ML library—comprehensive, well-documented, and easy to use.
**Strengths:**
- Consistent API across all algorithms
- Excellent documentation and examples
- Built-in preprocessing and evaluation tools
- Perfect for learning ML fundamentals
**Installation:**
```bash
pip install scikit-learn pandas numpy
```
### Python: XGBoost
Gradient boosting library that often wins ML competitions.
**Strengths:**
- State-of-the-art performance on structured data
- Handles missing data
- Built-in cross-validation
- Feature importance analysis
**Installation:**
```bash
pip install xgboost
```
### Python: TensorFlow/Keras
Deep learning frameworks for neural networks.
**Strengths:**
- Powerful for complex patterns
- Excellent for video/image data
- Production-ready deployment
- Large community
**Installation:**
```bash
pip install tensorflow
```
### R: caret
Comprehensive ML framework that unifies 200+ algorithms with consistent syntax.
**Strengths:**
- Streamlined workflow from preprocessing to evaluation
- Automated hyperparameter tuning
- Extensive model selection
- Excellent for research and exploration
**Installation:**
```r
install.packages("caret")
```
### R: tidymodels
Modern ML framework following tidyverse principles.
**Strengths:**
- Clean, intuitive syntax
- Consistent approach across all models
- Integrated with tidyverse ecosystem
- Excellent for production workflows
**Installation:**
```r
install.packages("tidymodels")
```
### Python: PyTorch
Deep learning framework with dynamic computation graphs.
**Strengths:**
- Flexible and pythonic
- Excellent for research
- Strong community in sports analytics
- Easy debugging
**Installation:**
```bash
pip install torch
```
## Real-World Applications in Baseball
### Pitch Prediction and Sequencing
Models predict next pitch type based on:
- Count, score, inning, base-out state
- Previous pitches in at-bat and game
- Batter tendencies and history vs. pitcher
- Pitcher repertoire and recent usage
**Value:** Helps batters prepare, informs hitting coaches, powers opponent scouting
### Launch Angle Optimization
ML identifies optimal swing paths for individual hitters:
- Analyze relationship between launch angle and outcomes
- Account for exit velocity, spray angle, ballpark factors
- Find each hitter's "sweet spot" zone
- Recommend swing adjustments
**Value:** Several teams credit launch angle optimization with power surges
### Defensive Positioning
Models predict batted ball location to optimize fielder positioning:
- Analyze spray charts with pitch type, count, velocity
- Account for ballpark dimensions
- Generate heat maps for positioning
- Real-time adjustments during at-bats
**Value:** Shifts save dozens of hits per season, worth multiple wins
### Injury Prediction
ML identifies injury risk factors:
- Biomechanical data (arm slot changes, release point variance)
- Workload metrics (pitch counts, days rest)
- Velocity fluctuations
- Historical injury patterns
**Value:** Preventive rest saves careers and millions in contract value
### Player Development
Models track skill development:
- Identify mechanical improvements
- Project minor leaguer performance in majors
- Recommend training priorities
- Monitor progress against development plans
**Value:** Better prospect evaluation, optimized development resources
### Win Probability Models
Real-time prediction of game outcomes:
- Updates after every pitch
- Accounts for score, inning, base-out state
- Includes team strength and pitcher quality
- Powers strategic decisions (when to bring closer, etc.)
**Value:** Optimal strategic timing, fan engagement
## Getting Started Workflow
Here's a practical workflow for your first baseball ML project.
### Step 1: Define Your Question
Be specific about what you're predicting:
- "Predict whether a pitch will be a called strike" (classification)
- "Estimate a player's home runs next season" (regression)
- "Group pitchers by arsenal similarity" (clustering)
### Step 2: Gather and Explore Data
Collect relevant data from sources like:
- Baseball-Reference, FanGraphs, Baseball Savant
- MLB Statcast API
- Retrosheet play-by-play
- Private scouting databases
Explore with summary statistics, visualizations, correlation analysis.
### Step 3: Engineer Features
Create meaningful variables:
- Calculate derived metrics
- Encode categorical variables
- Create interaction terms
- Add temporal features
### Step 4: Split Data
Create train/test splits:
- Use temporal splits for future prediction
- Reserve 20-30% for testing
- Consider validation set for tuning
### Step 5: Train Baseline Model
Start simple:
- Logistic regression for classification
- Linear regression for continuous prediction
- Establish baseline performance to beat
### Step 6: Try Advanced Models
Experiment with:
- Random forests
- XGBoost
- Neural networks (if you have lots of data)
### Step 7: Evaluate and Iterate
- Compare models using appropriate metrics
- Analyze errors to find improvement opportunities
- Engineer new features based on insights
- Tune hyperparameters
- Validate on test set
### Step 8: Interpret Results
- Examine feature importance
- Analyze prediction errors
- Create visualizations
- Generate actionable insights
### Step 9: Deploy and Monitor
- Put model into production use
- Monitor performance on new data
- Retrain periodically as new data arrives
- Update features and algorithms as needed
## Code Examples
### Example 1: Ball/Strike Classification (Python)
This example builds a logistic regression model to predict ball vs. strike calls based on pitch location.
```python
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report, roc_auc_score
import matplotlib.pyplot as plt
# Load pitch data (example structure)
# Columns: plate_x, plate_z, pitch_type, called_strike (0=ball, 1=strike)
pitches = pd.read_csv('pitches.csv')
# Feature engineering
pitches['distance_from_center'] = np.sqrt(pitches['plate_x']**2 +
(pitches['plate_z'] - 2.5)**2)
pitches['in_zone'] = ((abs(pitches['plate_x']) < 0.83) &
(pitches['plate_z'] > 1.5) &
(pitches['plate_z'] < 3.5)).astype(int)
# Select features
features = ['plate_x', 'plate_z', 'distance_from_center', 'in_zone']
X = pitches[features]
y = pitches['called_strike']
# Train/test split
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# Train logistic regression model
model = LogisticRegression(max_iter=1000)
model.fit(X_train, y_train)
# Predictions
y_pred = model.predict(X_test)
y_pred_proba = model.predict_proba(X_test)[:, 1]
# Evaluation
print("Ball/Strike Classification Results")
print("=" * 50)
print(f"Accuracy: {accuracy_score(y_test, y_pred):.3f}")
print(f"ROC AUC: {roc_auc_score(y_test, y_pred_proba):.3f}")
print("\nClassification Report:")
print(classification_report(y_test, y_pred,
target_names=['Ball', 'Strike']))
# Feature importance (coefficients)
feature_importance = pd.DataFrame({
'feature': features,
'coefficient': model.coef_[0]
}).sort_values('coefficient', ascending=False)
print("\nFeature Importance:")
print(feature_importance)
# Visualize decision boundary
plt.figure(figsize=(10, 6))
plt.scatter(X_test['plate_x'], X_test['plate_z'],
c=y_test, cmap='RdYlGn', alpha=0.6, edgecolors='black')
plt.xlabel('Horizontal Location (ft)')
plt.ylabel('Vertical Location (ft)')
plt.title('Pitch Locations: Balls vs Strikes')
plt.colorbar(label='Called Strike')
# Add strike zone rectangle
from matplotlib.patches import Rectangle
strike_zone = Rectangle((-0.83, 1.5), 1.66, 2.0,
fill=False, edgecolor='blue', linewidth=2)
plt.gca().add_patch(strike_zone)
plt.savefig('ball_strike_classification.png', dpi=300, bbox_inches='tight')
plt.show()
```
### Example 2: Exit Velocity Regression (Python)
Predict exit velocity based on swing and pitch characteristics using Random Forest.
```python
import pandas as pd
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_absolute_error, r2_score, mean_squared_error
import numpy as np
# Load batted ball data
# Columns: launch_angle, pitch_velocity, swing_speed, contact_quality, exit_velocity
batted_balls = pd.read_csv('batted_balls.csv')
# Feature engineering
batted_balls['velocity_squared'] = batted_balls['pitch_velocity'] ** 2
batted_balls['swing_pitch_interaction'] = (batted_balls['swing_speed'] *
batted_balls['pitch_velocity'])
batted_balls['optimal_launch'] = (abs(batted_balls['launch_angle'] - 25) < 5).astype(int)
# Remove outliers (exit velocity > 120 mph is rare)
batted_balls = batted_balls[batted_balls['exit_velocity'] <= 120]
# Select features
features = ['pitch_velocity', 'swing_speed', 'contact_quality',
'launch_angle', 'velocity_squared', 'swing_pitch_interaction']
X = batted_balls[features]
y = batted_balls['exit_velocity']
# Train/test split (temporal if you have dates)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# Train Random Forest
rf_model = RandomForestRegressor(
n_estimators=100,
max_depth=15,
min_samples_split=10,
min_samples_leaf=5,
random_state=42,
n_jobs=-1
)
rf_model.fit(X_train, y_train)
# Predictions
y_pred = rf_model.predict(X_test)
# Evaluation
print("Exit Velocity Prediction Results")
print("=" * 50)
print(f"Mean Absolute Error: {mean_absolute_error(y_test, y_pred):.2f} mph")
print(f"Root Mean Squared Error: {np.sqrt(mean_squared_error(y_test, y_pred)):.2f} mph")
print(f"R² Score: {r2_score(y_test, y_pred):.3f}")
# Cross-validation
cv_scores = cross_val_score(rf_model, X, y, cv=5,
scoring='neg_mean_absolute_error')
print(f"\nCross-Validation MAE: {-cv_scores.mean():.2f} (+/- {cv_scores.std():.2f}) mph")
# Prediction vs Actual plot
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 6))
plt.scatter(y_test, y_pred, alpha=0.5, edgecolors='black')
plt.plot([y_test.min(), y_test.max()],
[y_test.min(), y_test.max()],
'r--', lw=2, label='Perfect Prediction')
plt.xlabel('Actual Exit Velocity (mph)')
plt.ylabel('Predicted Exit Velocity (mph)')
plt.title('Exit Velocity Prediction: Actual vs Predicted')
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig('exit_velocity_predictions.png', dpi=300, bbox_inches='tight')
plt.show()
# Residual plot
residuals = y_test - y_pred
plt.figure(figsize=(10, 6))
plt.scatter(y_pred, residuals, alpha=0.5, edgecolors='black')
plt.axhline(y=0, color='r', linestyle='--', linewidth=2)
plt.xlabel('Predicted Exit Velocity (mph)')
plt.ylabel('Residual (Actual - Predicted)')
plt.title('Residual Plot')
plt.grid(True, alpha=0.3)
plt.savefig('exit_velocity_residuals.png', dpi=300, bbox_inches='tight')
plt.show()
```
### Example 3: Feature Importance Analysis (Python)
Analyze which features matter most for predictions using tree-based models.
```python
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
import matplotlib.pyplot as plt
import seaborn as sns
# Load data (example: predicting home runs)
# Columns: exit_velocity, launch_angle, spray_angle, spin_rate,
# hang_time, ballpark_factor, home_run (0/1)
data = pd.read_csv('batted_balls_hr.csv')
# Prepare features and target
feature_cols = ['exit_velocity', 'launch_angle', 'spray_angle',
'spin_rate', 'hang_time', 'ballpark_factor']
X = data[feature_cols]
y = data['home_run']
# Train Random Forest
rf = RandomForestClassifier(n_estimators=200, max_depth=10, random_state=42)
rf.fit(X, y)
# Feature importance from Random Forest
importance_df = pd.DataFrame({
'feature': feature_cols,
'importance': rf.feature_importances_
}).sort_values('importance', ascending=False)
print("Feature Importance Ranking")
print("=" * 50)
print(importance_df.to_string(index=False))
# Visualize feature importance
plt.figure(figsize=(10, 6))
plt.barh(importance_df['feature'], importance_df['importance'])
plt.xlabel('Importance')
plt.ylabel('Feature')
plt.title('Feature Importance for Home Run Prediction')
plt.gca().invert_yaxis()
plt.tight_layout()
plt.savefig('feature_importance.png', dpi=300, bbox_inches='tight')
plt.show()
# Permutation importance (more reliable but slower)
from sklearn.inspection import permutation_importance
perm_importance = permutation_importance(
rf, X, y, n_repeats=10, random_state=42, n_jobs=-1
)
perm_importance_df = pd.DataFrame({
'feature': feature_cols,
'importance_mean': perm_importance.importances_mean,
'importance_std': perm_importance.importances_std
}).sort_values('importance_mean', ascending=False)
print("\nPermutation Feature Importance")
print("=" * 50)
print(perm_importance_df.to_string(index=False))
# Visualize with error bars
plt.figure(figsize=(10, 6))
plt.barh(perm_importance_df['feature'], perm_importance_df['importance_mean'],
xerr=perm_importance_df['importance_std'])
plt.xlabel('Permutation Importance')
plt.ylabel('Feature')
plt.title('Permutation Feature Importance for Home Run Prediction')
plt.gca().invert_yaxis()
plt.tight_layout()
plt.savefig('permutation_importance.png', dpi=300, bbox_inches='tight')
plt.show()
# Partial dependence plots (show relationship between features and predictions)
from sklearn.inspection import PartialDependenceDisplay
features_to_plot = [0, 1] # exit_velocity, launch_angle
PartialDependenceDisplay.from_estimator(
rf, X, features_to_plot, feature_names=feature_cols
)
plt.suptitle('Partial Dependence Plots')
plt.tight_layout()
plt.savefig('partial_dependence.png', dpi=300, bbox_inches='tight')
plt.show()
```
### Example 4: Model Evaluation Metrics (Python)
Comprehensive evaluation of classification models with multiple metrics.
```python
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (accuracy_score, precision_score, recall_score,
f1_score, roc_auc_score, confusion_matrix,
classification_report, roc_curve)
import matplotlib.pyplot as plt
# Load data (example: predicting stolen base success)
# Columns: runner_speed, pitcher_time_to_plate, catcher_pop_time,
# lead_distance, pitcher_pickoff_move, success (0/1)
data = pd.read_csv('stolen_base_attempts.csv')
features = ['runner_speed', 'pitcher_time_to_plate', 'catcher_pop_time',
'lead_distance', 'pitcher_pickoff_move']
X = data[features]
y = data['success']
# Split data
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# Train multiple models
models = {
'Logistic Regression': LogisticRegression(max_iter=1000),
'Random Forest': RandomForestClassifier(n_estimators=100, random_state=42),
'Gradient Boosting': GradientBoostingClassifier(n_estimators=100, random_state=42)
}
results = []
for name, model in models.items():
# Train
model.fit(X_train, y_train)
# Predict
y_pred = model.predict(X_test)
y_pred_proba = model.predict_proba(X_test)[:, 1]
# Calculate metrics
results.append({
'Model': name,
'Accuracy': accuracy_score(y_test, y_pred),
'Precision': precision_score(y_test, y_pred),
'Recall': recall_score(y_test, y_pred),
'F1 Score': f1_score(y_test, y_pred),
'ROC AUC': roc_auc_score(y_test, y_pred_proba)
})
# Display results
results_df = pd.DataFrame(results)
print("Model Comparison")
print("=" * 80)
print(results_df.to_string(index=False))
# Detailed report for best model (by ROC AUC)
best_model_name = results_df.loc[results_df['ROC AUC'].idxmax(), 'Model']
best_model = models[best_model_name]
y_pred_best = best_model.predict(X_test)
print(f"\nDetailed Report for {best_model_name}")
print("=" * 80)
print(classification_report(y_test, y_pred_best,
target_names=['Failed', 'Successful']))
# Confusion Matrix
cm = confusion_matrix(y_test, y_pred_best)
plt.figure(figsize=(8, 6))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title(f'Confusion Matrix - {best_model_name}')
plt.colorbar()
tick_marks = np.arange(2)
plt.xticks(tick_marks, ['Failed', 'Successful'])
plt.yticks(tick_marks, ['Failed', 'Successful'])
# Add text annotations
thresh = cm.max() / 2.
for i, j in np.ndindex(cm.shape):
plt.text(j, i, format(cm[i, j], 'd'),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.tight_layout()
plt.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight')
plt.show()
# ROC Curves for all models
plt.figure(figsize=(10, 8))
for name, model in models.items():
y_pred_proba = model.predict_proba(X_test)[:, 1]
fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
auc = roc_auc_score(y_test, y_pred_proba)
plt.plot(fpr, tpr, label=f'{name} (AUC = {auc:.3f})')
plt.plot([0, 1], [0, 1], 'k--', label='Random Classifier')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curves - Stolen Base Success Prediction')
plt.legend(loc="lower right")
plt.grid(True, alpha=0.3)
plt.savefig('roc_curves.png', dpi=300, bbox_inches='tight')
plt.show()
```
### Example 5: Basic Classification with R (caret)
Predicting pitch type using the caret framework in R.
```r
library(caret)
library(dplyr)
# Load pitch data
# Columns: velocity, spin_rate, horizontal_break, vertical_break, pitch_type
pitches <- read.csv('pitch_data.csv')
# Convert pitch_type to factor
pitches$pitch_type <- as.factor(pitches$pitch_type)
# Feature engineering
pitches <- pitches %>%
mutate(
total_break = sqrt(horizontal_break^2 + vertical_break^2),
break_angle = atan2(vertical_break, horizontal_break),
spin_efficiency = spin_rate / velocity
)
# Select features
feature_cols <- c('velocity', 'spin_rate', 'horizontal_break',
'vertical_break', 'total_break', 'spin_efficiency')
# Split data (75% training, 25% testing)
set.seed(42)
train_index <- createDataPartition(pitches$pitch_type, p = 0.75, list = FALSE)
train_data <- pitches[train_index, ]
test_data <- pitches[-train_index, ]
# Set up cross-validation
train_control <- trainControl(
method = "cv",
number = 5,
classProbs = TRUE,
summaryFunction = multiClassSummary
)
# Train Random Forest model
rf_model <- train(
x = train_data[, feature_cols],
y = train_data$pitch_type,
method = "rf",
trControl = train_control,
tuneGrid = expand.grid(mtry = c(2, 3, 4)),
ntree = 100,
metric = "Accuracy"
)
# Print model results
print(rf_model)
# Make predictions
predictions <- predict(rf_model, newdata = test_data[, feature_cols])
# Confusion Matrix
cm <- confusionMatrix(predictions, test_data$pitch_type)
print(cm)
# Feature importance
importance <- varImp(rf_model)
print(importance)
plot(importance, main = "Feature Importance - Pitch Type Classification")
# Overall accuracy
accuracy <- sum(predictions == test_data$pitch_type) / length(predictions)
cat("\nTest Set Accuracy:", round(accuracy * 100, 2), "%\n")
```
### Example 6: Regression with R (tidymodels)
Predicting player OPS (On-base Plus Slugging) using tidymodels.
```r
library(tidymodels)
library(dplyr)
# Load player statistics
# Columns: age, hard_hit_rate, barrel_rate, whiff_rate, walk_rate, ops
players <- read.csv('player_stats.csv')
# Initial split
set.seed(42)
data_split <- initial_split(players, prop = 0.75)
train_data <- training(data_split)
test_data <- testing(data_split)
# Create recipe (preprocessing steps)
rec <- recipe(ops ~ age + hard_hit_rate + barrel_rate + whiff_rate + walk_rate,
data = train_data) %>%
step_normalize(all_numeric_predictors()) %>%
step_poly(age, degree = 2) # Add quadratic age term
# Define model
rf_spec <- rand_forest(
trees = 100,
mtry = 3,
min_n = 5
) %>%
set_engine("ranger", importance = "impurity") %>%
set_mode("regression")
# Create workflow
wf <- workflow() %>%
add_recipe(rec) %>%
add_model(rf_spec)
# Train model
rf_fit <- wf %>% fit(data = train_data)
# Predictions on test set
predictions <- rf_fit %>%
predict(test_data) %>%
bind_cols(test_data)
# Evaluation metrics
metrics <- predictions %>%
metrics(truth = ops, estimate = .pred)
print(metrics)
# Custom metrics
mae_val <- predictions %>% mae(truth = ops, estimate = .pred)
rmse_val <- predictions %>% rmse(truth = ops, estimate = .pred)
rsq_val <- predictions %>% rsq(truth = ops, estimate = .pred)
cat("Mean Absolute Error:", round(mae_val$.estimate, 4), "\n")
cat("Root Mean Squared Error:", round(rmse_val$.estimate, 4), "\n")
cat("R-squared:", round(rsq_val$.estimate, 4), "\n")
# Plot predictions vs actual
library(ggplot2)
ggplot(predictions, aes(x = ops, y = .pred)) +
geom_point(alpha = 0.6) +
geom_abline(slope = 1, intercept = 0, color = "red", linetype = "dashed") +
labs(x = "Actual OPS", y = "Predicted OPS",
title = "OPS Prediction: Actual vs Predicted") +
theme_minimal()
# Cross-validation for robust evaluation
folds <- vfold_cv(train_data, v = 5)
cv_results <- wf %>%
fit_resamples(
resamples = folds,
metrics = metric_set(rmse, mae, rsq)
)
cv_metrics <- cv_results %>% collect_metrics()
print(cv_metrics)
```
## Conclusion and Next Steps
Machine learning has become indispensable in modern baseball analytics. From predicting pitch outcomes to optimizing player development, ML provides competitive advantages at every level of the game.
**Key takeaways:**
- Start with clear, specific questions
- Engineer meaningful features using baseball domain knowledge
- Use proper train/test splits to ensure generalization
- Compare multiple algorithms and tune hyperparameters
- Validate thoroughly before deploying models
- Interpret results and generate actionable insights
**Resources for continued learning:**
- **Books**: "Introduction to Statistical Learning" (free PDF), "Analyzing Baseball Data with R"
- **Courses**: Andrew Ng's Machine Learning course, Fast.ai
- **Data sources**: Baseball Savant, Retrosheet, pybaseball library
- **Communities**: r/Sabermetrics, Baseball Prospectus, FanGraphs
**Next steps:**
1. Download real baseball data from Baseball Savant or pybaseball
2. Replicate the code examples with actual data
3. Start with simple problems (pitch classification, hit prediction)
4. Gradually tackle more complex projects
5. Share your findings with the baseball analytics community
Machine learning is a journey, not a destination. Start simple, iterate constantly, and always validate your models rigorously. The combination of baseball knowledge and ML skills creates powerful synergies—your understanding of the game guides feature engineering, while ML uncovers patterns you might never find manually.
Happy modeling, and may your predictions be accurate!
Discussion
Have questions or feedback? Join our community discussion on
Discord or
GitHub Discussions.
Table of Contents
Related Topics
Quick Actions