> "A model without monitoring is a model waiting to fail silently. Traditional software fails loudly — exceptions, crashes, HTTP 500s. ML models fail quietly — they return predictions that are syntactically correct, numerically plausible, and...
In This Chapter
- Learning Objectives
- 30.1 The Silent Failure Problem
- 30.2 Monitoring vs. Observability
- 30.3 The Three Pillars Architecture: Prometheus, Grafana, and the ML Layer
- 30.4 Pillar 1: Data Quality Monitoring
- 30.5 Pillar 2: Model Performance Monitoring
- 30.6 Pillar 3: System Health Monitoring
- 30.7 Pillar 4: Business Metric Monitoring and the Feedback Loop
- 30.8 Drift Detection: Population Stability Index (PSI)
- 30.9 Drift Detection: The Kolmogorov-Smirnov Test
- 30.10 Drift Detection: Jensen-Shannon Divergence
- 30.11 Types of Drift
- 30.12 Alerting Infrastructure
- 30.13 Escalation and On-Call
- 30.14 Incident Response for ML Systems
- 30.15 Blameless Post-Mortems
- 30.16 Putting It Together: The Monitoring Architecture
- 30.17 Progressive Project M14: StreamRec Monitoring Dashboard
- 30.18 Summary
- References
Chapter 30: Monitoring, Observability, and Incident Response — Keeping ML Systems Healthy in Production
"A model without monitoring is a model waiting to fail silently. Traditional software fails loudly — exceptions, crashes, HTTP 500s. ML models fail quietly — they return predictions that are syntactically correct, numerically plausible, and completely wrong." — Adapted from Sculley et al., "Hidden Technical Debt in Machine Learning Systems" (NeurIPS, 2015)
Learning Objectives
By the end of this chapter, you will be able to:
- Design monitoring dashboards that track data quality, model performance, system health, and business metrics in production ML systems
- Implement data drift detection using Population Stability Index (PSI), the Kolmogorov-Smirnov test, and Jensen-Shannon divergence
- Build alerting infrastructure with thresholds, escalation policies, and on-call rotation
- Conduct blameless post-mortems for ML-specific incidents and derive actionable preventive measures
- Distinguish monitoring (detecting known failure modes) from observability (diagnosing unknown failure modes) and design systems that support both
30.1 The Silent Failure Problem
On March 3, the StreamRec recommendation model began recommending stale content to 40% of users. Engagement dropped 12% over four days. No alert fired. No log entry indicated a problem. The serving infrastructure returned HTTP 200 on every request. The model scored every input and returned valid item IDs. By every metric the engineering team had instrumented — latency, throughput, error rate, uptime — the system was healthy.
The problem was invisible to traditional software monitoring because it was not a software problem. A feature pipeline change had silently shifted the days_since_last_interaction feature from a 0-365 integer to a Unix timestamp. The model interpreted timestamps in the billions as "extremely long since last interaction" and overweighted recency-insensitive content. The predictions were valid. The system was functional. The recommendations were garbage.
This incident illustrates the central challenge of this chapter: ML systems fail in ways that traditional monitoring cannot detect. Software monitoring answers the question "Is the system running?" ML monitoring must also answer "Is the system producing good predictions?" — a fundamentally harder question because "good" depends on the data distribution, the model's learned patterns, the business context, and the relationship between model outputs and user outcomes.
Production ML = Software Engineering: This chapter completes the production ML stack from Part V. Chapter 24 designed the system. Chapter 25 built the data infrastructure. Chapter 26 scaled training. Chapter 27 orchestrated the pipeline. Chapter 28 built the testing infrastructure. Chapter 29 deployed the model. This chapter asks the final question: once the model is in production, how do we know it is still working? The answer requires monitoring infrastructure that goes beyond traditional software — monitoring not just the system, but the data flowing through it, the model's behavior on that data, and the business impact of the model's predictions.
Know How Your Model Is Wrong: Monitoring and observability are the production manifestation of this recurring theme. During development, you know how your model is wrong through evaluation metrics, error analysis, and sliced performance. In production, you know how your model is wrong through drift detection, prediction monitoring, and business metric tracking. The transition from "knowing how the model was wrong on the test set" to "knowing how the model is wrong right now on live traffic" is the intellectual core of this chapter.
The chapter proceeds in five movements. Section 30.2-30.3 establish the conceptual distinction between monitoring and observability. Section 30.4-30.7 cover the four pillars of ML monitoring: data quality, model performance, system health, and business metrics. Section 30.8-30.10 cover drift detection methods: PSI, the KS test, and Jensen-Shannon divergence. Section 30.11-30.13 cover alerting infrastructure: thresholds, escalation, on-call, and runbooks. Section 30.14-30.16 cover incident response: detection, mitigation, root cause analysis, and blameless post-mortems. Section 30.17 applies all of this to the StreamRec progressive project.
30.2 Monitoring vs. Observability
The terms "monitoring" and "observability" are often used interchangeably. They should not be. They represent fundamentally different approaches to understanding system behavior, and an effective production ML system requires both.
Monitoring is the practice of collecting predefined metrics and checking them against known thresholds. A monitoring system answers questions you anticipated: "Is latency below 50ms? Is the error rate below 0.1%? Is the feature null rate below 5%?" Monitoring detects known unknowns — failure modes you have seen before or can imagine. You instrument the system, define thresholds, and receive alerts when those thresholds are violated.
Observability is the property of a system that allows you to understand its internal state from its external outputs. An observable system produces enough telemetry — metrics, logs, and traces — that you can diagnose problems you did not anticipate. Observability addresses unknown unknowns — failure modes you have never seen and could not have predicted. The March 3 incident above was an unknown unknown: no one anticipated that a feature pipeline would silently convert integers to timestamps, so no monitor existed for it.
The distinction maps to a medical analogy. Monitoring is the vital signs monitor in a hospital room: heart rate, blood pressure, oxygen saturation, temperature. These metrics are predefined, thresholds are well-understood, and alerts are clear. Observability is the ability to order a CT scan, blood panel, or biopsy when something seems wrong but the vital signs do not explain why. The vital signs tell you the patient is deteriorating; the diagnostic tests tell you why.
| Dimension | Monitoring | Observability |
|---|---|---|
| Question answered | "Is something wrong?" | "Why is something wrong?" |
| Failure modes | Known unknowns (anticipated) | Unknown unknowns (novel) |
| Approach | Predefined metrics + thresholds | Rich telemetry + ad-hoc queries |
| Output | Alerts (binary: firing/resolved) | Diagnostic data (exploratory) |
| Design time | You decide what to measure | You decide what to record, query later |
| Example | "PSI of feature X exceeds 0.25" | "What changed about feature X's distribution, when, and why?" |
For ML systems, the monitoring-observability spectrum plays out across four signal types:
-
Metrics: Numeric measurements sampled at regular intervals (latency p50/p95/p99, prediction mean, feature null rate). Metrics are the foundation of monitoring. They are cheap to collect, efficient to store, and fast to query. But they are aggregated and lossy — a mean prediction score of 0.5 tells you nothing about whether the bimodal distribution has shifted.
-
Logs: Timestamped, structured records of discrete events (prediction requests, feature computation results, model loading events). Logs are the foundation of observability. They are rich and granular — a log entry can contain the full feature vector, the model's prediction, the latency breakdown, and any warnings. But they are expensive to store and slow to query at scale.
-
Traces: Causally linked sequences of operations that represent a single request's path through the system. A trace for a recommendation request might span: API gateway → feature retrieval (Redis) → candidate retrieval (FAISS) → ranking model inference → re-ranking → response serialization. Traces allow you to pinpoint which component is slow, failing, or producing unexpected results.
-
Prediction artifacts: The inputs and outputs of the model itself — the feature vectors, the predicted scores, the ranking order. These are ML-specific signals that have no analogue in traditional software monitoring. A trace tells you the model inference took 12ms; a prediction artifact tells you the model scored item #4471 at 0.93, which is anomalously high for that item category.
A well-designed monitoring system collects all four signal types at appropriate granularity and retention:
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional
from datetime import timedelta
class SignalType(Enum):
METRIC = "metric"
LOG = "log"
TRACE = "trace"
PREDICTION = "prediction"
class RetentionPolicy(Enum):
HOT = "hot" # In-memory or fast SSD: real-time queries
WARM = "warm" # Object storage: hourly/daily queries
COLD = "cold" # Archive: compliance/audit queries
@dataclass
class TelemetryConfig:
"""Configuration for ML system telemetry collection.
Defines what signals to collect, at what granularity, and for
how long to retain them. Designed for recommendation systems
where prediction volume is high (millions/day) and storage
costs constrain full-fidelity retention.
Attributes:
system_name: Name of the ML system being monitored.
signals: List of signal configurations.
default_metric_interval: Default scrape interval for metrics.
prediction_sample_rate: Fraction of predictions to log
at full fidelity (feature vectors + scores).
trace_sample_rate: Fraction of requests to trace end-to-end.
"""
system_name: str
signals: List["SignalConfig"] = field(default_factory=list)
default_metric_interval: timedelta = timedelta(seconds=15)
prediction_sample_rate: float = 0.01 # 1% of predictions
trace_sample_rate: float = 0.001 # 0.1% of requests
def estimated_daily_storage_gb(
self, daily_predictions: int
) -> Dict[str, float]:
"""Estimate daily storage requirements by signal type.
Args:
daily_predictions: Expected number of predictions per day.
Returns:
Dictionary mapping signal type to estimated GB/day.
"""
# Rough estimates based on typical payload sizes
metric_points_per_day = (
len([s for s in self.signals
if s.signal_type == SignalType.METRIC])
* 86400
/ self.default_metric_interval.total_seconds()
)
log_entries_per_day = daily_predictions * 0.1 # ~10% generate logs
traced_requests = daily_predictions * self.trace_sample_rate
sampled_predictions = (
daily_predictions * self.prediction_sample_rate
)
return {
"metrics": metric_points_per_day * 8 / 1e9, # 8 bytes/point
"logs": log_entries_per_day * 500 / 1e9, # 500 bytes/entry
"traces": traced_requests * 2000 / 1e9, # 2KB/trace
"predictions": sampled_predictions * 4000 / 1e9, # 4KB/pred
}
@dataclass
class SignalConfig:
"""Configuration for a single telemetry signal.
Attributes:
name: Human-readable signal name.
signal_type: Type of signal (metric, log, trace, prediction).
description: What this signal measures and why it matters.
collection_interval: How often to collect (for metrics).
retention: Dictionary mapping retention tier to duration.
alert_threshold: Optional threshold that triggers an alert.
"""
name: str
signal_type: SignalType
description: str
collection_interval: Optional[timedelta] = None
retention: Dict[RetentionPolicy, timedelta] = field(
default_factory=lambda: {
RetentionPolicy.HOT: timedelta(days=7),
RetentionPolicy.WARM: timedelta(days=90),
RetentionPolicy.COLD: timedelta(days=365),
}
)
alert_threshold: Optional[float] = None
30.3 The Three Pillars Architecture: Prometheus, Grafana, and the ML Layer
The industry-standard monitoring stack for production systems is Prometheus (metrics collection and storage) + Grafana (visualization and alerting). This stack was designed for software systems. ML systems adopt it and extend it with an ML-specific layer that tracks data quality, model behavior, and drift.
Prometheus
Prometheus is a time-series database and monitoring system. It collects metrics by scraping HTTP endpoints at regular intervals (typically 15 seconds). Each metric is a time series: a sequence of (timestamp, value) pairs with a name and a set of labels.
The four Prometheus metric types map to ML monitoring needs:
| Metric Type | Description | ML Example |
|---|---|---|
| Counter | Monotonically increasing value | Total predictions served, total drift alerts fired |
| Gauge | Value that can go up or down | Current model version, feature null rate, prediction mean |
| Histogram | Distribution of values in configurable buckets | Prediction latency distribution, score distribution |
| Summary | Similar to histogram but computes quantiles client-side | Prediction score p50/p90/p99 |
from prometheus_client import (
Counter, Gauge, Histogram, Summary, Info, start_http_server
)
from typing import Dict, List, Optional
import time
# --- System-level metrics ---
PREDICTION_COUNTER = Counter(
"streamrec_predictions_total",
"Total number of predictions served",
["model_version", "endpoint"],
)
PREDICTION_LATENCY = Histogram(
"streamrec_prediction_latency_seconds",
"Prediction request latency in seconds",
["model_version", "stage"],
buckets=[0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0],
)
PREDICTION_ERRORS = Counter(
"streamrec_prediction_errors_total",
"Total prediction errors by type",
["model_version", "error_type"],
)
# --- Model-level metrics ---
PREDICTION_SCORE = Histogram(
"streamrec_prediction_score",
"Distribution of prediction scores",
["model_version"],
buckets=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
)
FEATURE_NULL_RATE = Gauge(
"streamrec_feature_null_rate",
"Null rate for each feature at serving time",
["feature_name"],
)
FEATURE_MEAN = Gauge(
"streamrec_feature_mean",
"Running mean for numeric features at serving time",
["feature_name"],
)
MODEL_INFO = Info(
"streamrec_model",
"Current model metadata",
)
# --- Data quality metrics ---
DATA_FRESHNESS_SECONDS = Gauge(
"streamrec_data_freshness_seconds",
"Age of the most recent data point in the feature store",
["feature_source"],
)
DRIFT_PSI = Gauge(
"streamrec_drift_psi",
"Population Stability Index for each feature",
["feature_name"],
)
def instrument_prediction(
model_version: str,
features: Dict[str, Optional[float]],
scores: List[float],
latency_breakdown: Dict[str, float],
) -> None:
"""Record telemetry for a single prediction request.
Called on every prediction to update Prometheus metrics.
Designed for minimal overhead: gauge and counter updates
are O(1); histogram observations are O(log B) where B
is the number of buckets.
Args:
model_version: Currently serving model version string.
features: Feature name to value mapping (None = missing).
scores: List of predicted scores for ranked items.
latency_breakdown: Latency in seconds by stage
(e.g., 'feature_retrieval', 'inference', 'reranking').
"""
# Count the prediction
PREDICTION_COUNTER.labels(
model_version=model_version, endpoint="/recommend"
).inc()
# Record latency by stage
total_latency = 0.0
for stage, latency in latency_breakdown.items():
PREDICTION_LATENCY.labels(
model_version=model_version, stage=stage
).observe(latency)
total_latency += latency
PREDICTION_LATENCY.labels(
model_version=model_version, stage="total"
).observe(total_latency)
# Record score distribution
for score in scores[:10]: # Top 10 scores only to limit cardinality
PREDICTION_SCORE.labels(
model_version=model_version
).observe(score)
# Record feature health
for feature_name, value in features.items():
if value is None:
FEATURE_NULL_RATE.labels(
feature_name=feature_name
).set(1.0)
else:
FEATURE_NULL_RATE.labels(
feature_name=feature_name
).set(0.0)
FEATURE_MEAN.labels(
feature_name=feature_name
).set(value)
Grafana
Grafana is a visualization and alerting platform that queries Prometheus (and other data sources) to render dashboards and evaluate alert rules. For ML systems, Grafana dashboards are organized into layers:
| Dashboard Layer | What It Shows | Who Watches |
|---|---|---|
| Executive | Business KPIs (engagement, CTR, revenue) | Product managers, leadership |
| Model Health | Prediction distributions, drift scores, accuracy proxies | ML engineers, data scientists |
| Data Quality | Feature null rates, freshness, schema violations | Data engineers, ML engineers |
| System Health | Latency, throughput, error rates, resource utilization | SREs, platform engineers |
The ML Monitoring Layer
Prometheus and Grafana handle metrics collection and visualization. But they were designed for software metrics (latency, throughput, error rate), not ML metrics (drift, prediction quality, feature distribution shifts). The ML monitoring layer bridges this gap by computing ML-specific metrics and exposing them as Prometheus gauges:
import numpy as np
from scipy import stats
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
from datetime import datetime, timedelta
import logging
logger = logging.getLogger(__name__)
@dataclass
class ReferenceDistribution:
"""Reference distribution for a single feature, computed from
training data.
Stores bin edges and proportions for PSI computation, plus
summary statistics for KS test and basic monitoring.
Attributes:
feature_name: Name of the feature.
bin_edges: Edges of histogram bins (length = n_bins + 1).
bin_proportions: Proportion of training data in each bin.
mean: Training data mean.
std: Training data standard deviation.
quantiles: Dictionary of quantile values (e.g., {0.25: v}).
n_samples: Number of training samples used to compute
the reference.
computed_at: Timestamp when the reference was computed.
"""
feature_name: str
bin_edges: np.ndarray
bin_proportions: np.ndarray
mean: float
std: float
quantiles: Dict[float, float] = field(default_factory=dict)
n_samples: int = 0
computed_at: Optional[datetime] = None
@classmethod
def from_array(
cls,
feature_name: str,
values: np.ndarray,
n_bins: int = 10,
) -> "ReferenceDistribution":
"""Compute a reference distribution from training data.
Uses quantile-based bins to ensure each bin has roughly
equal representation in the reference, which makes PSI
more sensitive to shifts in the tails.
Args:
feature_name: Name of the feature.
values: 1D array of training data values (non-null).
n_bins: Number of bins for histogram.
Returns:
ReferenceDistribution with computed statistics.
"""
values = values[~np.isnan(values)]
quantile_edges = np.linspace(0, 1, n_bins + 1)
bin_edges = np.quantile(values, quantile_edges)
# Ensure unique bin edges (can happen with low-cardinality features)
bin_edges = np.unique(bin_edges)
if len(bin_edges) < 3:
bin_edges = np.array([
values.min(), values.mean(), values.max()
])
counts, _ = np.histogram(values, bins=bin_edges)
proportions = counts / counts.sum()
# Avoid zero proportions (would cause division by zero in PSI)
proportions = np.clip(proportions, 1e-6, None)
proportions = proportions / proportions.sum()
return cls(
feature_name=feature_name,
bin_edges=bin_edges,
bin_proportions=proportions,
mean=float(np.mean(values)),
std=float(np.std(values)),
quantiles={
q: float(np.quantile(values, q))
for q in [0.01, 0.05, 0.25, 0.50, 0.75, 0.95, 0.99]
},
n_samples=len(values),
computed_at=datetime.utcnow(),
)
30.4 Pillar 1: Data Quality Monitoring
Data quality monitoring is the first and most impactful layer. As Chapter 28 established, most production ML failures are data failures, not model failures. The testing infrastructure from Chapter 28 catches data problems at training time; the monitoring infrastructure in this chapter catches them at serving time, when the data flows continuously and there is no discrete "validation checkpoint."
Feature-Level Monitoring
For every feature the model consumes at serving time, monitor:
| Signal | What It Detects | Threshold Example |
|---|---|---|
| Null rate | Missing data, pipeline failures | > 5% for nullable features, > 0% for required features |
| Mean / median | Central tendency shift | > 2 standard deviations from training mean |
| Variance | Spread change (compression or explosion) | Ratio to training variance outside [0.5, 2.0] |
| Min / max | Out-of-range values, encoding errors | Values outside [training_min - margin, training_max + margin] |
| Cardinality | New categories, category collapse | Change > 10% in unique value count |
| Distribution (PSI) | Full distributional shift | PSI > 0.1 (investigate), > 0.25 (alert) |
Data Freshness Monitoring
Feature stores (Chapter 25) serve both batch features (computed daily or hourly) and streaming features (computed in real time). Freshness monitoring checks that each feature source is up to date:
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Dict, Optional
from enum import Enum
class FreshnessStatus(Enum):
FRESH = "fresh"
STALE = "stale"
CRITICAL = "critical"
@dataclass
class FreshnessMonitor:
"""Monitor data freshness for feature store sources.
Each feature source has an expected update cadence. The monitor
checks the timestamp of the most recent update against the
expected cadence and alerts when data is stale.
Attributes:
source_configs: Mapping from source name to expected
maximum age (timedelta).
"""
source_configs: Dict[str, timedelta]
def check_freshness(
self,
source_timestamps: Dict[str, datetime],
current_time: Optional[datetime] = None,
) -> Dict[str, "FreshnessResult"]:
"""Check freshness of all configured sources.
Args:
source_timestamps: Mapping from source name to timestamp
of most recent data point.
current_time: Current time (defaults to utcnow).
Returns:
Dictionary mapping source name to FreshnessResult.
"""
if current_time is None:
current_time = datetime.utcnow()
results = {}
for source_name, max_age in self.source_configs.items():
if source_name not in source_timestamps:
results[source_name] = FreshnessResult(
source_name=source_name,
status=FreshnessStatus.CRITICAL,
age=None,
max_age=max_age,
message=f"No timestamp found for {source_name}",
)
continue
age = current_time - source_timestamps[source_name]
if age <= max_age:
status = FreshnessStatus.FRESH
elif age <= max_age * 2:
status = FreshnessStatus.STALE
else:
status = FreshnessStatus.CRITICAL
results[source_name] = FreshnessResult(
source_name=source_name,
status=status,
age=age,
max_age=max_age,
message=(
f"{source_name}: age={age}, max={max_age}, "
f"status={status.value}"
),
)
return results
@dataclass
class FreshnessResult:
"""Result of a freshness check for a single source.
Attributes:
source_name: Name of the data source.
status: Fresh, stale, or critical.
age: Current age of the data (None if not available).
max_age: Configured maximum acceptable age.
message: Human-readable status message.
"""
source_name: str
status: FreshnessStatus
age: Optional[timedelta]
max_age: timedelta
message: str
# StreamRec freshness configuration
streamrec_freshness = FreshnessMonitor(
source_configs={
"user_batch_features": timedelta(hours=6),
"item_batch_features": timedelta(hours=12),
"user_streaming_features": timedelta(minutes=5),
"item_popularity_features": timedelta(hours=1),
"model_embeddings": timedelta(hours=24),
}
)
Training-Serving Skew Detection
One of the subtlest data quality issues is training-serving skew: the features the model sees at serving time differ from the features it was trained on, even though both use the "same" feature definitions. Common causes:
- Computation skew: Training features are computed in batch (e.g., Spark); serving features are computed in real time (e.g., Redis + Python). Different code paths produce different values.
- Temporal skew: Training features include future information that is not available at serving time (label leakage).
- Staleness skew: Training features use the latest value; serving features use a cached value that is hours or days old.
The feature store architecture from Chapter 25 mitigates computation skew by using a shared feature computation layer. But even with a shared layer, staleness skew is inevitable for batch features. Monitor the distribution gap between training-time features and serving-time features:
import numpy as np
from typing import Dict, Tuple
def compute_training_serving_skew(
training_features: Dict[str, np.ndarray],
serving_features: Dict[str, np.ndarray],
method: str = "psi",
n_bins: int = 10,
) -> Dict[str, Tuple[float, str]]:
"""Compare feature distributions between training and serving.
Computes a distributional distance metric for each feature
and classifies the result as 'ok', 'investigate', or 'alert'.
Args:
training_features: Feature name to array of training values.
serving_features: Feature name to array of serving values.
method: Distance metric ('psi', 'ks', or 'js').
n_bins: Number of bins for PSI/JS computation.
Returns:
Dictionary mapping feature name to (distance, severity).
"""
results = {}
common_features = set(training_features) & set(serving_features)
for feature_name in common_features:
train_vals = training_features[feature_name]
serve_vals = serving_features[feature_name]
# Remove NaNs
train_vals = train_vals[~np.isnan(train_vals)]
serve_vals = serve_vals[~np.isnan(serve_vals)]
if len(train_vals) < 100 or len(serve_vals) < 100:
results[feature_name] = (float("nan"), "insufficient_data")
continue
if method == "psi":
distance = _compute_psi(train_vals, serve_vals, n_bins)
severity = (
"ok" if distance < 0.1
else "investigate" if distance < 0.25
else "alert"
)
elif method == "ks":
stat, p_value = _compute_ks(train_vals, serve_vals)
distance = stat
severity = (
"ok" if p_value > 0.01
else "investigate" if p_value > 0.001
else "alert"
)
elif method == "js":
distance = _compute_js_divergence(
train_vals, serve_vals, n_bins
)
severity = (
"ok" if distance < 0.05
else "investigate" if distance < 0.15
else "alert"
)
else:
raise ValueError(f"Unknown method: {method}")
results[feature_name] = (float(distance), severity)
# Check for features present in training but missing in serving
for feature_name in set(training_features) - common_features:
results[feature_name] = (float("inf"), "missing_in_serving")
return results
def _compute_psi(
reference: np.ndarray,
current: np.ndarray,
n_bins: int,
) -> float:
"""Compute Population Stability Index.
See Section 30.8 for detailed explanation.
"""
# Use quantile-based bins from reference distribution
bin_edges = np.quantile(
reference, np.linspace(0, 1, n_bins + 1)
)
bin_edges = np.unique(bin_edges)
ref_counts, _ = np.histogram(reference, bins=bin_edges)
cur_counts, _ = np.histogram(current, bins=bin_edges)
ref_props = ref_counts / ref_counts.sum()
cur_props = cur_counts / cur_counts.sum()
# Clip to avoid log(0)
ref_props = np.clip(ref_props, 1e-6, None)
cur_props = np.clip(cur_props, 1e-6, None)
psi = np.sum((cur_props - ref_props) * np.log(cur_props / ref_props))
return float(psi)
def _compute_ks(
reference: np.ndarray,
current: np.ndarray,
) -> Tuple[float, float]:
"""Compute two-sample Kolmogorov-Smirnov test.
See Section 30.9 for detailed explanation.
"""
from scipy.stats import ks_2samp
stat, p_value = ks_2samp(reference, current)
return float(stat), float(p_value)
def _compute_js_divergence(
reference: np.ndarray,
current: np.ndarray,
n_bins: int,
) -> float:
"""Compute Jensen-Shannon divergence.
See Section 30.10 for detailed explanation.
"""
bin_edges = np.quantile(
reference, np.linspace(0, 1, n_bins + 1)
)
bin_edges = np.unique(bin_edges)
ref_counts, _ = np.histogram(reference, bins=bin_edges)
cur_counts, _ = np.histogram(current, bins=bin_edges)
ref_props = ref_counts / ref_counts.sum()
cur_props = cur_counts / cur_counts.sum()
ref_props = np.clip(ref_props, 1e-10, None)
cur_props = np.clip(cur_props, 1e-10, None)
m = 0.5 * (ref_props + cur_props)
kl_ref_m = np.sum(ref_props * np.log(ref_props / m))
kl_cur_m = np.sum(cur_props * np.log(cur_props / m))
js = 0.5 * kl_ref_m + 0.5 * kl_cur_m
return float(js)
30.5 Pillar 2: Model Performance Monitoring
Model performance monitoring answers the question: "Is the model still producing good predictions?" This is harder than it sounds, because ground truth — the actual outcome — is often delayed or unavailable.
The Ground Truth Delay Problem
In software monitoring, feedback is immediate: the server either returns the correct response or it does not. In ML, the feedback loop can take days, weeks, or months:
| Application | Prediction Delay | Ground Truth Delay |
|---|---|---|
| StreamRec (content recommendation) | Milliseconds | Minutes to hours (did the user engage?) |
| Credit scoring | Seconds | 6-18 months (did the borrower default?) |
| Fraud detection | Milliseconds | Days to weeks (was the transaction confirmed as fraud?) |
| Medical diagnosis | Seconds | Weeks to months (biopsy results, treatment outcomes) |
For StreamRec, the feedback loop is relatively short: we know within hours whether a user clicked, watched, or completed the recommended content. For Meridian Financial's credit scoring model, we may not know whether the prediction was correct for over a year. This asymmetry demands different monitoring strategies:
Short feedback loops (minutes to hours): Monitor actual model performance directly. Compute real-time accuracy, precision, recall, or custom business metrics using incoming ground truth labels.
Long feedback loops (weeks to months): Monitor proxy signals and distributional properties. If the model's prediction distribution shifts, the input distribution shifts, or the model's behavior on known test cases changes, something has changed — even if we cannot yet measure the actual impact on outcomes.
Prediction Distribution Monitoring
Even without ground truth, the model's prediction distribution is an informative signal. A well-calibrated model produces a stable prediction distribution when the input distribution is stable. A shift in the prediction distribution signals either (a) the input distribution has changed (data drift), (b) the relationship between inputs and outcomes has changed (concept drift), or (c) the model itself has changed (deployment error).
import numpy as np
from dataclasses import dataclass, field
from typing import List, Optional, Dict
from datetime import datetime
import logging
logger = logging.getLogger(__name__)
@dataclass
class PredictionMonitor:
"""Monitor the distribution of model predictions over time.
Compares sliding windows of predictions against a reference
distribution (typically from the model's validation set at
training time) and flags significant deviations.
Attributes:
reference_scores: Prediction scores from the validation set
at training time.
window_size: Number of predictions per monitoring window.
psi_warn_threshold: PSI above this triggers a warning.
psi_alert_threshold: PSI above this triggers an alert.
score_buffer: Buffer of recent prediction scores.
"""
reference_scores: np.ndarray
window_size: int = 10000
psi_warn_threshold: float = 0.10
psi_alert_threshold: float = 0.25
score_buffer: List[float] = field(default_factory=list)
def __post_init__(self):
self._ref_distribution = ReferenceDistribution.from_array(
feature_name="prediction_score",
values=self.reference_scores,
n_bins=20,
)
def observe(self, scores: List[float]) -> Optional["DriftAlert"]:
"""Add prediction scores and check for distribution shift.
Called after each batch of predictions. When the buffer
reaches window_size, computes PSI against the reference
and resets the buffer.
Args:
scores: List of prediction scores from recent requests.
Returns:
DriftAlert if a threshold is exceeded, None otherwise.
"""
self.score_buffer.extend(scores)
if len(self.score_buffer) < self.window_size:
return None
# Compute PSI against reference
current_values = np.array(self.score_buffer[:self.window_size])
psi = _compute_psi(
self._ref_distribution.bin_proportions,
current_values,
self._ref_distribution.bin_edges,
)
# Update Prometheus gauge
DRIFT_PSI.labels(feature_name="prediction_score").set(psi)
# Reset buffer
self.score_buffer = self.score_buffer[self.window_size:]
if psi > self.psi_alert_threshold:
return DriftAlert(
feature_name="prediction_score",
metric="psi",
value=psi,
threshold=self.psi_alert_threshold,
severity="critical",
timestamp=datetime.utcnow(),
message=(
f"Prediction score PSI={psi:.4f} exceeds alert "
f"threshold {self.psi_alert_threshold}. "
f"Investigate for concept drift or data quality "
f"issues."
),
)
elif psi > self.psi_warn_threshold:
return DriftAlert(
feature_name="prediction_score",
metric="psi",
value=psi,
threshold=self.psi_warn_threshold,
severity="warning",
timestamp=datetime.utcnow(),
message=(
f"Prediction score PSI={psi:.4f} exceeds warning "
f"threshold {self.psi_warn_threshold}. "
f"Monitor trend."
),
)
return None
def _compute_psi(
ref_proportions: np.ndarray,
current_values: np.ndarray,
bin_edges: np.ndarray,
) -> float:
"""Compute PSI using precomputed reference proportions."""
cur_counts, _ = np.histogram(current_values, bins=bin_edges)
cur_props = cur_counts / cur_counts.sum()
cur_props = np.clip(cur_props, 1e-6, None)
ref_props = np.clip(ref_proportions, 1e-6, None)
psi = np.sum(
(cur_props - ref_props) * np.log(cur_props / ref_props)
)
return float(psi)
@dataclass
class DriftAlert:
"""Alert generated when a drift threshold is exceeded.
Attributes:
feature_name: Name of the drifting feature (or
'prediction_score').
metric: Drift metric used (psi, ks, js).
value: Computed drift value.
threshold: Threshold that was exceeded.
severity: 'warning' or 'critical'.
timestamp: When the drift was detected.
message: Human-readable alert message.
"""
feature_name: str
metric: str
value: float
threshold: float
severity: str
timestamp: datetime
message: str
Business Metric Monitoring
Model performance metrics (accuracy, Recall@20, NDCG@20) measure the model's technical quality. Business metrics measure the model's impact on the product. For StreamRec:
| Business Metric | Definition | Why It Matters |
|---|---|---|
| Click-through rate (CTR) | Clicks on recommended items / impressions | Direct engagement signal |
| Completion rate | Completions / starts for recommended content | Depth of engagement |
| Session length | Average minutes per user session | Overall platform stickiness |
| Recommendation coverage | Unique items recommended / total catalog size | Diversity and catalog utilization |
| Revenue per recommendation | Revenue attributed to recommendations / total recs | Business value |
The relationship between model metrics and business metrics is not always monotonic. Optimizing Recall@20 (showing items the user will click) may decrease recommendation coverage (showing the same popular items to everyone). Optimizing NDCG@20 (ranking preferred items higher) may decrease session length if users find their desired content immediately and leave. These trade-offs are why business metric monitoring cannot be replaced by model metric monitoring.
30.6 Pillar 3: System Health Monitoring
System health monitoring tracks the operational performance of the serving infrastructure. For ML systems, this includes the standard software metrics plus ML-specific operational concerns:
| Metric | Software Standard | ML-Specific Extension |
|---|---|---|
| Latency | API response time | Breakdown by stage (feature retrieval, inference, re-ranking) |
| Throughput | Requests per second | Predictions per second, batch size utilization |
| Error rate | HTTP 5xx / total | Model loading errors, feature store timeouts, OOM errors |
| Resource utilization | CPU, memory, disk | GPU utilization, GPU memory, model cache hit rate |
| Availability | Uptime percentage | Model availability (serving latest version vs. fallback) |
SLOs and SLIs for ML Systems
Site Reliability Engineering (SRE) uses Service Level Indicators (SLIs) and Service Level Objectives (SLOs) to define and measure reliability. An SLI is a quantitative measure of a service's behavior; an SLO is a target value for an SLI.
For ML systems, the standard SLIs expand to include prediction quality:
from dataclasses import dataclass
from typing import Dict
@dataclass
class MLServiceSLO:
"""Service Level Objectives for an ML serving system.
Defines the reliability targets that the ML system must meet.
SLOs are negotiated between the ML team (system owner) and
the product team (system consumer). Each SLO has an associated
error budget: the amount of allowed unreliability per period.
Attributes:
service_name: Name of the ML service.
slos: Mapping from SLI name to target value.
error_budget_period_days: Period over which error budget
is computed (typically 28 or 30 days).
"""
service_name: str
slos: Dict[str, float]
error_budget_period_days: int = 28
def compute_error_budget(
self, sli_name: str, sli_value: float
) -> float:
"""Compute remaining error budget for an SLI.
The error budget is the difference between the SLO target
and 100% (for availability-type SLIs) or the headroom
between the current value and the target (for latency-type
SLIs).
Args:
sli_name: Name of the SLI.
sli_value: Current measured value of the SLI.
Returns:
Remaining error budget as a fraction (0 = exhausted,
1 = fully available).
"""
if sli_name not in self.slos:
raise ValueError(f"Unknown SLI: {sli_name}")
target = self.slos[sli_name]
if sli_name.endswith("_availability"):
# Availability SLIs: target is a minimum (e.g., 99.9%)
total_budget = 1.0 - target
consumed = max(0, target - sli_value)
return max(0, 1.0 - consumed / total_budget)
elif sli_name.endswith("_latency_p99"):
# Latency SLIs: target is a maximum (e.g., 50ms)
if sli_value <= target:
return 1.0
return max(0, 1.0 - (sli_value - target) / target)
else:
# Generic: assume target is a minimum
total_budget = 1.0 - target
if total_budget <= 0:
return 1.0 if sli_value >= target else 0.0
consumed = max(0, target - sli_value)
return max(0, 1.0 - consumed / total_budget)
# StreamRec SLOs
streamrec_slo = MLServiceSLO(
service_name="StreamRec Recommendation API",
slos={
"prediction_availability": 0.999, # 99.9% uptime
"prediction_latency_p99": 0.050, # 50ms p99
"feature_freshness_compliance": 0.995, # 99.5% of requests
# use fresh features
"model_staleness_days": 7.0, # Model < 7 days old
"drift_compliance": 0.99, # 99% of hours with
# PSI < 0.25
},
)
Error Budgets and Deployment Decisions
The error budget is the inverse of the SLO: if the SLO is 99.9% availability, the error budget is 0.1% — roughly 43 minutes of downtime per month. The error budget creates a quantitative framework for balancing reliability against velocity:
- Error budget available: The team has headroom. Deploy new models, experiment with features, take calculated risks.
- Error budget exhausted: The team has no headroom. Freeze deployments, focus on reliability improvements, fix the issues that consumed the budget.
For ML systems, the error budget framework applies to both system reliability (uptime, latency) and model reliability (drift compliance, prediction quality). A model deployment that causes a drift alert consumes model reliability error budget. A feature store outage that causes stale features consumes system reliability error budget. Both must be tracked.
30.7 Pillar 4: Business Metric Monitoring and the Feedback Loop
The fourth pillar connects model behavior to business outcomes. This is the most important layer and the most difficult to instrument, because the relationship between model predictions and business outcomes is indirect, delayed, and confounded.
Closing the Feedback Loop
For StreamRec, the feedback loop works as follows:
- The model recommends items to a user (prediction).
- The user interacts with the recommendations (feedback).
- The interaction data flows into the feature store (data pipeline).
- The model retrains on the updated data (continuous training, Chapter 29).
- The retrained model begins serving (deployment).
Each step introduces a delay, and the cumulative delay determines how quickly the system adapts to changes. Monitoring this feedback loop requires tracking not just the model's predictions but also the user's responses and the time between them.
from dataclasses import dataclass
from typing import Dict, List, Optional
from datetime import datetime, timedelta
import numpy as np
@dataclass
class FeedbackLoopMonitor:
"""Monitor the health of the ML feedback loop.
Tracks the delay between predictions and observed outcomes,
the rate at which feedback is received, and the alignment
between predicted and observed engagement.
Attributes:
prediction_log: List of (timestamp, predicted_score, item_id).
outcome_log: List of (timestamp, item_id, engaged).
max_feedback_delay: Maximum expected delay between prediction
and outcome (items with longer delays are considered lost).
"""
prediction_log: List[Dict] = None
outcome_log: List[Dict] = None
max_feedback_delay: timedelta = timedelta(hours=24)
def __post_init__(self):
if self.prediction_log is None:
self.prediction_log = []
if self.outcome_log is None:
self.outcome_log = []
def compute_feedback_rate(
self,
window: timedelta = timedelta(hours=1),
current_time: Optional[datetime] = None,
) -> Dict[str, float]:
"""Compute the rate at which predictions receive feedback.
A dropping feedback rate may indicate a logging pipeline
failure, a client-side bug, or a change in user behavior.
Args:
window: Time window for computation.
current_time: Current time (defaults to utcnow).
Returns:
Dictionary with feedback rate metrics.
"""
if current_time is None:
current_time = datetime.utcnow()
cutoff = current_time - window
mature_cutoff = current_time - self.max_feedback_delay
# Predictions old enough to have received feedback
mature_predictions = [
p for p in self.prediction_log
if cutoff <= p["timestamp"] <= mature_cutoff
]
if not mature_predictions:
return {
"feedback_rate": 0.0,
"mean_feedback_delay_seconds": 0.0,
"n_mature_predictions": 0,
}
# Match predictions to outcomes
outcome_map = {}
for o in self.outcome_log:
key = (o["item_id"], o.get("user_id"))
outcome_map[key] = o
matched = 0
delays = []
for p in mature_predictions:
key = (p["item_id"], p.get("user_id"))
if key in outcome_map:
matched += 1
delay = (
outcome_map[key]["timestamp"] - p["timestamp"]
)
delays.append(delay.total_seconds())
return {
"feedback_rate": matched / len(mature_predictions),
"mean_feedback_delay_seconds": (
float(np.mean(delays)) if delays else 0.0
),
"n_mature_predictions": len(mature_predictions),
}
30.8 Drift Detection: Population Stability Index (PSI)
Chapter 28 introduced PSI briefly for data validation. This section provides the full treatment needed for production drift detection.
Definition
The Population Stability Index measures the shift between a reference distribution $P$ (typically from training data) and a current distribution $Q$ (from serving data):
$$\text{PSI} = \sum_{i=1}^{B} (q_i - p_i) \cdot \ln\left(\frac{q_i}{p_i}\right)$$
where $p_i$ and $q_i$ are the proportions of reference and current data in bin $i$, and $B$ is the number of bins.
PSI is a symmetrized form of the Kullback-Leibler divergence. It has several properties that make it well-suited for production monitoring:
- Non-negative: PSI $\geq 0$, with equality only when $P = Q$.
- Symmetric: Unlike KL divergence, PSI is symmetric: $\text{PSI}(P, Q) = \text{PSI}(Q, P)$.
- Interpretable thresholds: Industry-standard thresholds (from credit risk modeling, where PSI is required by regulation) provide a calibrated interpretation.
- Decomposable: The total PSI is the sum of per-bin contributions, allowing you to localize which part of the distribution shifted.
Interpretation
| PSI Range | Interpretation | Recommended Action |
|---|---|---|
| < 0.10 | No significant shift | Continue monitoring |
| 0.10 - 0.25 | Moderate shift | Investigate root cause; may require model retraining |
| > 0.25 | Significant shift | Alert; investigate immediately; consider model rollback |
Implementation Details
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional
@dataclass
class PSIResult:
"""Result of a PSI computation.
Attributes:
feature_name: Name of the feature being monitored.
total_psi: Total PSI value across all bins.
bin_contributions: Per-bin PSI contributions, allowing
localization of the shift.
bin_edges: Bin edges used for the computation.
reference_proportions: Proportions in each bin for the
reference distribution.
current_proportions: Proportions in each bin for the
current distribution.
severity: Classification based on standard thresholds.
"""
feature_name: str
total_psi: float
bin_contributions: List[float]
bin_edges: List[float]
reference_proportions: List[float]
current_proportions: List[float]
severity: str
def top_contributing_bins(self, k: int = 3) -> List[Dict]:
"""Return the k bins contributing most to the PSI.
Useful for root cause analysis: if PSI is high, which
part of the distribution shifted most?
Args:
k: Number of top bins to return.
Returns:
List of dicts with bin range, contribution, and
direction of shift.
"""
indexed = [
{
"bin_index": i,
"range": (
float(self.bin_edges[i]),
float(self.bin_edges[i + 1]),
),
"contribution": self.bin_contributions[i],
"reference_proportion": self.reference_proportions[i],
"current_proportion": self.current_proportions[i],
"direction": (
"increase"
if self.current_proportions[i]
> self.reference_proportions[i]
else "decrease"
),
}
for i in range(len(self.bin_contributions))
]
indexed.sort(key=lambda x: x["contribution"], reverse=True)
return indexed[:k]
def compute_psi_detailed(
feature_name: str,
reference: np.ndarray,
current: np.ndarray,
n_bins: int = 10,
use_quantile_bins: bool = True,
) -> PSIResult:
"""Compute PSI with full diagnostics.
Uses quantile-based bins by default, which ensures equal
representation in the reference distribution and makes PSI
more sensitive to shifts in the tails.
Args:
feature_name: Name of the feature.
reference: Reference (training) distribution values.
current: Current (serving) distribution values.
n_bins: Number of bins.
use_quantile_bins: If True, use quantile-based bins from
the reference distribution. If False, use equal-width
bins spanning the union of both distributions.
Returns:
PSIResult with total PSI, per-bin contributions, and
severity classification.
"""
reference = reference[~np.isnan(reference)]
current = current[~np.isnan(current)]
if use_quantile_bins:
bin_edges = np.quantile(
reference, np.linspace(0, 1, n_bins + 1)
)
else:
combined_min = min(reference.min(), current.min())
combined_max = max(reference.max(), current.max())
bin_edges = np.linspace(combined_min, combined_max, n_bins + 1)
bin_edges = np.unique(bin_edges)
ref_counts, _ = np.histogram(reference, bins=bin_edges)
cur_counts, _ = np.histogram(current, bins=bin_edges)
ref_props = ref_counts / ref_counts.sum()
cur_props = cur_counts / cur_counts.sum()
# Clip to avoid division by zero or log(0)
ref_props = np.clip(ref_props, 1e-6, None)
cur_props = np.clip(cur_props, 1e-6, None)
# Renormalize after clipping
ref_props = ref_props / ref_props.sum()
cur_props = cur_props / cur_props.sum()
bin_contributions = (
(cur_props - ref_props) * np.log(cur_props / ref_props)
)
total_psi = float(np.sum(bin_contributions))
severity = (
"ok" if total_psi < 0.10
else "investigate" if total_psi < 0.25
else "alert"
)
return PSIResult(
feature_name=feature_name,
total_psi=total_psi,
bin_contributions=bin_contributions.tolist(),
bin_edges=bin_edges.tolist(),
reference_proportions=ref_props.tolist(),
current_proportions=cur_props.tolist(),
severity=severity,
)
PSI Limitations
PSI has important limitations:
-
Bin sensitivity. PSI depends on the number and placement of bins. Too few bins miss localized shifts; too many bins introduce noise. Quantile-based bins (from the reference distribution) mitigate this but do not eliminate it.
-
Sample size sensitivity. With small samples, PSI can be noisy. A rough guideline: require at least 100 observations per bin (i.e., at least $100 \times B$ total observations).
-
No statistical significance. PSI is a descriptive measure, not a hypothesis test. It does not produce a p-value. A PSI of 0.12 means "moderate shift," but it does not tell you whether the shift is statistically significant given the sample sizes.
-
Insensitivity to distribution shape. PSI measures the KL-divergence-like distance between bin proportions. Two distributions with the same bin proportions but very different shapes within bins will have PSI = 0.
30.9 Drift Detection: The Kolmogorov-Smirnov Test
The Kolmogorov-Smirnov (KS) test is a nonparametric hypothesis test for whether two samples come from the same distribution. Unlike PSI, it does not require binning and produces a p-value.
Definition
The KS statistic is the maximum absolute difference between the empirical cumulative distribution functions (ECDFs) of the two samples:
$$D = \max_x |F_{\text{ref}}(x) - F_{\text{cur}}(x)|$$
where $F_{\text{ref}}$ and $F_{\text{cur}}$ are the ECDFs of the reference and current samples. The null hypothesis is that both samples come from the same distribution; a small p-value rejects this null.
Implementation
import numpy as np
from scipy.stats import ks_2samp
from dataclasses import dataclass
from typing import Optional
@dataclass
class KSTestResult:
"""Result of a Kolmogorov-Smirnov two-sample test.
Attributes:
feature_name: Name of the feature.
statistic: KS statistic (max ECDF difference).
p_value: P-value from the KS test.
n_reference: Number of reference samples.
n_current: Number of current samples.
significant: Whether the test rejects H0 at the
specified significance level.
significance_level: Alpha used for the significance test.
"""
feature_name: str
statistic: float
p_value: float
n_reference: int
n_current: int
significant: bool
significance_level: float = 0.01
def compute_ks_test(
feature_name: str,
reference: np.ndarray,
current: np.ndarray,
significance_level: float = 0.01,
) -> KSTestResult:
"""Run a two-sample KS test for drift detection.
Args:
feature_name: Name of the feature.
reference: Reference (training) distribution values.
current: Current (serving) distribution values.
significance_level: Alpha for significance testing.
Default 0.01 (stricter than typical 0.05) to reduce
false alarms in continuous monitoring.
Returns:
KSTestResult with test statistic, p-value, and
significance decision.
"""
reference = reference[~np.isnan(reference)]
current = current[~np.isnan(current)]
stat, p_value = ks_2samp(reference, current)
return KSTestResult(
feature_name=feature_name,
statistic=float(stat),
p_value=float(p_value),
n_reference=len(reference),
n_current=len(current),
significant=p_value < significance_level,
significance_level=significance_level,
)
KS Test: Strengths and Limitations
Strengths: - No binning required — avoids PSI's bin sensitivity - Produces a p-value with well-understood statistical properties - Sensitive to any difference between distributions (location, scale, shape) - Nonparametric — no distributional assumptions
Limitations: - Overpowered at large sample sizes. With millions of serving predictions, the KS test will detect trivially small shifts that have no practical impact. A PSI of 0.01 and a KS p-value of $10^{-15}$ is not unusual — the shift is statistically significant but operationally meaningless. - Single-point summary. The KS statistic tells you the maximum ECDF gap but not where it occurs or how broad the shift is. PSI's per-bin decomposition is more informative for root cause analysis. - Not additive. Unlike PSI, the KS statistic of a multivariate distribution is not simply a function of the marginal KS statistics.
Addressing the Overpowering Problem
For production monitoring, use the KS statistic (the D value), not the p-value, as the alerting signal. The D statistic is sample-size independent and ranges from 0 (identical distributions) to 1 (non-overlapping distributions):
| KS Statistic D | Interpretation |
|---|---|
| < 0.05 | Negligible difference |
| 0.05 - 0.10 | Small but detectable difference |
| 0.10 - 0.20 | Moderate difference, investigate |
| > 0.20 | Large difference, alert |
30.10 Drift Detection: Jensen-Shannon Divergence
The Jensen-Shannon (JS) divergence addresses a key limitation of the Kullback-Leibler divergence (which underlies PSI): KL divergence is undefined when the current distribution assigns zero probability to a bin where the reference distribution assigns nonzero probability. JS divergence is always defined, bounded, and symmetric.
Definition
$$\text{JSD}(P \| Q) = \frac{1}{2} D_{\text{KL}}(P \| M) + \frac{1}{2} D_{\text{KL}}(Q \| M)$$
where $M = \frac{1}{2}(P + Q)$ is the midpoint distribution. The JS divergence is bounded: $0 \leq \text{JSD} \leq \ln(2) \approx 0.693$ (using natural logarithm). Taking the square root gives the Jensen-Shannon distance, which is a proper metric (satisfies the triangle inequality).
Comparison of Drift Detection Methods
| Property | PSI | KS Test | JS Divergence |
|---|---|---|---|
| Requires binning | Yes | No | Yes |
| Produces p-value | No | Yes | No (but bootstrapable) |
| Bounded | No ($0$ to $\infty$) | Yes ($0$ to $1$) | Yes ($0$ to $\ln 2$) |
| Symmetric | Yes | Yes | Yes |
| Decomposable by bin | Yes | No | Yes |
| Handles zero bins | Requires clipping | N/A | Naturally handles via midpoint |
| Industry standard | Credit risk, banking | General statistics | Information theory, NLP |
| Best for | Feature monitoring, regulatory | Statistical significance testing | Multi-modal distributions |
Practical Recommendation
Use all three methods in production, with different roles:
- PSI as the primary alerting metric for individual features (well-calibrated thresholds, regulatory precedent, per-bin decomposition for root cause analysis).
- KS test as a confirmation metric when PSI flags a shift (provides statistical rigor, useful for "is this shift real or noise?").
- JS divergence for comparing distributions of categorical features and prediction scores (naturally handles zero-probability categories, bounded range simplifies dashboard visualization).
import numpy as np
from dataclasses import dataclass, field
from typing import Dict, List
@dataclass
class DriftReport:
"""Comprehensive drift report for a single feature.
Combines PSI, KS, and JS metrics for a holistic view of
distributional change. Used for both automated alerting
and manual root cause analysis.
Attributes:
feature_name: Name of the feature.
psi: PSI result.
ks: KS test result.
js_divergence: JS divergence value.
overall_severity: Maximum severity across all metrics.
recommended_action: Suggested next step.
"""
feature_name: str
psi: "PSIResult"
ks: "KSTestResult"
js_divergence: float
overall_severity: str = ""
recommended_action: str = ""
def __post_init__(self):
severities = []
if self.psi.total_psi > 0.25:
severities.append("alert")
elif self.psi.total_psi > 0.10:
severities.append("investigate")
else:
severities.append("ok")
if self.ks.significant:
severities.append("investigate")
else:
severities.append("ok")
if self.js_divergence > 0.15:
severities.append("alert")
elif self.js_divergence > 0.05:
severities.append("investigate")
else:
severities.append("ok")
severity_order = {"ok": 0, "investigate": 1, "alert": 2}
self.overall_severity = max(
severities, key=lambda s: severity_order[s]
)
if self.overall_severity == "alert":
self.recommended_action = (
"Immediate investigation required. Check upstream "
"data pipelines, feature computation code, and "
"recent deployments. Consider model rollback if "
"business metrics are affected."
)
elif self.overall_severity == "investigate":
self.recommended_action = (
"Monitor trend over next 24 hours. Check if shift "
"is gradual (seasonal) or sudden (pipeline issue). "
"Correlate with business metrics."
)
else:
self.recommended_action = "Continue standard monitoring."
def generate_drift_report(
feature_name: str,
reference: np.ndarray,
current: np.ndarray,
n_bins: int = 10,
) -> DriftReport:
"""Generate a comprehensive drift report for a feature.
Args:
feature_name: Name of the feature.
reference: Reference distribution values.
current: Current distribution values.
n_bins: Number of bins for PSI and JS computation.
Returns:
DriftReport combining PSI, KS, and JS metrics.
"""
psi_result = compute_psi_detailed(
feature_name, reference, current, n_bins
)
ks_result = compute_ks_test(feature_name, reference, current)
# Compute JS divergence
bin_edges = np.quantile(reference, np.linspace(0, 1, n_bins + 1))
bin_edges = np.unique(bin_edges)
ref_counts, _ = np.histogram(reference, bins=bin_edges)
cur_counts, _ = np.histogram(current, bins=bin_edges)
ref_props = ref_counts / ref_counts.sum()
cur_props = cur_counts / cur_counts.sum()
ref_props = np.clip(ref_props, 1e-10, None)
cur_props = np.clip(cur_props, 1e-10, None)
m = 0.5 * (ref_props + cur_props)
kl_ref_m = np.sum(ref_props * np.log(ref_props / m))
kl_cur_m = np.sum(cur_props * np.log(cur_props / m))
js = 0.5 * kl_ref_m + 0.5 * kl_cur_m
return DriftReport(
feature_name=feature_name,
psi=psi_result,
ks=ks_result,
js_divergence=float(js),
)
30.11 Types of Drift
Drift detection methods measure that a distribution has shifted. Understanding what has shifted requires distinguishing between drift types:
Covariate Shift
Definition: The distribution of input features $P(X)$ changes, but the conditional distribution $P(Y|X)$ remains the same. The relationship between features and outcomes is stable; the population has changed.
Example: StreamRec launches in a new country. The user demographics shift (younger, more mobile-first), but the relationship between user features and content preferences is the same. The model may underperform because it was trained on a different population, but the underlying relationship has not changed.
Detection: Feature-level PSI/KS/JS monitoring. If input features drift but the model's prediction distribution does not (because the model is robust to the shift), the drift may be benign. If input features drift and predictions drift, retraining on the new population is warranted.
Concept Drift
Definition: The conditional distribution $P(Y|X)$ changes, even though $P(X)$ may remain the same. The relationship between features and outcomes has changed.
Example: A pandemic changes viewing patterns: users who previously watched comedies during evening hours now watch documentaries. The user features are the same (same demographics, same platform), but the mapping from features to preferences has shifted. The model's predictions are based on an outdated relationship.
Detection: Concept drift is harder to detect because it requires comparing predictions to outcomes. Monitor the gap between predicted and observed engagement rates over time. If the gap grows — the model consistently over- or under-predicts — concept drift is likely. Statistical process control charts (CUSUM, EWMA) on the prediction-error series are effective.
Label Drift
Definition: The marginal distribution of outcomes $P(Y)$ changes. This may or may not reflect a change in $P(Y|X)$.
Example: Meridian Financial's default rate rises from 3% to 5% due to an economic downturn. If the model's features capture the economic conditions that drive defaults, the model may adapt. If the features do not capture these conditions, the model's calibration is wrong.
Detection: Monitor the overall prediction distribution and compare to observed outcome rates (when available). For long-feedback-delay systems like credit scoring, monitor early warning signals: delinquency rates at 30, 60, 90 days as proxies for eventual default.
Prediction Drift
Definition: The distribution of model predictions $P(\hat{Y})$ changes, regardless of whether input or output distributions have changed. This can result from any of the above drift types, from a model deployment change, or from a feature pipeline change.
Detection: Monitor the prediction score distribution over time using PSI, KS, or JS divergence. Prediction drift is the most general signal — it indicates that something has changed, but root cause analysis is needed to determine what.
| Drift Type | What Changed | Detection Method | Response |
|---|---|---|---|
| Covariate shift | $P(X)$ | Feature PSI/KS/JS | Retrain on new population |
| Concept drift | $P(Y\|X)$ | Prediction-outcome gap | Retrain with recent data |
| Label drift | $P(Y)$ | Outcome rate monitoring | Recalibrate or retrain |
| Prediction drift | $P(\hat{Y})$ | Prediction PSI/KS/JS | Investigate root cause |
30.12 Alerting Infrastructure
Drift detection without alerting is a dashboard that nobody watches. Effective alerting transforms detection into action.
Alert Design Principles
-
Actionable. Every alert must have a clear next step. "Feature X PSI is 0.31" is not actionable. "Feature X (user_engagement_rate) PSI is 0.31, exceeding the 0.25 threshold. Top contributing bin: [0.0, 0.1) shifted from 15% to 35%. Runbook: https://wiki/runbooks/feature-drift" is actionable.
-
Tiered. Not every anomaly is an emergency. Use severity levels: - Info: Logged for analysis, no notification (PSI 0.10-0.15) - Warning: Notification to Slack channel (PSI 0.15-0.25) - Critical: PagerDuty alert to on-call (PSI > 0.25 + business metric impact)
-
Deduplicated. If a feature is drifting, the alert fires once, not every 15 seconds. Alerts have a cooldown period during which repeat alerts are suppressed.
-
Correlated. Multiple feature drift alerts firing simultaneously suggests a common root cause (upstream pipeline failure). Correlate alerts and present them as a single incident.
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Callable
from datetime import datetime, timedelta
from enum import Enum
import logging
logger = logging.getLogger(__name__)
class AlertSeverity(Enum):
INFO = "info"
WARNING = "warning"
CRITICAL = "critical"
class EscalationTarget(Enum):
LOG_ONLY = "log_only"
SLACK_CHANNEL = "slack_channel"
PAGERDUTY = "pagerduty"
EMAIL = "email"
@dataclass
class AlertRule:
"""A single alerting rule with threshold and escalation.
Attributes:
name: Human-readable rule name.
metric_name: Prometheus metric name or computed metric.
condition: Lambda that takes the metric value and returns
True if the alert should fire.
severity: Alert severity level.
escalation: Where to send the alert.
cooldown: Minimum time between consecutive firings.
runbook_url: Link to the runbook for this alert.
message_template: Template for the alert message.
"""
name: str
metric_name: str
condition: Callable[[float], bool]
severity: AlertSeverity
escalation: EscalationTarget
cooldown: timedelta = timedelta(hours=1)
runbook_url: str = ""
message_template: str = ""
@dataclass
class AlertState:
"""Tracks the state of a single alert rule.
Attributes:
rule: The alert rule this state tracks.
is_firing: Whether the alert is currently firing.
last_fired: Timestamp of the most recent firing.
fire_count: Total number of times this alert has fired.
acknowledged: Whether the current firing has been ack'd.
acknowledged_by: Who acknowledged the alert.
"""
rule: AlertRule
is_firing: bool = False
last_fired: Optional[datetime] = None
fire_count: int = 0
acknowledged: bool = False
acknowledged_by: str = ""
@dataclass
class AlertManager:
"""Manages alert evaluation, deduplication, and escalation.
Evaluates all configured alert rules against current metric
values, respects cooldown periods, and dispatches alerts to
the appropriate targets.
Attributes:
rules: List of alert rules to evaluate.
states: Mapping from rule name to current alert state.
escalation_handlers: Mapping from target to handler function.
"""
rules: List[AlertRule] = field(default_factory=list)
states: Dict[str, AlertState] = field(default_factory=dict)
escalation_handlers: Dict[
EscalationTarget, Callable[[str, AlertSeverity], None]
] = field(default_factory=dict)
def __post_init__(self):
for rule in self.rules:
if rule.name not in self.states:
self.states[rule.name] = AlertState(rule=rule)
def evaluate(
self,
metrics: Dict[str, float],
current_time: Optional[datetime] = None,
) -> List["AlertEvent"]:
"""Evaluate all rules against current metrics.
Args:
metrics: Mapping from metric name to current value.
current_time: Current time (defaults to utcnow).
Returns:
List of AlertEvents for rules that fired.
"""
if current_time is None:
current_time = datetime.utcnow()
events = []
for rule in self.rules:
if rule.metric_name not in metrics:
continue
value = metrics[rule.metric_name]
state = self.states[rule.name]
should_fire = rule.condition(value)
if should_fire and not state.is_firing:
# New alert
if (
state.last_fired is not None
and current_time - state.last_fired < rule.cooldown
):
continue # In cooldown period
state.is_firing = True
state.last_fired = current_time
state.fire_count += 1
state.acknowledged = False
message = rule.message_template.format(
metric_name=rule.metric_name,
value=value,
severity=rule.severity.value,
runbook=rule.runbook_url,
)
event = AlertEvent(
rule_name=rule.name,
severity=rule.severity,
metric_name=rule.metric_name,
metric_value=value,
message=message,
timestamp=current_time,
runbook_url=rule.runbook_url,
)
events.append(event)
# Dispatch to escalation target
handler = self.escalation_handlers.get(
rule.escalation
)
if handler:
handler(message, rule.severity)
else:
logger.warning(
f"No handler for {rule.escalation.value}"
)
elif not should_fire and state.is_firing:
# Alert resolved
state.is_firing = False
logger.info(f"Alert resolved: {rule.name}")
return events
@dataclass
class AlertEvent:
"""Record of a single alert firing.
Attributes:
rule_name: Name of the rule that fired.
severity: Severity level.
metric_name: Metric that triggered the alert.
metric_value: Value that crossed the threshold.
message: Formatted alert message.
timestamp: When the alert fired.
runbook_url: Link to the runbook.
"""
rule_name: str
severity: AlertSeverity
metric_name: str
metric_value: float
message: str
timestamp: datetime
runbook_url: str = ""
StreamRec Alert Configuration
streamrec_alerts = AlertManager(
rules=[
# --- Data quality alerts ---
AlertRule(
name="feature_drift_critical",
metric_name="streamrec_drift_psi",
condition=lambda v: v > 0.25,
severity=AlertSeverity.CRITICAL,
escalation=EscalationTarget.PAGERDUTY,
cooldown=timedelta(hours=4),
runbook_url="https://wiki.streamrec.dev/runbooks/feature-drift",
message_template=(
":rotating_light: CRITICAL: Feature drift detected. "
"{metric_name}={value:.4f} > 0.25. "
"Runbook: {runbook}"
),
),
AlertRule(
name="feature_drift_warning",
metric_name="streamrec_drift_psi",
condition=lambda v: 0.10 < v <= 0.25,
severity=AlertSeverity.WARNING,
escalation=EscalationTarget.SLACK_CHANNEL,
cooldown=timedelta(hours=12),
runbook_url="https://wiki.streamrec.dev/runbooks/feature-drift",
message_template=(
":warning: WARNING: Moderate feature drift. "
"{metric_name}={value:.4f}. "
"Monitor trend. Runbook: {runbook}"
),
),
AlertRule(
name="data_freshness_critical",
metric_name="streamrec_data_freshness_seconds",
condition=lambda v: v > 7200, # > 2 hours
severity=AlertSeverity.CRITICAL,
escalation=EscalationTarget.PAGERDUTY,
cooldown=timedelta(hours=1),
runbook_url=(
"https://wiki.streamrec.dev/runbooks/data-freshness"
),
message_template=(
":rotating_light: CRITICAL: Stale data. "
"Feature source age={value:.0f}s (>{metric_name}). "
"Runbook: {runbook}"
),
),
# --- Model performance alerts ---
AlertRule(
name="prediction_drift",
metric_name="prediction_score_psi",
condition=lambda v: v > 0.20,
severity=AlertSeverity.WARNING,
escalation=EscalationTarget.SLACK_CHANNEL,
cooldown=timedelta(hours=6),
runbook_url=(
"https://wiki.streamrec.dev/runbooks/prediction-drift"
),
message_template=(
":chart_with_upwards_trend: Prediction score "
"distribution shift. PSI={value:.4f}. "
"Runbook: {runbook}"
),
),
# --- System health alerts ---
AlertRule(
name="latency_p99_critical",
metric_name="streamrec_prediction_latency_p99",
condition=lambda v: v > 0.100, # > 100ms
severity=AlertSeverity.CRITICAL,
escalation=EscalationTarget.PAGERDUTY,
cooldown=timedelta(minutes=15),
runbook_url=(
"https://wiki.streamrec.dev/runbooks/high-latency"
),
message_template=(
":rotating_light: CRITICAL: p99 latency "
"{value:.0f}ms > 100ms SLO. Runbook: {runbook}"
),
),
AlertRule(
name="error_rate_warning",
metric_name="streamrec_error_rate",
condition=lambda v: v > 0.001, # > 0.1%
severity=AlertSeverity.WARNING,
escalation=EscalationTarget.SLACK_CHANNEL,
cooldown=timedelta(hours=1),
runbook_url=(
"https://wiki.streamrec.dev/runbooks/error-rate"
),
message_template=(
":warning: Error rate {value:.4%} > 0.1%. "
"Runbook: {runbook}"
),
),
],
)
30.13 Escalation and On-Call
Alerting tells you something is wrong. Escalation ensures the right person responds.
On-Call Rotation
ML systems require a specialized on-call rotation. Unlike traditional software on-call (where the responder needs system administration skills), ML on-call requires a hybrid of software engineering and data science expertise. The responder must be able to:
- Diagnose whether a problem is a system issue (latency, errors, capacity) or a model issue (drift, degraded predictions, data quality)
- Roll back a model to a previous version if the current model is degrading
- Investigate feature pipeline failures and assess their impact on model quality
- Make the judgment call: "Is this a data problem that will self-resolve, or a model problem that requires retraining?"
Escalation Policy
from dataclasses import dataclass
from typing import List
from datetime import timedelta
@dataclass
class EscalationLevel:
"""A single level in the escalation chain.
Attributes:
level: Numeric level (1 = first responder).
role: Who is paged at this level.
notify_after: Time after initial alert before escalating
to this level.
contact_method: How to reach the responder.
"""
level: int
role: str
notify_after: timedelta
contact_method: str
@dataclass
class EscalationPolicy:
"""Escalation policy for ML system incidents.
Defines who to page, when to escalate, and how to reach
each responder. Designed for ML-specific incidents where
the first responder is an ML engineer and escalation reaches
data scientists and engineering leadership.
Attributes:
name: Policy name.
levels: Ordered list of escalation levels.
"""
name: str
levels: List[EscalationLevel]
streamrec_escalation = EscalationPolicy(
name="StreamRec ML Incidents",
levels=[
EscalationLevel(
level=1,
role="ML Engineer On-Call",
notify_after=timedelta(minutes=0),
contact_method="PagerDuty + Slack DM",
),
EscalationLevel(
level=2,
role="Senior ML Engineer",
notify_after=timedelta(minutes=15),
contact_method="PagerDuty + Phone",
),
EscalationLevel(
level=3,
role="ML Team Lead",
notify_after=timedelta(minutes=30),
contact_method="PagerDuty + Phone + SMS",
),
EscalationLevel(
level=4,
role="VP Engineering",
notify_after=timedelta(hours=1),
contact_method="Phone + SMS",
),
],
)
Runbooks
A runbook is a documented procedure for responding to a specific alert. Good runbooks reduce mean time to resolution (MTTR) by converting on-call response from an improvisation exercise into a checklist:
| Runbook Section | Content |
|---|---|
| Alert description | What triggered the alert and what it means |
| Impact assessment | Who is affected and how severely |
| Diagnostic steps | Step-by-step commands/queries to diagnose the issue |
| Mitigation options | Immediate actions to reduce impact (e.g., model rollback, feature fallback) |
| Root cause investigation | Deeper analysis once mitigation is in place |
| Escalation criteria | When to escalate to the next level |
| Communication template | Message template for stakeholder communication |
Example runbook excerpt for the feature_drift_critical alert:
# Runbook: Feature Drift Critical (PSI > 0.25)
# URL: https://wiki.streamrec.dev/runbooks/feature-drift
## 1. Impact Assessment (first 5 minutes)
- Check which feature(s) are drifting: Grafana dashboard > Model Health > Feature Drift
- Check if business metrics are affected: Grafana dashboard > Executive > CTR / Engagement
- Check if prediction distribution has shifted: Grafana dashboard > Model Health > Score Distribution
## 2. Diagnostic Steps
- Identify top contributing PSI bins:
SELECT feature_name, bin_range, psi_contribution
FROM drift_results
WHERE timestamp > NOW() - INTERVAL '1 hour'
ORDER BY psi_contribution DESC
LIMIT 10;
- Check for upstream pipeline changes:
git log --since="24 hours ago" -- src/features/
- Check feature store health:
curl -s http://feature-store:8080/health | jq '.sources'
## 3. Mitigation
- If business metrics are affected AND PSI > 0.50:
Deploy fallback model: kubectl rollout undo deployment/streamrec-model
- If business metrics are stable:
Continue monitoring for 4 hours. If PSI remains elevated, retrain model.
## 4. Escalation
- Escalate to Level 2 if: mitigation does not reduce PSI within 1 hour
- Escalate to Level 3 if: business metric degradation exceeds 5%
30.14 Incident Response for ML Systems
Incident response for ML systems follows the same principles as traditional incident response (detection, triage, mitigation, resolution, post-mortem) but with ML-specific considerations at each stage.
The ML Incident Lifecycle
Detection → Triage → Mitigation → Resolution → Post-Mortem
│
▼
Preventive Action
(new monitors, tests,
runbooks, design changes)
Detection. ML incidents are detected through three channels: 1. Automated alerts: Drift detection, latency spikes, error rate increases 2. Business metric anomalies: CTR drop, engagement decline (often the first signal for concept drift) 3. User reports: "The recommendations have been terrible lately" — subjective but valuable
Triage. Classify the incident along two dimensions:
| Dimension | Categories |
|---|---|
| Severity | SEV-1 (revenue impact > $X/hour), SEV-2 (user-visible degradation), SEV-3 (detectable but limited impact), SEV-4 (monitoring-only) |
| Type | System (infrastructure failure), Data (pipeline/quality issue), Model (drift/degradation), Integration (upstream/downstream change) |
Mitigation. The priority is to reduce impact, not to fix the root cause. Common mitigation actions:
| Action | When to Use | Latency |
|---|---|---|
| Model rollback | Current model is worse than previous | Minutes |
| Feature fallback | A feature is corrupted; fall back to default value | Minutes |
| Traffic shift | Route traffic to a fallback model or rule-based system | Minutes |
| Retraining | Data has shifted and a new model is needed | Hours |
| Pipeline fix | Data pipeline bug is the root cause | Hours to days |
Resolution. Fix the root cause, validate the fix, and restore normal operation.
Post-mortem. The most important step — and the most commonly skipped.
30.15 Blameless Post-Mortems
A blameless post-mortem is a structured analysis of an incident that focuses on systemic causes rather than individual blame. The "blameless" framing is not soft-hearted kindness — it is an engineering decision. Blame causes people to hide mistakes, which prevents the organization from learning. Blamelessness causes people to report problems early, which allows the organization to fix systemic weaknesses.
Post-Mortem Structure
from dataclasses import dataclass, field
from typing import List, Dict, Optional
from datetime import datetime, timedelta
@dataclass
class IncidentTimeline:
"""Single event in the incident timeline.
Attributes:
timestamp: When the event occurred.
description: What happened.
actor: Who performed the action (person or system).
category: Type of event (detection, communication,
mitigation, investigation, resolution).
"""
timestamp: datetime
description: str
actor: str
category: str
@dataclass
class PostMortem:
"""Blameless post-mortem for an ML system incident.
Based on the Google SRE post-mortem template, extended
for ML-specific root causes and preventive actions.
Attributes:
title: Incident title.
incident_id: Unique incident identifier.
severity: SEV-1 through SEV-4.
date: Date of the incident.
duration: Total incident duration.
impact: Description of user/business impact.
summary: 2-3 sentence summary of what happened.
timeline: Ordered list of events during the incident.
root_causes: List of root causes (systemic, not personal).
contributing_factors: Factors that made the incident worse
or delayed detection/resolution.
detection: How the incident was detected and how long
it took.
mitigation: What was done to reduce impact.
resolution: What was done to fix the root cause.
action_items: Concrete, assigned, deadlined action items.
lessons_learned: What the team learned.
ml_specific: ML-specific analysis (drift type, model
impact, data lineage).
"""
title: str
incident_id: str
severity: str
date: datetime
duration: timedelta
impact: str
summary: str
timeline: List[IncidentTimeline] = field(default_factory=list)
root_causes: List[str] = field(default_factory=list)
contributing_factors: List[str] = field(default_factory=list)
detection: str = ""
mitigation: str = ""
resolution: str = ""
action_items: List[Dict[str, str]] = field(default_factory=list)
lessons_learned: List[str] = field(default_factory=list)
ml_specific: Dict[str, str] = field(default_factory=dict)
def time_to_detect(self) -> Optional[timedelta]:
"""Compute time from incident start to detection.
Returns:
Timedelta from first anomaly to first detection event,
or None if timeline is empty.
"""
anomaly_events = [
e for e in self.timeline
if e.category in ("anomaly", "trigger")
]
detection_events = [
e for e in self.timeline
if e.category == "detection"
]
if anomaly_events and detection_events:
return (
detection_events[0].timestamp
- anomaly_events[0].timestamp
)
return None
def time_to_mitigate(self) -> Optional[timedelta]:
"""Compute time from detection to mitigation.
Returns:
Timedelta from first detection to first mitigation
event, or None if timeline is incomplete.
"""
detection_events = [
e for e in self.timeline
if e.category == "detection"
]
mitigation_events = [
e for e in self.timeline
if e.category == "mitigation"
]
if detection_events and mitigation_events:
return (
mitigation_events[0].timestamp
- detection_events[0].timestamp
)
return None
def format_for_review(self) -> str:
"""Format the post-mortem for team review.
Returns:
Markdown-formatted post-mortem document.
"""
lines = [
f"# Post-Mortem: {self.title}",
f"**Incident ID:** {self.incident_id}",
f"**Severity:** {self.severity}",
f"**Date:** {self.date.strftime('%Y-%m-%d')}",
f"**Duration:** {self.duration}",
"",
f"## Impact",
self.impact,
"",
f"## Summary",
self.summary,
"",
f"## Timeline",
]
for event in self.timeline:
lines.append(
f"- **{event.timestamp.strftime('%H:%M')}** "
f"[{event.category}] {event.description} "
f"({event.actor})"
)
lines.append("")
lines.append("## Root Causes")
for cause in self.root_causes:
lines.append(f"- {cause}")
lines.append("")
lines.append("## Contributing Factors")
for factor in self.contributing_factors:
lines.append(f"- {factor}")
lines.append("")
lines.append("## ML-Specific Analysis")
for key, value in self.ml_specific.items():
lines.append(f"- **{key}:** {value}")
lines.append("")
lines.append("## Action Items")
for item in self.action_items:
lines.append(
f"- [ ] {item['action']} "
f"(Owner: {item.get('owner', 'TBD')}, "
f"Due: {item.get('due', 'TBD')})"
)
lines.append("")
lines.append("## Lessons Learned")
for lesson in self.lessons_learned:
lines.append(f"- {lesson}")
return "\n".join(lines)
ML-Specific Root Cause Categories
Traditional post-mortems categorize root causes as: human error, software bug, hardware failure, capacity, dependency. ML post-mortems add:
| ML Root Cause | Description | Example |
|---|---|---|
| Data pipeline change | Upstream data schema, semantics, or volume changed | Feature encoded as timestamp instead of integer |
| Distribution shift | Input distribution changed beyond training range | New country launch, seasonal change |
| Concept drift | Relationship between features and outcomes changed | Pandemic changes viewing patterns |
| Training-serving skew | Training and serving compute different features | Batch vs. streaming computation mismatch |
| Label leakage | Training data contained future information | Feature computed from post-prediction events |
| Feedback loop | Model's own predictions influenced training data | Popular items get recommended more, appear more popular |
| Stale model | Model was not retrained despite distribution change | Retraining pipeline was silently broken for 3 weeks |
30.16 Putting It Together: The Monitoring Architecture
The complete monitoring architecture for an ML system integrates data quality, model performance, system health, and business metrics into a unified observability platform:
┌─────────────────────────────────────────────────────────────┐
│ Business Metrics │
│ (CTR, engagement, revenue, user satisfaction) │
├─────────────────────────────────────────────────────────────┤
│ Model Performance │
│ (Prediction drift, accuracy proxy, calibration) │
├─────────────────────────────────────────────────────────────┤
│ Data Quality │
│ (Feature drift PSI/KS/JS, null rates, freshness, skew) │
├─────────────────────────────────────────────────────────────┤
│ System Health │
│ (Latency, throughput, errors, GPU utilization, SLOs) │
├─────────────────────────────────────────────────────────────┤
│ Telemetry Infrastructure │
│ (Prometheus metrics, structured logs, distributed traces) │
└─────────────────────────────────────────────────────────────┘
│ │ │
┌────▼────┐ ┌────▼────┐ ┌────▼────┐
│ Grafana │ │ Alert │ │ Log │
│Dashboard │ │ Manager │ │ Store │
│ │ │ │ │ (Loki) │
└─────────┘ └─────────┘ └─────────┘
Each layer monitors a different aspect of system health, and problems propagate upward: a data quality issue (layer 3) manifests as a model performance degradation (layer 2), which eventually manifests as a business metric decline (layer 1). Effective monitoring detects the problem at the lowest layer possible, before it propagates.
30.17 Progressive Project M14: StreamRec Monitoring Dashboard
Milestone M14 builds on M12 (testing infrastructure, Chapter 28) and M13 (CI/CD and deployment, Chapter 29). By the end of this milestone, the StreamRec system has a comprehensive monitoring dashboard with automated drift detection, alerting, and incident response procedures.
Task 1: Reference Distribution Computation
Compute reference distributions for all features used by the StreamRec model. Store them as ReferenceDistribution objects serialized to JSON. These reference distributions become the baseline for drift detection.
Acceptance criteria: - Reference distributions computed for all serving features ($\geq 30$ features) - Each reference computed from the same data used to train the current production model - Distributions stored in a versioned artifact store (MLflow or similar)
Task 2: Drift Detection Pipeline
Implement a drift detection pipeline that runs hourly on serving data:
- Collect a sample of the last hour's feature vectors (from the prediction log)
- Compute PSI, KS, and JS divergence for each feature against its reference distribution
- Compute PSI for the prediction score distribution
- Publish all drift metrics to Prometheus
Acceptance criteria: - Pipeline runs hourly with $< 5$ minutes execution time - PSI, KS statistic, and JS divergence available as Prometheus gauges for every feature - Prediction score distribution PSI available as a separate metric - Dashboard panel in Grafana shows drift trends for top 10 features
Task 3: Alerting Configuration
Configure the AlertManager with rules for:
- Feature drift (PSI warning at 0.10, critical at 0.25)
- Data freshness (critical if any source > 2x expected age)
- Prediction drift (warning at 0.15, critical at 0.25)
- Latency (warning at 75ms p99, critical at 100ms p99)
- Error rate (warning at 0.1%, critical at 0.5%)
Acceptance criteria: - All rules configured with appropriate cooldown periods - Warning alerts go to Slack channel - Critical alerts go to PagerDuty - Each alert includes a link to the corresponding runbook
Task 4: Grafana Dashboard
Build a four-layer Grafana dashboard:
| Layer | Panels |
|---|---|
| Business | CTR (7-day rolling), completion rate, session length, recommendation coverage |
| Model | Prediction score distribution (histogram), drift PSI (time series), accuracy proxy (if available) |
| Data | Top-10 feature PSI (heatmap), null rates (time series), freshness (gauge per source) |
| System | Latency p50/p95/p99, throughput (QPS), error rate, GPU utilization, SLO error budget |
Acceptance criteria: - All four layers visible on a single dashboard - Time range selector allows zooming from 1 hour to 30 days - Annotations mark model deployments and pipeline changes
Task 5: Runbooks and Escalation
Write runbooks for the 5 most likely incident types: 1. Feature drift (data pipeline change) 2. Stale data (feature store failure) 3. Prediction drift (concept drift or model issue) 4. High latency (model or infrastructure issue) 5. Business metric drop (unknown root cause)
Configure the escalation policy with 4 levels.
Acceptance criteria: - Each runbook includes: impact assessment, diagnostic steps, mitigation options, escalation criteria - Escalation policy documented and integrated with PagerDuty - On-call rotation established with $\geq 3$ trained responders
Track A (Complete): Implement Tasks 1-5 with the provided configurations.
Track B (Extended): Add training-serving skew detection (compare training-time feature distributions to serving-time distributions). Implement concept drift detection using prediction-outcome gap monitoring.
Track C (Production): Implement a full post-mortem template and conduct a tabletop exercise: simulate a feature pipeline failure and walk through the detection, triage, mitigation, and post-mortem process. Write the post-mortem document.
30.18 Summary
A model without monitoring is a model waiting to fail silently. Traditional software monitoring — latency, throughput, error rate — is necessary but not sufficient for ML systems. The four pillars of ML monitoring are data quality (feature distributions, null rates, freshness, training-serving skew), model performance (prediction distributions, accuracy proxies, business metric alignment), system health (latency, throughput, errors, resource utilization, SLOs), and business metrics (CTR, engagement, revenue).
Drift detection is the technical core of ML monitoring. PSI provides interpretable, decomposable drift scores with industry-standard thresholds. The KS test provides statistical rigor and sample-size-independent sensitivity. JS divergence handles zero-probability bins naturally and provides bounded, symmetric scores. In production, use all three: PSI for alerting, KS for confirmation, JS for categorical features.
Alerting transforms detection into action. Good alerts are actionable (include runbook links), tiered (info/warning/critical), deduplicated (cooldown periods), and correlated (group related alerts). Escalation policies ensure the right person responds at the right time.
Incident response for ML systems follows the standard lifecycle (detection, triage, mitigation, resolution, post-mortem) with ML-specific extensions: model rollback as a mitigation action, drift type analysis as part of root cause investigation, and training-serving skew as a root cause category. Blameless post-mortems are the most important step — they convert individual incidents into systemic improvements.
The next chapter (Chapter 31) turns from operational concerns to ethical ones: fairness in machine learning. The monitoring infrastructure built here will be extended in Chapter 31 to track fairness metrics in production, detect disparate impact drift, and alert on fairness SLO violations — because fairness, like model quality, can degrade silently.
References
Beyer, Betsy, Chris Jones, Jennifer Petoff, and Niall Richard Murphy, eds. Site Reliability Engineering: How Google Runs Production Systems. O'Reilly Media, 2016.
Breck, Eric, Shanqing Cai, Eric Nielsen, Michael Salib, and D. Sculley. "The ML Test Score: A Rubric for ML Production Readiness and Technical Debt Reduction." In Proceedings of IEEE Big Data, 2017.
Klaise, Janis, Arnaud Van Looveren, Giovanni Vacanti, and Alexandru Coca. "Monitoring Machine Learning Models in Production." arXiv preprint arXiv:2007.06299, 2020.
Lu, Jie, Anjin Liu, Fan Dong, Feng Gu, Joao Gama, and Guangquan Zhang. "Learning under Concept Drift: A Review." IEEE Transactions on Knowledge and Data Engineering 31, no. 12 (2019): 2346-2363.
Sculley, D., Gary Holt, Daniel Golovin, Eugene Davydov, Todd Phillips, Dietmar Ebner, Vinay Chaudhary, Michael Young, Jean-Francois Crespo, and Dan Dennison. "Hidden Technical Debt in Machine Learning Systems." In Advances in Neural Information Processing Systems (NeurIPS), 2015.
Prometheus Authors. "Prometheus Documentation." https://prometheus.io/docs/.
Grafana Labs. "Grafana Documentation." https://grafana.com/docs/.