Case Study 2: From Notebook to Production Script
Overview
Every data analyst has experienced this: you start exploring data in a Jupyter notebook, and before you know it, you have 80 cells of code that kind of works but is fragile, poorly organized, and impossible to run automatically. In this case study, we take a realistic exploratory notebook — one that analyzes prediction market calibration — and refactor it into a clean, tested Python module with proper structure, error handling, and documentation.
This is one of the most practical skills in the data science workflow. The notebook is where you think; the module is where you build.
The Starting Point: A Messy Notebook
Here is what our fictional notebook looks like. It was written over several sessions of exploration, with cells added and re-run in various orders. It works (most of the time) but has numerous problems.
Original Notebook: calibration_analysis.ipynb
Cell 1:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import sqlite3
%matplotlib inline
Cell 2:
# Load data
conn = sqlite3.connect('data/markets.db')
df = pd.read_sql("SELECT * FROM markets WHERE resolution IS NOT NULL", conn)
prices = pd.read_sql("SELECT * FROM price_snapshots", conn)
conn.close()
print(f"Markets: {len(df)}, Price snapshots: {len(prices)}")
Cell 3:
# oops, need this
import seaborn as sns
sns.set()
Cell 4:
# get the last price before resolution for each market
last_prices = prices.sort_values('timestamp').groupby('market_id').last()
last_prices.head()
Cell 5:
# merge
merged = df.merge(last_prices[['yes_price']], left_on='id', right_index=True, how='inner')
merged['outcome'] = merged['resolution'].astype(int)
print(len(merged))
Cell 6:
# calibration calculation
# bin the predictions and calculate frequency of positive outcomes in each bin
n_bins = 10
merged['prob_bin'] = pd.cut(merged['yes_price'], bins=n_bins)
cal = merged.groupby('prob_bin').agg(
mean_predicted=('yes_price', 'mean'),
mean_outcome=('outcome', 'mean'),
count=('outcome', 'count')
)
cal
Cell 7:
# plot calibration
plt.figure(figsize=(8,8))
plt.plot([0,1],[0,1],'k--', label='Perfect')
plt.scatter(cal['mean_predicted'], cal['mean_outcome'], s=cal['count']*2)
for idx, row in cal.iterrows():
plt.annotate(f"n={int(row['count'])}", (row['mean_predicted'], row['mean_outcome']),
fontsize=8, ha='center', va='bottom')
plt.xlabel('Predicted Probability')
plt.ylabel('Observed Frequency')
plt.title('Market Calibration')
plt.legend()
plt.savefig('calibration.png')
plt.show()
Cell 8:
# brier score
brier = np.mean((merged['yes_price'] - merged['outcome'])**2)
print(f"Brier score: {brier:.4f}")
Cell 9:
# TODO: add log score
# TODO: break down by category
# TODO: handle the case where there are fewer than 10 resolved markets
Cell 10:
# quick check by category
for cat in merged['category'].unique():
subset = merged[merged['category'] == cat]
if len(subset) >= 5:
b = np.mean((subset['yes_price'] - subset['outcome'])**2)
print(f"{cat}: Brier={b:.4f} (n={len(subset)})")
Problems With This Notebook
- Scattered imports:
seabornis imported in Cell 3 because it was forgotten in Cell 1. - Hardcoded paths:
'data/markets.db'is hardcoded, making it inflexible. - No error handling: What if the database is empty? What if a column is missing?
- No functions: Everything is top-level code, making reuse impossible.
- Hidden state: Cells depend on variables created in earlier cells, but the order is not obvious.
- No type hints or documentation: Future readers (including yourself) will struggle to understand the code.
- Incomplete features: The TODOs in Cell 9 were never implemented.
- Magic numbers:
n_bins = 10ands=cal['count']*2are unexplained. - No tests: How do we know the calibration calculation is correct?
- Cannot be run from the command line: It is a notebook, not a script.
The Refactoring Process
We will transform this notebook into a well-structured module through a systematic process.
Step 1: Extract Functions
The first step is to identify distinct operations and wrap them in functions with clear inputs and outputs.
From the notebook, we identify these operations: 1. Load data from the database 2. Calculate the last price before resolution for each market 3. Calculate calibration statistics (bin predictions, compare to outcomes) 4. Calculate scoring metrics (Brier, log score) 5. Generate visualizations 6. Break down by category
Step 2: Define Data Contracts
What data goes in and what comes out? We define this explicitly with type hints and docstrings.
Step 3: Add Error Handling
What can go wrong? Empty databases, missing columns, insufficient data for binning, division by zero.
Step 4: Write Tests
We write tests alongside the code, not after.
Step 5: Create the Command-Line Interface
Make the module runnable from the terminal.
The Result: A Clean Module
calibration.py
"""
Calibration Analysis Module
============================
Analyzes the calibration of prediction market prices: are events priced at
X% actually occurring X% of the time?
This module can be used as a library (import individual functions) or as a
command-line tool:
python calibration.py --db data/markets.db --output reports/calibration.png
Functions:
load_resolved_markets: Load resolved markets with their final prices
calculate_calibration: Compute calibration curve data
calculate_scores: Compute Brier and log scoring metrics
category_breakdown: Analyze calibration by market category
plot_calibration_report: Generate a comprehensive calibration report figure
"""
from __future__ import annotations
import argparse
import logging
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Sequence
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
import pandas as pd
import seaborn as sns
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Data structures
# ---------------------------------------------------------------------------
@dataclass
class CalibrationBin:
"""A single bin in the calibration analysis."""
bin_lower: float
bin_upper: float
mean_predicted: float
mean_observed: float
count: int
@property
def bin_center(self) -> float:
return (self.bin_lower + self.bin_upper) / 2
@property
def deviation(self) -> float:
"""How far off calibration is this bin (positive = overconfident)."""
return self.mean_predicted - self.mean_observed
@dataclass
class CalibrationResult:
"""Complete calibration analysis results."""
bins: list[CalibrationBin]
brier_score: float
log_score: float
n_markets: int
category_scores: dict[str, dict] = field(default_factory=dict)
@property
def mean_absolute_calibration_error(self) -> float:
"""Weighted mean absolute deviation from perfect calibration."""
if not self.bins:
return float('nan')
total_count = sum(b.count for b in self.bins)
if total_count == 0:
return float('nan')
weighted_sum = sum(
abs(b.deviation) * b.count for b in self.bins
)
return weighted_sum / total_count
def summary(self) -> str:
"""Return a human-readable summary."""
lines = [
"Calibration Analysis Summary",
"=" * 40,
f"Markets analyzed: {self.n_markets}",
f"Brier score: {self.brier_score:.4f}",
f"Log score: {self.log_score:.4f}",
f"Mean abs cal err: {self.mean_absolute_calibration_error:.4f}",
"",
f"{'Bin':>12} {'Predicted':>10} {'Observed':>10} {'Count':>6} {'Dev':>8}",
"-" * 48,
]
for b in self.bins:
bin_label = f"{b.bin_lower:.1f}-{b.bin_upper:.1f}"
lines.append(
f"{bin_label:>12} {b.mean_predicted:>10.3f} "
f"{b.mean_observed:>10.3f} {b.count:>6} "
f"{b.deviation:>+8.3f}"
)
if self.category_scores:
lines.append("")
lines.append("By Category:")
lines.append(f"{'Category':>20} {'Brier':>8} {'N':>6}")
lines.append("-" * 36)
for cat, scores in sorted(self.category_scores.items()):
lines.append(
f"{cat:>20} {scores['brier']:.4f} "
f"{scores['count']:>6}"
)
return "\n".join(lines)
# ---------------------------------------------------------------------------
# Data loading
# ---------------------------------------------------------------------------
def load_resolved_markets(
db_path: str,
min_volume: float = 0
) -> pd.DataFrame:
"""
Load resolved markets with their final price before resolution.
Args:
db_path: Path to the SQLite database
min_volume: Minimum total volume to include a market
Returns:
DataFrame with columns: market_id, title, category, yes_price,
outcome (0 or 1), volume
Raises:
FileNotFoundError: If the database file does not exist
ValueError: If no resolved markets are found
"""
import sqlite3
db_file = Path(db_path)
if not db_file.exists():
raise FileNotFoundError(f"Database not found: {db_path}")
conn = sqlite3.connect(str(db_file))
try:
# Load resolved markets
markets = pd.read_sql("""
SELECT id, title, category, resolution, volume
FROM markets
WHERE resolution IS NOT NULL
""", conn)
if markets.empty:
raise ValueError(
"No resolved markets found in the database. "
"Calibration analysis requires markets that have resolved."
)
# Load the last price snapshot for each market
last_prices = pd.read_sql("""
SELECT market_id, yes_price, no_price, timestamp
FROM price_snapshots p1
WHERE timestamp = (
SELECT MAX(timestamp)
FROM price_snapshots p2
WHERE p2.market_id = p1.market_id
)
""", conn)
if last_prices.empty:
raise ValueError(
"No price snapshots found. Cannot determine final prices."
)
finally:
conn.close()
# Merge markets with their final prices
merged = markets.merge(
last_prices[['market_id', 'yes_price']],
left_on='id',
right_on='market_id',
how='inner'
)
if merged.empty:
raise ValueError(
"No overlap between resolved markets and price snapshots. "
"Check that market IDs are consistent."
)
# Apply volume filter
if min_volume > 0 and 'volume' in merged.columns:
before = len(merged)
merged = merged[merged['volume'] >= min_volume]
logger.info(
f"Volume filter ({min_volume}): {before} -> {len(merged)} markets"
)
# Clean up
merged['outcome'] = merged['resolution'].astype(int)
merged = merged[['id', 'title', 'category', 'yes_price', 'outcome', 'volume']]
merged = merged.rename(columns={'id': 'market_id'})
logger.info(f"Loaded {len(merged)} resolved markets for calibration analysis")
return merged
# ---------------------------------------------------------------------------
# Calibration calculation
# ---------------------------------------------------------------------------
def calculate_calibration(
predicted: Sequence[float],
actual: Sequence[int],
n_bins: int = 10,
min_bin_count: int = 1
) -> list[CalibrationBin]:
"""
Calculate calibration curve by binning predictions.
Args:
predicted: Predicted probabilities (0 to 1)
actual: Actual outcomes (0 or 1)
n_bins: Number of equal-width bins
min_bin_count: Minimum number of observations in a bin to include it
Returns:
List of CalibrationBin objects
Raises:
ValueError: If inputs are empty or have different lengths
"""
predicted = np.asarray(predicted, dtype=float)
actual = np.asarray(actual, dtype=int)
if len(predicted) == 0:
raise ValueError("No predictions provided")
if len(predicted) != len(actual):
raise ValueError(
f"Length mismatch: {len(predicted)} predictions vs "
f"{len(actual)} outcomes"
)
# Validate ranges
if np.any((predicted < 0) | (predicted > 1)):
raise ValueError("Predictions must be between 0 and 1")
if not np.all(np.isin(actual, [0, 1])):
raise ValueError("Outcomes must be 0 or 1")
bin_edges = np.linspace(0, 1, n_bins + 1)
bins = []
for i in range(n_bins):
lower, upper = bin_edges[i], bin_edges[i + 1]
if i < n_bins - 1:
mask = (predicted >= lower) & (predicted < upper)
else:
# Last bin includes right edge
mask = (predicted >= lower) & (predicted <= upper)
count = mask.sum()
if count >= min_bin_count:
bins.append(CalibrationBin(
bin_lower=float(lower),
bin_upper=float(upper),
mean_predicted=float(predicted[mask].mean()),
mean_observed=float(actual[mask].mean()),
count=int(count)
))
if not bins:
logger.warning(
f"No bins met the minimum count threshold ({min_bin_count}). "
f"Total observations: {len(predicted)}"
)
return bins
def calculate_scores(
predicted: Sequence[float],
actual: Sequence[int],
epsilon: float = 1e-10
) -> dict[str, float]:
"""
Calculate forecast scoring metrics.
Args:
predicted: Predicted probabilities (0 to 1)
actual: Actual outcomes (0 or 1)
epsilon: Small value to avoid log(0)
Returns:
Dictionary with 'brier' and 'log_score' keys
"""
predicted = np.asarray(predicted, dtype=float)
actual = np.asarray(actual, dtype=int)
# Brier score: mean squared error of probabilities
brier = float(np.mean((predicted - actual) ** 2))
# Logarithmic score
clipped = np.clip(predicted, epsilon, 1 - epsilon)
log_scores = actual * np.log(clipped) + (1 - actual) * np.log(1 - clipped)
log_score = float(np.mean(log_scores))
return {
'brier': brier,
'log_score': log_score,
}
def category_breakdown(
data: pd.DataFrame,
min_category_size: int = 5
) -> dict[str, dict]:
"""
Calculate calibration scores broken down by category.
Args:
data: DataFrame with 'category', 'yes_price', and 'outcome' columns
min_category_size: Minimum markets per category to include
Returns:
Dictionary mapping category names to score dictionaries
"""
results = {}
if 'category' not in data.columns:
logger.warning("No 'category' column found. Skipping breakdown.")
return results
for category, group in data.groupby('category'):
if len(group) < min_category_size:
logger.debug(
f"Skipping category '{category}': "
f"only {len(group)} markets (min: {min_category_size})"
)
continue
scores = calculate_scores(
group['yes_price'].values,
group['outcome'].values
)
scores['count'] = len(group)
results[str(category)] = scores
return results
# ---------------------------------------------------------------------------
# Full analysis
# ---------------------------------------------------------------------------
def run_calibration_analysis(
db_path: str,
n_bins: int = 10,
min_volume: float = 0,
min_bin_count: int = 1,
min_category_size: int = 5
) -> CalibrationResult:
"""
Run the complete calibration analysis pipeline.
Args:
db_path: Path to the SQLite database
n_bins: Number of calibration bins
min_volume: Minimum market volume to include
min_bin_count: Minimum observations per bin
min_category_size: Minimum markets per category
Returns:
CalibrationResult with all analysis results
"""
# Load data
data = load_resolved_markets(db_path, min_volume=min_volume)
# Calculate calibration
bins = calculate_calibration(
data['yes_price'].values,
data['outcome'].values,
n_bins=n_bins,
min_bin_count=min_bin_count
)
# Calculate scores
scores = calculate_scores(
data['yes_price'].values,
data['outcome'].values
)
# Category breakdown
cat_scores = category_breakdown(
data, min_category_size=min_category_size
)
result = CalibrationResult(
bins=bins,
brier_score=scores['brier'],
log_score=scores['log_score'],
n_markets=len(data),
category_scores=cat_scores
)
logger.info(f"Analysis complete: Brier={result.brier_score:.4f}, "
f"markets={result.n_markets}")
return result
# ---------------------------------------------------------------------------
# Visualization
# ---------------------------------------------------------------------------
def plot_calibration_report(
result: CalibrationResult,
title: str = "Prediction Market Calibration Report",
save_path: Optional[str] = None,
figsize: tuple = (14, 10)
) -> plt.Figure:
"""
Generate a comprehensive calibration report figure.
Creates a 2x2 grid:
- Top left: Calibration curve
- Top right: Prediction distribution histogram
- Bottom left: Category comparison (bar chart)
- Bottom right: Summary statistics (text)
Args:
result: CalibrationResult from run_calibration_analysis
title: Figure title
save_path: Path to save the figure (None to skip saving)
figsize: Figure size
Returns:
matplotlib Figure
"""
sns.set_theme(style="whitegrid")
fig, axes = plt.subplots(2, 2, figsize=figsize)
fig.suptitle(title, fontsize=16, fontweight='bold', y=0.98)
# --- Top Left: Calibration Curve ---
ax_cal = axes[0, 0]
if result.bins:
centers = [b.mean_predicted for b in result.bins]
observed = [b.mean_observed for b in result.bins]
counts = [b.count for b in result.bins]
# Scale point sizes
max_count = max(counts)
sizes = [max(30, (c / max_count) * 200) for c in counts]
ax_cal.plot([0, 1], [0, 1], 'k--', alpha=0.5, label='Perfect calibration')
ax_cal.scatter(
centers, observed, s=sizes, c='#2196F3',
edgecolors='white', linewidth=1, zorder=3, label='Observed'
)
ax_cal.plot(centers, observed, '-', color='#2196F3', alpha=0.4, linewidth=1)
# Annotate bin counts
for pred, obs, n in zip(centers, observed, counts):
ax_cal.annotate(
f'n={n}', (pred, obs),
textcoords="offset points", xytext=(0, 10),
fontsize=7, ha='center', color='gray'
)
ax_cal.set_xlabel('Predicted Probability')
ax_cal.set_ylabel('Observed Frequency')
ax_cal.set_title('Calibration Curve')
ax_cal.set_xlim(-0.02, 1.02)
ax_cal.set_ylim(-0.02, 1.02)
ax_cal.set_aspect('equal')
ax_cal.legend(loc='upper left', fontsize=9)
# --- Top Right: Prediction Distribution ---
ax_hist = axes[0, 1]
if result.bins:
bin_edges = [b.bin_lower for b in result.bins] + [result.bins[-1].bin_upper]
counts = [b.count for b in result.bins]
bar_centers = [b.bin_center for b in result.bins]
bar_width = (bin_edges[1] - bin_edges[0]) * 0.8
ax_hist.bar(
bar_centers, counts, width=bar_width,
color='#2196F3', alpha=0.7, edgecolor='white'
)
ax_hist.set_xlabel('Predicted Probability')
ax_hist.set_ylabel('Number of Markets')
ax_hist.set_title('Distribution of Predictions')
ax_hist.set_xlim(-0.02, 1.02)
# --- Bottom Left: Category Comparison ---
ax_cat = axes[1, 0]
if result.category_scores:
categories = sorted(result.category_scores.keys())
brier_scores = [result.category_scores[c]['brier'] for c in categories]
cat_counts = [result.category_scores[c]['count'] for c in categories]
y_pos = np.arange(len(categories))
bars = ax_cat.barh(
y_pos, brier_scores,
color='#2196F3', alpha=0.7, edgecolor='white'
)
# Annotate with count
for i, (score, count) in enumerate(zip(brier_scores, cat_counts)):
ax_cat.text(
score + 0.005, i, f'n={count}',
va='center', fontsize=8, color='gray'
)
ax_cat.set_yticks(y_pos)
ax_cat.set_yticklabels(categories, fontsize=9)
ax_cat.set_xlabel('Brier Score (lower is better)')
ax_cat.set_title('Calibration by Category')
ax_cat.axvline(x=result.brier_score, color='red', linestyle='--',
alpha=0.5, label=f'Overall ({result.brier_score:.3f})')
ax_cat.legend(fontsize=8)
else:
ax_cat.text(
0.5, 0.5, 'No category data available',
ha='center', va='center', fontsize=12, color='gray',
transform=ax_cat.transAxes
)
ax_cat.set_title('Calibration by Category')
# --- Bottom Right: Summary Statistics ---
ax_summary = axes[1, 1]
ax_summary.axis('off')
summary_text = (
f"Overall Metrics\n"
f"{'─' * 30}\n"
f"Markets analyzed: {result.n_markets:>6}\n"
f"Brier score: {result.brier_score:>6.4f}\n"
f"Log score: {result.log_score:>6.4f}\n"
f"Mean abs cal error: {result.mean_absolute_calibration_error:>6.4f}\n"
f"\n"
f"Interpretation\n"
f"{'─' * 30}\n"
)
if result.brier_score < 0.15:
summary_text += "Brier: Excellent calibration\n"
elif result.brier_score < 0.25:
summary_text += "Brier: Good calibration\n"
else:
summary_text += "Brier: Poor calibration\n"
if result.mean_absolute_calibration_error < 0.05:
summary_text += "MACE: Well-calibrated\n"
elif result.mean_absolute_calibration_error < 0.10:
summary_text += "MACE: Moderately calibrated\n"
else:
summary_text += "MACE: Poorly calibrated\n"
ax_summary.text(
0.1, 0.9, summary_text,
transform=ax_summary.transAxes,
fontsize=11, verticalalignment='top',
fontfamily='monospace',
bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8)
)
plt.tight_layout(rect=[0, 0, 1, 0.95])
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches='tight')
logger.info(f"Report saved to {save_path}")
return fig
# ---------------------------------------------------------------------------
# Command-line interface
# ---------------------------------------------------------------------------
def parse_args(argv: list[str] = None) -> argparse.Namespace:
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(
description="Analyze prediction market calibration",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python calibration.py --db data/markets.db
python calibration.py --db data/markets.db --output reports/cal.png --bins 15
python calibration.py --db data/markets.db --min-volume 1000 --verbose
"""
)
parser.add_argument(
'--db', required=True,
help='Path to SQLite database'
)
parser.add_argument(
'--output', default=None,
help='Path to save the calibration report figure'
)
parser.add_argument(
'--bins', type=int, default=10,
help='Number of calibration bins (default: 10)'
)
parser.add_argument(
'--min-volume', type=float, default=0,
help='Minimum market volume to include (default: 0)'
)
parser.add_argument(
'--min-category-size', type=int, default=5,
help='Minimum markets per category (default: 5)'
)
parser.add_argument(
'--verbose', action='store_true',
help='Enable verbose logging'
)
parser.add_argument(
'--no-plot', action='store_true',
help='Skip generating the plot (text output only)'
)
return parser.parse_args(argv)
def main(argv: list[str] = None) -> int:
"""Main entry point for command-line usage."""
args = parse_args(argv)
# Configure logging
log_level = logging.DEBUG if args.verbose else logging.INFO
logging.basicConfig(
level=log_level,
format='%(asctime)s | %(name)s | %(levelname)s | %(message)s',
datefmt='%H:%M:%S'
)
try:
# Run analysis
result = run_calibration_analysis(
db_path=args.db,
n_bins=args.bins,
min_volume=args.min_volume,
min_category_size=args.min_category_size
)
# Print summary
print()
print(result.summary())
print()
# Generate plot
if not args.no_plot:
import matplotlib
if args.output:
matplotlib.use('Agg')
fig = plot_calibration_report(result, save_path=args.output)
if not args.output:
plt.show()
return 0
except FileNotFoundError as e:
logger.error(str(e))
return 1
except ValueError as e:
logger.error(str(e))
return 1
except Exception as e:
logger.error(f"Unexpected error: {e}", exc_info=True)
return 2
if __name__ == "__main__":
sys.exit(main())
Tests: test_calibration.py
"""Tests for the calibration analysis module."""
import pytest
import numpy as np
from calibration import (
CalibrationBin,
CalibrationResult,
calculate_calibration,
calculate_scores,
)
class TestCalibrationBin:
def test_bin_center(self):
b = CalibrationBin(0.0, 0.2, 0.10, 0.12, 50)
assert b.bin_center == pytest.approx(0.1)
def test_deviation_overconfident(self):
"""Predicted > observed means overconfident."""
b = CalibrationBin(0.6, 0.8, 0.70, 0.60, 30)
assert b.deviation == pytest.approx(0.10)
def test_deviation_underconfident(self):
"""Predicted < observed means underconfident."""
b = CalibrationBin(0.2, 0.4, 0.30, 0.40, 25)
assert b.deviation == pytest.approx(-0.10)
class TestCalculateCalibration:
def test_perfectly_calibrated(self):
"""Perfect predictions should show bins near the diagonal."""
np.random.seed(42)
n = 1000
# Generate predictions and then outcomes that match
predictions = np.random.uniform(0, 1, n)
outcomes = (np.random.uniform(0, 1, n) < predictions).astype(int)
bins = calculate_calibration(predictions, outcomes, n_bins=5)
assert len(bins) > 0
for b in bins:
# With enough data, observed should be close to predicted
assert abs(b.deviation) < 0.15, (
f"Bin {b.bin_lower:.1f}-{b.bin_upper:.1f}: "
f"deviation {b.deviation:.3f} too large"
)
def test_empty_input_raises(self):
with pytest.raises(ValueError, match="No predictions"):
calculate_calibration([], [])
def test_length_mismatch_raises(self):
with pytest.raises(ValueError, match="Length mismatch"):
calculate_calibration([0.5, 0.6], [1])
def test_invalid_predictions_raise(self):
with pytest.raises(ValueError, match="between 0 and 1"):
calculate_calibration([1.5], [1])
def test_invalid_outcomes_raise(self):
with pytest.raises(ValueError, match="must be 0 or 1"):
calculate_calibration([0.5], [2])
def test_min_bin_count(self):
"""Bins with too few observations should be excluded."""
predictions = [0.1, 0.9, 0.9, 0.9, 0.9]
outcomes = [0, 1, 1, 1, 0]
bins = calculate_calibration(
predictions, outcomes, n_bins=10, min_bin_count=2
)
# The bin containing 0.1 has only 1 observation and should be excluded
for b in bins:
assert b.count >= 2
def test_all_same_prediction(self):
"""All predictions in one bin should produce one bin."""
predictions = [0.5] * 100
outcomes = [1] * 60 + [0] * 40
bins = calculate_calibration(predictions, outcomes, n_bins=10)
non_empty = [b for b in bins if b.count > 0]
assert len(non_empty) == 1
assert non_empty[0].mean_observed == pytest.approx(0.6)
class TestCalculateScores:
def test_perfect_brier(self):
scores = calculate_scores([1.0, 0.0, 1.0], [1, 0, 1])
assert scores['brier'] == pytest.approx(0.0)
def test_worst_brier(self):
scores = calculate_scores([0.0, 1.0], [1, 0])
assert scores['brier'] == pytest.approx(1.0)
def test_log_score_perfect(self):
"""Near-perfect predictions should have log score close to 0."""
scores = calculate_scores([0.999, 0.001], [1, 0])
assert scores['log_score'] > -0.01
def test_log_score_bad(self):
"""Bad predictions should have very negative log score."""
scores = calculate_scores([0.01, 0.99], [1, 0])
assert scores['log_score'] < -4.0
def test_uniform_predictions(self):
"""50% predictions on balanced data should give Brier = 0.25."""
scores = calculate_scores([0.5] * 100, [1] * 50 + [0] * 50)
assert scores['brier'] == pytest.approx(0.25)
class TestCalibrationResult:
def test_mace_calculation(self):
bins = [
CalibrationBin(0.0, 0.5, 0.25, 0.30, 100), # dev = -0.05
CalibrationBin(0.5, 1.0, 0.75, 0.70, 100), # dev = +0.05
]
result = CalibrationResult(
bins=bins, brier_score=0.20, log_score=-0.5, n_markets=200
)
assert result.mean_absolute_calibration_error == pytest.approx(0.05)
def test_summary_string(self):
result = CalibrationResult(
bins=[], brier_score=0.20, log_score=-0.5, n_markets=100
)
summary = result.summary()
assert "100" in summary
assert "0.2000" in summary
def test_empty_bins_mace(self):
result = CalibrationResult(
bins=[], brier_score=0.0, log_score=0.0, n_markets=0
)
assert np.isnan(result.mean_absolute_calibration_error)
What Changed: A Comparison
| Aspect | Notebook | Refactored Module |
|---|---|---|
| Imports | Scattered across cells | Single block at the top |
| Functions | None (top-level code) | 7 well-defined functions |
| Type hints | None | Full type annotations |
| Error handling | None | ValueError, FileNotFoundError, logging |
| Data structures | Raw dicts/DataFrames | CalibrationBin, CalibrationResult dataclasses |
| Documentation | Sparse comments | Module, class, and function docstrings |
| Configurability | Hardcoded values | Command-line arguments |
| Testing | Manual inspection | 15+ automated test cases |
| Reusability | Copy-paste only | Import as a library |
| Execution | Jupyter only | CLI, library, or Jupyter |
| Output | Inline only | File, display, or text summary |
The Refactoring Process, Step by Step
For reference, here are the exact steps we followed:
- Read through the entire notebook and understand what it does holistically.
- List the distinct operations (load, transform, calculate, visualize).
- Define data structures for intermediate and final results (CalibrationBin, CalibrationResult).
- Extract each operation into a function with clear inputs, outputs, and docstrings.
- Add input validation at the beginning of each function.
- Add error handling for every expected failure mode.
- Add type hints to all function signatures.
- Write tests for each function, including edge cases.
- Add a command-line interface using argparse.
- Add logging at appropriate levels (DEBUG for details, INFO for flow, WARNING/ERROR for problems).
- Run all tests to verify nothing broke.
- Run the module end-to-end from the command line to verify it works as a script.
Key Lessons
1. Notebooks Are for Exploration, Modules Are for Reuse
The notebook was valuable for developing the analysis interactively. But once the approach is validated, extracting it into a module makes it reliable, testable, and composable with other tools.
2. Functions Are the Unit of Reuse
By wrapping each step in a function, we can reuse calculate_calibration in other analyses without dragging along the database loading or visualization code.
3. Data Classes Clarify Intent
CalibrationBin is clearer than a DataFrame row. It documents exactly what fields exist, what their types are, and what computed properties are available.
4. Tests Are Documentation
The test names (test_perfectly_calibrated, test_empty_input_raises, test_min_bin_count) document the expected behavior better than any comment.
5. Command-Line Interfaces Make Scripts Useful
The argparse interface lets other scripts, cron jobs, or colleagues use the analysis without opening Jupyter.
Discussion Questions
- What parts of the analysis might still benefit from remaining in notebook form? Why?
- How would you extend this module to compare calibration across different time periods (e.g., has calibration improved over the past year)?
- The current module uses SQLite directly. How would you modify it to also accept a pandas DataFrame as input, making it usable without a database?
- What additional visualizations would be useful in the calibration report?
- How would you set up continuous integration (CI) to run the tests automatically on every code change?