Case Study 2: From Notebook to Production Script

Overview

Every data analyst has experienced this: you start exploring data in a Jupyter notebook, and before you know it, you have 80 cells of code that kind of works but is fragile, poorly organized, and impossible to run automatically. In this case study, we take a realistic exploratory notebook — one that analyzes prediction market calibration — and refactor it into a clean, tested Python module with proper structure, error handling, and documentation.

This is one of the most practical skills in the data science workflow. The notebook is where you think; the module is where you build.

The Starting Point: A Messy Notebook

Here is what our fictional notebook looks like. It was written over several sessions of exploration, with cells added and re-run in various orders. It works (most of the time) but has numerous problems.

Original Notebook: calibration_analysis.ipynb

Cell 1:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import sqlite3
%matplotlib inline

Cell 2:

# Load data
conn = sqlite3.connect('data/markets.db')
df = pd.read_sql("SELECT * FROM markets WHERE resolution IS NOT NULL", conn)
prices = pd.read_sql("SELECT * FROM price_snapshots", conn)
conn.close()
print(f"Markets: {len(df)}, Price snapshots: {len(prices)}")

Cell 3:

# oops, need this
import seaborn as sns
sns.set()

Cell 4:

# get the last price before resolution for each market
last_prices = prices.sort_values('timestamp').groupby('market_id').last()
last_prices.head()

Cell 5:

# merge
merged = df.merge(last_prices[['yes_price']], left_on='id', right_index=True, how='inner')
merged['outcome'] = merged['resolution'].astype(int)
print(len(merged))

Cell 6:

# calibration calculation
# bin the predictions and calculate frequency of positive outcomes in each bin
n_bins = 10
merged['prob_bin'] = pd.cut(merged['yes_price'], bins=n_bins)
cal = merged.groupby('prob_bin').agg(
    mean_predicted=('yes_price', 'mean'),
    mean_outcome=('outcome', 'mean'),
    count=('outcome', 'count')
)
cal

Cell 7:

# plot calibration
plt.figure(figsize=(8,8))
plt.plot([0,1],[0,1],'k--', label='Perfect')
plt.scatter(cal['mean_predicted'], cal['mean_outcome'], s=cal['count']*2)
for idx, row in cal.iterrows():
    plt.annotate(f"n={int(row['count'])}", (row['mean_predicted'], row['mean_outcome']),
                fontsize=8, ha='center', va='bottom')
plt.xlabel('Predicted Probability')
plt.ylabel('Observed Frequency')
plt.title('Market Calibration')
plt.legend()
plt.savefig('calibration.png')
plt.show()

Cell 8:

# brier score
brier = np.mean((merged['yes_price'] - merged['outcome'])**2)
print(f"Brier score: {brier:.4f}")

Cell 9:

# TODO: add log score
# TODO: break down by category
# TODO: handle the case where there are fewer than 10 resolved markets

Cell 10:

# quick check by category
for cat in merged['category'].unique():
    subset = merged[merged['category'] == cat]
    if len(subset) >= 5:
        b = np.mean((subset['yes_price'] - subset['outcome'])**2)
        print(f"{cat}: Brier={b:.4f} (n={len(subset)})")

Problems With This Notebook

  1. Scattered imports: seaborn is imported in Cell 3 because it was forgotten in Cell 1.
  2. Hardcoded paths: 'data/markets.db' is hardcoded, making it inflexible.
  3. No error handling: What if the database is empty? What if a column is missing?
  4. No functions: Everything is top-level code, making reuse impossible.
  5. Hidden state: Cells depend on variables created in earlier cells, but the order is not obvious.
  6. No type hints or documentation: Future readers (including yourself) will struggle to understand the code.
  7. Incomplete features: The TODOs in Cell 9 were never implemented.
  8. Magic numbers: n_bins = 10 and s=cal['count']*2 are unexplained.
  9. No tests: How do we know the calibration calculation is correct?
  10. Cannot be run from the command line: It is a notebook, not a script.

The Refactoring Process

We will transform this notebook into a well-structured module through a systematic process.

Step 1: Extract Functions

The first step is to identify distinct operations and wrap them in functions with clear inputs and outputs.

From the notebook, we identify these operations: 1. Load data from the database 2. Calculate the last price before resolution for each market 3. Calculate calibration statistics (bin predictions, compare to outcomes) 4. Calculate scoring metrics (Brier, log score) 5. Generate visualizations 6. Break down by category

Step 2: Define Data Contracts

What data goes in and what comes out? We define this explicitly with type hints and docstrings.

Step 3: Add Error Handling

What can go wrong? Empty databases, missing columns, insufficient data for binning, division by zero.

Step 4: Write Tests

We write tests alongside the code, not after.

Step 5: Create the Command-Line Interface

Make the module runnable from the terminal.

The Result: A Clean Module

calibration.py

"""
Calibration Analysis Module
============================

Analyzes the calibration of prediction market prices: are events priced at
X% actually occurring X% of the time?

This module can be used as a library (import individual functions) or as a
command-line tool:

    python calibration.py --db data/markets.db --output reports/calibration.png

Functions:
    load_resolved_markets: Load resolved markets with their final prices
    calculate_calibration: Compute calibration curve data
    calculate_scores: Compute Brier and log scoring metrics
    category_breakdown: Analyze calibration by market category
    plot_calibration_report: Generate a comprehensive calibration report figure
"""

from __future__ import annotations

import argparse
import logging
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Sequence

import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
import pandas as pd
import seaborn as sns

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Data structures
# ---------------------------------------------------------------------------

@dataclass
class CalibrationBin:
    """A single bin in the calibration analysis."""
    bin_lower: float
    bin_upper: float
    mean_predicted: float
    mean_observed: float
    count: int

    @property
    def bin_center(self) -> float:
        return (self.bin_lower + self.bin_upper) / 2

    @property
    def deviation(self) -> float:
        """How far off calibration is this bin (positive = overconfident)."""
        return self.mean_predicted - self.mean_observed


@dataclass
class CalibrationResult:
    """Complete calibration analysis results."""
    bins: list[CalibrationBin]
    brier_score: float
    log_score: float
    n_markets: int
    category_scores: dict[str, dict] = field(default_factory=dict)

    @property
    def mean_absolute_calibration_error(self) -> float:
        """Weighted mean absolute deviation from perfect calibration."""
        if not self.bins:
            return float('nan')
        total_count = sum(b.count for b in self.bins)
        if total_count == 0:
            return float('nan')
        weighted_sum = sum(
            abs(b.deviation) * b.count for b in self.bins
        )
        return weighted_sum / total_count

    def summary(self) -> str:
        """Return a human-readable summary."""
        lines = [
            "Calibration Analysis Summary",
            "=" * 40,
            f"Markets analyzed:  {self.n_markets}",
            f"Brier score:       {self.brier_score:.4f}",
            f"Log score:         {self.log_score:.4f}",
            f"Mean abs cal err:  {self.mean_absolute_calibration_error:.4f}",
            "",
            f"{'Bin':>12} {'Predicted':>10} {'Observed':>10} {'Count':>6} {'Dev':>8}",
            "-" * 48,
        ]
        for b in self.bins:
            bin_label = f"{b.bin_lower:.1f}-{b.bin_upper:.1f}"
            lines.append(
                f"{bin_label:>12} {b.mean_predicted:>10.3f} "
                f"{b.mean_observed:>10.3f} {b.count:>6} "
                f"{b.deviation:>+8.3f}"
            )

        if self.category_scores:
            lines.append("")
            lines.append("By Category:")
            lines.append(f"{'Category':>20} {'Brier':>8} {'N':>6}")
            lines.append("-" * 36)
            for cat, scores in sorted(self.category_scores.items()):
                lines.append(
                    f"{cat:>20} {scores['brier']:.4f} "
                    f"{scores['count']:>6}"
                )

        return "\n".join(lines)


# ---------------------------------------------------------------------------
# Data loading
# ---------------------------------------------------------------------------

def load_resolved_markets(
    db_path: str,
    min_volume: float = 0
) -> pd.DataFrame:
    """
    Load resolved markets with their final price before resolution.

    Args:
        db_path: Path to the SQLite database
        min_volume: Minimum total volume to include a market

    Returns:
        DataFrame with columns: market_id, title, category, yes_price,
        outcome (0 or 1), volume

    Raises:
        FileNotFoundError: If the database file does not exist
        ValueError: If no resolved markets are found
    """
    import sqlite3

    db_file = Path(db_path)
    if not db_file.exists():
        raise FileNotFoundError(f"Database not found: {db_path}")

    conn = sqlite3.connect(str(db_file))
    try:
        # Load resolved markets
        markets = pd.read_sql("""
            SELECT id, title, category, resolution, volume
            FROM markets
            WHERE resolution IS NOT NULL
        """, conn)

        if markets.empty:
            raise ValueError(
                "No resolved markets found in the database. "
                "Calibration analysis requires markets that have resolved."
            )

        # Load the last price snapshot for each market
        last_prices = pd.read_sql("""
            SELECT market_id, yes_price, no_price, timestamp
            FROM price_snapshots p1
            WHERE timestamp = (
                SELECT MAX(timestamp)
                FROM price_snapshots p2
                WHERE p2.market_id = p1.market_id
            )
        """, conn)

        if last_prices.empty:
            raise ValueError(
                "No price snapshots found. Cannot determine final prices."
            )

    finally:
        conn.close()

    # Merge markets with their final prices
    merged = markets.merge(
        last_prices[['market_id', 'yes_price']],
        left_on='id',
        right_on='market_id',
        how='inner'
    )

    if merged.empty:
        raise ValueError(
            "No overlap between resolved markets and price snapshots. "
            "Check that market IDs are consistent."
        )

    # Apply volume filter
    if min_volume > 0 and 'volume' in merged.columns:
        before = len(merged)
        merged = merged[merged['volume'] >= min_volume]
        logger.info(
            f"Volume filter ({min_volume}): {before} -> {len(merged)} markets"
        )

    # Clean up
    merged['outcome'] = merged['resolution'].astype(int)
    merged = merged[['id', 'title', 'category', 'yes_price', 'outcome', 'volume']]
    merged = merged.rename(columns={'id': 'market_id'})

    logger.info(f"Loaded {len(merged)} resolved markets for calibration analysis")
    return merged


# ---------------------------------------------------------------------------
# Calibration calculation
# ---------------------------------------------------------------------------

def calculate_calibration(
    predicted: Sequence[float],
    actual: Sequence[int],
    n_bins: int = 10,
    min_bin_count: int = 1
) -> list[CalibrationBin]:
    """
    Calculate calibration curve by binning predictions.

    Args:
        predicted: Predicted probabilities (0 to 1)
        actual: Actual outcomes (0 or 1)
        n_bins: Number of equal-width bins
        min_bin_count: Minimum number of observations in a bin to include it

    Returns:
        List of CalibrationBin objects

    Raises:
        ValueError: If inputs are empty or have different lengths
    """
    predicted = np.asarray(predicted, dtype=float)
    actual = np.asarray(actual, dtype=int)

    if len(predicted) == 0:
        raise ValueError("No predictions provided")
    if len(predicted) != len(actual):
        raise ValueError(
            f"Length mismatch: {len(predicted)} predictions vs "
            f"{len(actual)} outcomes"
        )

    # Validate ranges
    if np.any((predicted < 0) | (predicted > 1)):
        raise ValueError("Predictions must be between 0 and 1")
    if not np.all(np.isin(actual, [0, 1])):
        raise ValueError("Outcomes must be 0 or 1")

    bin_edges = np.linspace(0, 1, n_bins + 1)
    bins = []

    for i in range(n_bins):
        lower, upper = bin_edges[i], bin_edges[i + 1]

        if i < n_bins - 1:
            mask = (predicted >= lower) & (predicted < upper)
        else:
            # Last bin includes right edge
            mask = (predicted >= lower) & (predicted <= upper)

        count = mask.sum()

        if count >= min_bin_count:
            bins.append(CalibrationBin(
                bin_lower=float(lower),
                bin_upper=float(upper),
                mean_predicted=float(predicted[mask].mean()),
                mean_observed=float(actual[mask].mean()),
                count=int(count)
            ))

    if not bins:
        logger.warning(
            f"No bins met the minimum count threshold ({min_bin_count}). "
            f"Total observations: {len(predicted)}"
        )

    return bins


def calculate_scores(
    predicted: Sequence[float],
    actual: Sequence[int],
    epsilon: float = 1e-10
) -> dict[str, float]:
    """
    Calculate forecast scoring metrics.

    Args:
        predicted: Predicted probabilities (0 to 1)
        actual: Actual outcomes (0 or 1)
        epsilon: Small value to avoid log(0)

    Returns:
        Dictionary with 'brier' and 'log_score' keys
    """
    predicted = np.asarray(predicted, dtype=float)
    actual = np.asarray(actual, dtype=int)

    # Brier score: mean squared error of probabilities
    brier = float(np.mean((predicted - actual) ** 2))

    # Logarithmic score
    clipped = np.clip(predicted, epsilon, 1 - epsilon)
    log_scores = actual * np.log(clipped) + (1 - actual) * np.log(1 - clipped)
    log_score = float(np.mean(log_scores))

    return {
        'brier': brier,
        'log_score': log_score,
    }


def category_breakdown(
    data: pd.DataFrame,
    min_category_size: int = 5
) -> dict[str, dict]:
    """
    Calculate calibration scores broken down by category.

    Args:
        data: DataFrame with 'category', 'yes_price', and 'outcome' columns
        min_category_size: Minimum markets per category to include

    Returns:
        Dictionary mapping category names to score dictionaries
    """
    results = {}

    if 'category' not in data.columns:
        logger.warning("No 'category' column found. Skipping breakdown.")
        return results

    for category, group in data.groupby('category'):
        if len(group) < min_category_size:
            logger.debug(
                f"Skipping category '{category}': "
                f"only {len(group)} markets (min: {min_category_size})"
            )
            continue

        scores = calculate_scores(
            group['yes_price'].values,
            group['outcome'].values
        )
        scores['count'] = len(group)
        results[str(category)] = scores

    return results


# ---------------------------------------------------------------------------
# Full analysis
# ---------------------------------------------------------------------------

def run_calibration_analysis(
    db_path: str,
    n_bins: int = 10,
    min_volume: float = 0,
    min_bin_count: int = 1,
    min_category_size: int = 5
) -> CalibrationResult:
    """
    Run the complete calibration analysis pipeline.

    Args:
        db_path: Path to the SQLite database
        n_bins: Number of calibration bins
        min_volume: Minimum market volume to include
        min_bin_count: Minimum observations per bin
        min_category_size: Minimum markets per category

    Returns:
        CalibrationResult with all analysis results
    """
    # Load data
    data = load_resolved_markets(db_path, min_volume=min_volume)

    # Calculate calibration
    bins = calculate_calibration(
        data['yes_price'].values,
        data['outcome'].values,
        n_bins=n_bins,
        min_bin_count=min_bin_count
    )

    # Calculate scores
    scores = calculate_scores(
        data['yes_price'].values,
        data['outcome'].values
    )

    # Category breakdown
    cat_scores = category_breakdown(
        data, min_category_size=min_category_size
    )

    result = CalibrationResult(
        bins=bins,
        brier_score=scores['brier'],
        log_score=scores['log_score'],
        n_markets=len(data),
        category_scores=cat_scores
    )

    logger.info(f"Analysis complete: Brier={result.brier_score:.4f}, "
                f"markets={result.n_markets}")

    return result


# ---------------------------------------------------------------------------
# Visualization
# ---------------------------------------------------------------------------

def plot_calibration_report(
    result: CalibrationResult,
    title: str = "Prediction Market Calibration Report",
    save_path: Optional[str] = None,
    figsize: tuple = (14, 10)
) -> plt.Figure:
    """
    Generate a comprehensive calibration report figure.

    Creates a 2x2 grid:
    - Top left: Calibration curve
    - Top right: Prediction distribution histogram
    - Bottom left: Category comparison (bar chart)
    - Bottom right: Summary statistics (text)

    Args:
        result: CalibrationResult from run_calibration_analysis
        title: Figure title
        save_path: Path to save the figure (None to skip saving)
        figsize: Figure size

    Returns:
        matplotlib Figure
    """
    sns.set_theme(style="whitegrid")
    fig, axes = plt.subplots(2, 2, figsize=figsize)
    fig.suptitle(title, fontsize=16, fontweight='bold', y=0.98)

    # --- Top Left: Calibration Curve ---
    ax_cal = axes[0, 0]
    if result.bins:
        centers = [b.mean_predicted for b in result.bins]
        observed = [b.mean_observed for b in result.bins]
        counts = [b.count for b in result.bins]

        # Scale point sizes
        max_count = max(counts)
        sizes = [max(30, (c / max_count) * 200) for c in counts]

        ax_cal.plot([0, 1], [0, 1], 'k--', alpha=0.5, label='Perfect calibration')
        ax_cal.scatter(
            centers, observed, s=sizes, c='#2196F3',
            edgecolors='white', linewidth=1, zorder=3, label='Observed'
        )
        ax_cal.plot(centers, observed, '-', color='#2196F3', alpha=0.4, linewidth=1)

        # Annotate bin counts
        for pred, obs, n in zip(centers, observed, counts):
            ax_cal.annotate(
                f'n={n}', (pred, obs),
                textcoords="offset points", xytext=(0, 10),
                fontsize=7, ha='center', color='gray'
            )

    ax_cal.set_xlabel('Predicted Probability')
    ax_cal.set_ylabel('Observed Frequency')
    ax_cal.set_title('Calibration Curve')
    ax_cal.set_xlim(-0.02, 1.02)
    ax_cal.set_ylim(-0.02, 1.02)
    ax_cal.set_aspect('equal')
    ax_cal.legend(loc='upper left', fontsize=9)

    # --- Top Right: Prediction Distribution ---
    ax_hist = axes[0, 1]
    if result.bins:
        bin_edges = [b.bin_lower for b in result.bins] + [result.bins[-1].bin_upper]
        counts = [b.count for b in result.bins]
        bar_centers = [b.bin_center for b in result.bins]
        bar_width = (bin_edges[1] - bin_edges[0]) * 0.8

        ax_hist.bar(
            bar_centers, counts, width=bar_width,
            color='#2196F3', alpha=0.7, edgecolor='white'
        )

    ax_hist.set_xlabel('Predicted Probability')
    ax_hist.set_ylabel('Number of Markets')
    ax_hist.set_title('Distribution of Predictions')
    ax_hist.set_xlim(-0.02, 1.02)

    # --- Bottom Left: Category Comparison ---
    ax_cat = axes[1, 0]
    if result.category_scores:
        categories = sorted(result.category_scores.keys())
        brier_scores = [result.category_scores[c]['brier'] for c in categories]
        cat_counts = [result.category_scores[c]['count'] for c in categories]

        y_pos = np.arange(len(categories))
        bars = ax_cat.barh(
            y_pos, brier_scores,
            color='#2196F3', alpha=0.7, edgecolor='white'
        )

        # Annotate with count
        for i, (score, count) in enumerate(zip(brier_scores, cat_counts)):
            ax_cat.text(
                score + 0.005, i, f'n={count}',
                va='center', fontsize=8, color='gray'
            )

        ax_cat.set_yticks(y_pos)
        ax_cat.set_yticklabels(categories, fontsize=9)
        ax_cat.set_xlabel('Brier Score (lower is better)')
        ax_cat.set_title('Calibration by Category')
        ax_cat.axvline(x=result.brier_score, color='red', linestyle='--',
                       alpha=0.5, label=f'Overall ({result.brier_score:.3f})')
        ax_cat.legend(fontsize=8)
    else:
        ax_cat.text(
            0.5, 0.5, 'No category data available',
            ha='center', va='center', fontsize=12, color='gray',
            transform=ax_cat.transAxes
        )
        ax_cat.set_title('Calibration by Category')

    # --- Bottom Right: Summary Statistics ---
    ax_summary = axes[1, 1]
    ax_summary.axis('off')

    summary_text = (
        f"Overall Metrics\n"
        f"{'─' * 30}\n"
        f"Markets analyzed:    {result.n_markets:>6}\n"
        f"Brier score:         {result.brier_score:>6.4f}\n"
        f"Log score:           {result.log_score:>6.4f}\n"
        f"Mean abs cal error:  {result.mean_absolute_calibration_error:>6.4f}\n"
        f"\n"
        f"Interpretation\n"
        f"{'─' * 30}\n"
    )

    if result.brier_score < 0.15:
        summary_text += "Brier: Excellent calibration\n"
    elif result.brier_score < 0.25:
        summary_text += "Brier: Good calibration\n"
    else:
        summary_text += "Brier: Poor calibration\n"

    if result.mean_absolute_calibration_error < 0.05:
        summary_text += "MACE: Well-calibrated\n"
    elif result.mean_absolute_calibration_error < 0.10:
        summary_text += "MACE: Moderately calibrated\n"
    else:
        summary_text += "MACE: Poorly calibrated\n"

    ax_summary.text(
        0.1, 0.9, summary_text,
        transform=ax_summary.transAxes,
        fontsize=11, verticalalignment='top',
        fontfamily='monospace',
        bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8)
    )

    plt.tight_layout(rect=[0, 0, 1, 0.95])

    if save_path:
        fig.savefig(save_path, dpi=150, bbox_inches='tight')
        logger.info(f"Report saved to {save_path}")

    return fig


# ---------------------------------------------------------------------------
# Command-line interface
# ---------------------------------------------------------------------------

def parse_args(argv: list[str] = None) -> argparse.Namespace:
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(
        description="Analyze prediction market calibration",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
    python calibration.py --db data/markets.db
    python calibration.py --db data/markets.db --output reports/cal.png --bins 15
    python calibration.py --db data/markets.db --min-volume 1000 --verbose
        """
    )
    parser.add_argument(
        '--db', required=True,
        help='Path to SQLite database'
    )
    parser.add_argument(
        '--output', default=None,
        help='Path to save the calibration report figure'
    )
    parser.add_argument(
        '--bins', type=int, default=10,
        help='Number of calibration bins (default: 10)'
    )
    parser.add_argument(
        '--min-volume', type=float, default=0,
        help='Minimum market volume to include (default: 0)'
    )
    parser.add_argument(
        '--min-category-size', type=int, default=5,
        help='Minimum markets per category (default: 5)'
    )
    parser.add_argument(
        '--verbose', action='store_true',
        help='Enable verbose logging'
    )
    parser.add_argument(
        '--no-plot', action='store_true',
        help='Skip generating the plot (text output only)'
    )

    return parser.parse_args(argv)


def main(argv: list[str] = None) -> int:
    """Main entry point for command-line usage."""
    args = parse_args(argv)

    # Configure logging
    log_level = logging.DEBUG if args.verbose else logging.INFO
    logging.basicConfig(
        level=log_level,
        format='%(asctime)s | %(name)s | %(levelname)s | %(message)s',
        datefmt='%H:%M:%S'
    )

    try:
        # Run analysis
        result = run_calibration_analysis(
            db_path=args.db,
            n_bins=args.bins,
            min_volume=args.min_volume,
            min_category_size=args.min_category_size
        )

        # Print summary
        print()
        print(result.summary())
        print()

        # Generate plot
        if not args.no_plot:
            import matplotlib
            if args.output:
                matplotlib.use('Agg')
            fig = plot_calibration_report(result, save_path=args.output)
            if not args.output:
                plt.show()

        return 0

    except FileNotFoundError as e:
        logger.error(str(e))
        return 1
    except ValueError as e:
        logger.error(str(e))
        return 1
    except Exception as e:
        logger.error(f"Unexpected error: {e}", exc_info=True)
        return 2


if __name__ == "__main__":
    sys.exit(main())

Tests: test_calibration.py

"""Tests for the calibration analysis module."""

import pytest
import numpy as np
from calibration import (
    CalibrationBin,
    CalibrationResult,
    calculate_calibration,
    calculate_scores,
)


class TestCalibrationBin:
    def test_bin_center(self):
        b = CalibrationBin(0.0, 0.2, 0.10, 0.12, 50)
        assert b.bin_center == pytest.approx(0.1)

    def test_deviation_overconfident(self):
        """Predicted > observed means overconfident."""
        b = CalibrationBin(0.6, 0.8, 0.70, 0.60, 30)
        assert b.deviation == pytest.approx(0.10)

    def test_deviation_underconfident(self):
        """Predicted < observed means underconfident."""
        b = CalibrationBin(0.2, 0.4, 0.30, 0.40, 25)
        assert b.deviation == pytest.approx(-0.10)


class TestCalculateCalibration:
    def test_perfectly_calibrated(self):
        """Perfect predictions should show bins near the diagonal."""
        np.random.seed(42)
        n = 1000
        # Generate predictions and then outcomes that match
        predictions = np.random.uniform(0, 1, n)
        outcomes = (np.random.uniform(0, 1, n) < predictions).astype(int)

        bins = calculate_calibration(predictions, outcomes, n_bins=5)
        assert len(bins) > 0

        for b in bins:
            # With enough data, observed should be close to predicted
            assert abs(b.deviation) < 0.15, (
                f"Bin {b.bin_lower:.1f}-{b.bin_upper:.1f}: "
                f"deviation {b.deviation:.3f} too large"
            )

    def test_empty_input_raises(self):
        with pytest.raises(ValueError, match="No predictions"):
            calculate_calibration([], [])

    def test_length_mismatch_raises(self):
        with pytest.raises(ValueError, match="Length mismatch"):
            calculate_calibration([0.5, 0.6], [1])

    def test_invalid_predictions_raise(self):
        with pytest.raises(ValueError, match="between 0 and 1"):
            calculate_calibration([1.5], [1])

    def test_invalid_outcomes_raise(self):
        with pytest.raises(ValueError, match="must be 0 or 1"):
            calculate_calibration([0.5], [2])

    def test_min_bin_count(self):
        """Bins with too few observations should be excluded."""
        predictions = [0.1, 0.9, 0.9, 0.9, 0.9]
        outcomes = [0, 1, 1, 1, 0]

        bins = calculate_calibration(
            predictions, outcomes, n_bins=10, min_bin_count=2
        )

        # The bin containing 0.1 has only 1 observation and should be excluded
        for b in bins:
            assert b.count >= 2

    def test_all_same_prediction(self):
        """All predictions in one bin should produce one bin."""
        predictions = [0.5] * 100
        outcomes = [1] * 60 + [0] * 40

        bins = calculate_calibration(predictions, outcomes, n_bins=10)
        non_empty = [b for b in bins if b.count > 0]
        assert len(non_empty) == 1
        assert non_empty[0].mean_observed == pytest.approx(0.6)


class TestCalculateScores:
    def test_perfect_brier(self):
        scores = calculate_scores([1.0, 0.0, 1.0], [1, 0, 1])
        assert scores['brier'] == pytest.approx(0.0)

    def test_worst_brier(self):
        scores = calculate_scores([0.0, 1.0], [1, 0])
        assert scores['brier'] == pytest.approx(1.0)

    def test_log_score_perfect(self):
        """Near-perfect predictions should have log score close to 0."""
        scores = calculate_scores([0.999, 0.001], [1, 0])
        assert scores['log_score'] > -0.01

    def test_log_score_bad(self):
        """Bad predictions should have very negative log score."""
        scores = calculate_scores([0.01, 0.99], [1, 0])
        assert scores['log_score'] < -4.0

    def test_uniform_predictions(self):
        """50% predictions on balanced data should give Brier = 0.25."""
        scores = calculate_scores([0.5] * 100, [1] * 50 + [0] * 50)
        assert scores['brier'] == pytest.approx(0.25)


class TestCalibrationResult:
    def test_mace_calculation(self):
        bins = [
            CalibrationBin(0.0, 0.5, 0.25, 0.30, 100),  # dev = -0.05
            CalibrationBin(0.5, 1.0, 0.75, 0.70, 100),   # dev = +0.05
        ]
        result = CalibrationResult(
            bins=bins, brier_score=0.20, log_score=-0.5, n_markets=200
        )
        assert result.mean_absolute_calibration_error == pytest.approx(0.05)

    def test_summary_string(self):
        result = CalibrationResult(
            bins=[], brier_score=0.20, log_score=-0.5, n_markets=100
        )
        summary = result.summary()
        assert "100" in summary
        assert "0.2000" in summary

    def test_empty_bins_mace(self):
        result = CalibrationResult(
            bins=[], brier_score=0.0, log_score=0.0, n_markets=0
        )
        assert np.isnan(result.mean_absolute_calibration_error)

What Changed: A Comparison

Aspect Notebook Refactored Module
Imports Scattered across cells Single block at the top
Functions None (top-level code) 7 well-defined functions
Type hints None Full type annotations
Error handling None ValueError, FileNotFoundError, logging
Data structures Raw dicts/DataFrames CalibrationBin, CalibrationResult dataclasses
Documentation Sparse comments Module, class, and function docstrings
Configurability Hardcoded values Command-line arguments
Testing Manual inspection 15+ automated test cases
Reusability Copy-paste only Import as a library
Execution Jupyter only CLI, library, or Jupyter
Output Inline only File, display, or text summary

The Refactoring Process, Step by Step

For reference, here are the exact steps we followed:

  1. Read through the entire notebook and understand what it does holistically.
  2. List the distinct operations (load, transform, calculate, visualize).
  3. Define data structures for intermediate and final results (CalibrationBin, CalibrationResult).
  4. Extract each operation into a function with clear inputs, outputs, and docstrings.
  5. Add input validation at the beginning of each function.
  6. Add error handling for every expected failure mode.
  7. Add type hints to all function signatures.
  8. Write tests for each function, including edge cases.
  9. Add a command-line interface using argparse.
  10. Add logging at appropriate levels (DEBUG for details, INFO for flow, WARNING/ERROR for problems).
  11. Run all tests to verify nothing broke.
  12. Run the module end-to-end from the command line to verify it works as a script.

Key Lessons

1. Notebooks Are for Exploration, Modules Are for Reuse

The notebook was valuable for developing the analysis interactively. But once the approach is validated, extracting it into a module makes it reliable, testable, and composable with other tools.

2. Functions Are the Unit of Reuse

By wrapping each step in a function, we can reuse calculate_calibration in other analyses without dragging along the database loading or visualization code.

3. Data Classes Clarify Intent

CalibrationBin is clearer than a DataFrame row. It documents exactly what fields exist, what their types are, and what computed properties are available.

4. Tests Are Documentation

The test names (test_perfectly_calibrated, test_empty_input_raises, test_min_bin_count) document the expected behavior better than any comment.

5. Command-Line Interfaces Make Scripts Useful

The argparse interface lets other scripts, cron jobs, or colleagues use the analysis without opening Jupyter.

Discussion Questions

  1. What parts of the analysis might still benefit from remaining in notebook form? Why?
  2. How would you extend this module to compare calibration across different time periods (e.g., has calibration improved over the past year)?
  3. The current module uses SQLite directly. How would you modify it to also accept a pandas DataFrame as input, making it usable without a database?
  4. What additional visualizations would be useful in the calibration report?
  5. How would you set up continuous integration (CI) to run the tests automatically on every code change?