Case Study 1: StreamFlow's FastAPI + Docker Deployment


Background

StreamFlow's data science team has built, tuned, and tracked a subscriber churn prediction model across this textbook. They trained a logistic regression baseline in Chapter 11, compared tree-based methods in Chapters 13--14, tuned hyperparameters with Optuna in Chapter 18, interpreted the model with SHAP in Chapter 19, refactored the code into a proper Python package in Chapter 29, and tracked everything in MLflow in Chapter 30.

The model is good. It has an AUC of 0.8834 on the holdout set. The SHAP values are intuitive --- contract_type, tenure, and monthly_charges are the top drivers, which matches the business team's understanding. The MLflow Model Registry has it registered as streamflow-churn-predictor@champion.

Now the product team has a request: when a subscriber logs into the StreamFlow portal and navigates to "Manage Subscription," the portal should display a personalized retention banner if the subscriber's churn probability exceeds 0.6. The banner shows the top three reasons for the churn risk (in plain language) and offers a targeted discount.

That means the portal backend needs to call the churn model in real time --- a REST API call during page load, returning a churn probability and SHAP explanations within 200 milliseconds. The model can no longer live in a notebook. It needs an API. And the API needs a container. And the container needs a cloud deployment.

This case study follows StreamFlow's ML engineer, Priya, as she takes the model from the MLflow registry to a production endpoint on AWS.


Phase 1: Exporting the Model

Priya starts by loading the champion model from the MLflow registry and saving it as a self-contained artifact.

import mlflow
import joblib

mlflow.set_tracking_uri("http://mlflow.streamflow.internal:5000")

# Load the champion model from the registry
model = mlflow.pyfunc.load_model(
    "models:/streamflow-churn-predictor@champion"
)

# The underlying model is an XGBoost classifier
# Extract it for direct scikit-learn-compatible usage
xgb_model = model._model_impl.xgb_model

# Save with joblib for the serving layer
joblib.dump(xgb_model, "model/churn_model.joblib")
print(f"Model saved. Feature names: {xgb_model.get_booster().feature_names}")

Practical Tip --- Priya extracts the native XGBoost model from the MLflow wrapper because it is simpler to serve directly. An alternative is to serve the MLflow model using mlflow models serve, which handles the loading automatically. Both approaches work; the choice depends on how much control you need over the API layer.

She records the model's metadata:

model_metadata = {
    "model_version": "v2.3.1",
    "mlflow_run_id": "a7c3e9f2b4d8",
    "training_date": "2024-03-10",
    "training_data_version": "streamflow-v3-2024-03",
    "test_auc": 0.8834,
    "test_f1": 0.7821,
    "features": [
        "tenure", "monthly_charges", "total_charges", "contract_type",
        "payment_method", "num_support_tickets", "internet_service",
        "streaming_services", "paperless_billing", "senior_citizen",
    ],
}

This metadata will be embedded in every API response so that downstream systems can trace any prediction back to the exact model version.


Phase 2: Building the API

Priya follows the pattern from the chapter, with a few production additions.

Project Structure

streamflow-churn-api/
    app.py
    schemas.py
    preprocessing.py
    test_api.py
    model/
        churn_model.joblib
    requirements.txt
    Dockerfile
    docker-compose.yml
    .dockerignore

She separates preprocessing into its own module so it can be tested independently.

preprocessing.py

# preprocessing.py --- Feature encoding logic
import pandas as pd
from schemas import ChurnPredictionRequest

FEATURE_ORDER = [
    "tenure", "monthly_charges", "total_charges", "contract_type",
    "payment_method", "num_support_tickets", "internet_service",
    "streaming_services", "paperless_billing", "senior_citizen",
]

CONTRACT_MAP = {"month-to-month": 0, "one-year": 1, "two-year": 2}
PAYMENT_MAP = {
    "electronic_check": 0, "mailed_check": 1,
    "bank_transfer": 2, "credit_card": 3,
}
INTERNET_MAP = {"none": 0, "dsl": 1, "fiber_optic": 2}


def encode_request(request: ChurnPredictionRequest) -> pd.DataFrame:
    """Convert a validated request into a DataFrame matching training schema."""
    data = {
        "tenure": request.tenure,
        "monthly_charges": request.monthly_charges,
        "total_charges": request.total_charges,
        "contract_type": CONTRACT_MAP[request.contract_type],
        "payment_method": PAYMENT_MAP[request.payment_method],
        "num_support_tickets": request.num_support_tickets,
        "internet_service": INTERNET_MAP[request.internet_service],
        "streaming_services": request.streaming_services,
        "paperless_billing": int(request.paperless_billing),
        "senior_citizen": int(request.senior_citizen),
    }
    return pd.DataFrame([data], columns=FEATURE_ORDER)


def encode_batch(requests: list[ChurnPredictionRequest]) -> pd.DataFrame:
    """Convert a list of requests into a single DataFrame for vectorized inference."""
    rows = []
    for req in requests:
        rows.append({
            "tenure": req.tenure,
            "monthly_charges": req.monthly_charges,
            "total_charges": req.total_charges,
            "contract_type": CONTRACT_MAP[req.contract_type],
            "payment_method": PAYMENT_MAP[req.payment_method],
            "num_support_tickets": req.num_support_tickets,
            "internet_service": INTERNET_MAP[req.internet_service],
            "streaming_services": req.streaming_services,
            "paperless_billing": int(req.paperless_billing),
            "senior_citizen": int(req.senior_citizen),
        })
    return pd.DataFrame(rows, columns=FEATURE_ORDER)

schemas.py

# schemas.py --- Pydantic models for request/response
from pydantic import BaseModel, Field
from typing import Literal


class ChurnPredictionRequest(BaseModel):
    tenure: int = Field(..., ge=0, le=120, description="Months subscribed")
    monthly_charges: float = Field(..., ge=0, description="Monthly charge ($)")
    total_charges: float = Field(..., ge=0, description="Total charges to date ($)")
    contract_type: Literal["month-to-month", "one-year", "two-year"] = Field(
        ..., description="Contract type"
    )
    payment_method: Literal[
        "electronic_check", "mailed_check", "bank_transfer", "credit_card"
    ] = Field(..., description="Payment method")
    num_support_tickets: int = Field(
        ..., ge=0, description="Support tickets in last 6 months"
    )
    internet_service: Literal["none", "dsl", "fiber_optic"] = Field(
        ..., description="Internet service type"
    )
    streaming_services: int = Field(
        ..., ge=0, le=4, description="Number of streaming services (0-4)"
    )
    paperless_billing: bool = Field(..., description="Uses paperless billing")
    senior_citizen: bool = Field(..., description="Is senior citizen")


class ShapReason(BaseModel):
    feature: str
    value: float
    shap_contribution: float
    direction: Literal["increases_risk", "decreases_risk"]
    plain_language: str = Field(
        ..., description="Human-readable explanation for the portal UI"
    )


class ChurnPredictionResponse(BaseModel):
    churn_probability: float = Field(..., ge=0, le=1)
    risk_tier: Literal["low", "medium", "high"]
    top_reasons: list[ShapReason]
    model_version: str
    show_retention_banner: bool = Field(
        ..., description="True if churn_probability > 0.6"
    )

Note the Literal type for categorical fields. Instead of accepting any string and checking it in the encoding function, Pydantic rejects invalid values immediately with a clear error message. And note show_retention_banner --- a boolean derived from the probability that the portal can use directly without implementing threshold logic.

Also note plain_language in ShapReason. The portal will display these to customers, so technical feature names need to be translated.

app.py

# app.py --- StreamFlow Churn Prediction API
import os
import time
import logging
from contextlib import asynccontextmanager

import joblib
import numpy as np
import pandas as pd
import shap
from fastapi import FastAPI, Request, HTTPException
from schemas import (
    ChurnPredictionRequest, ChurnPredictionResponse, ShapReason,
)
from preprocessing import encode_request, FEATURE_ORDER

# --- Configuration ---
MODEL_VERSION = os.getenv("MODEL_VERSION", "v2.3.1")
MODEL_PATH = os.getenv("MODEL_PATH", "model/churn_model.joblib")
CHURN_THRESHOLD = float(os.getenv("CHURN_THRESHOLD", "0.6"))

# --- Logging ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("churn-api")

# --- Plain-language mapping for SHAP reasons ---
FEATURE_DESCRIPTIONS = {
    "tenure": "account age",
    "monthly_charges": "monthly bill amount",
    "total_charges": "total amount spent",
    "contract_type": "contract type",
    "payment_method": "payment method",
    "num_support_tickets": "recent support tickets",
    "internet_service": "internet service type",
    "streaming_services": "number of streaming add-ons",
    "paperless_billing": "paperless billing setting",
    "senior_citizen": "senior citizen status",
}

DIRECTION_TEMPLATES = {
    "tenure": {
        "increases_risk": "Short account history increases churn risk",
        "decreases_risk": "Long account history lowers churn risk",
    },
    "monthly_charges": {
        "increases_risk": "Higher monthly charges increase churn risk",
        "decreases_risk": "Lower monthly charges reduce churn risk",
    },
    "total_charges": {
        "increases_risk": "Lower total spend suggests less engagement",
        "decreases_risk": "Higher total spend indicates engagement",
    },
    "contract_type": {
        "increases_risk": "Month-to-month contract increases churn risk",
        "decreases_risk": "Longer contract commitment lowers churn risk",
    },
    "payment_method": {
        "increases_risk": "Payment method is associated with higher churn",
        "decreases_risk": "Payment method is associated with lower churn",
    },
    "num_support_tickets": {
        "increases_risk": "Recent support issues increase churn risk",
        "decreases_risk": "Few support issues suggest satisfaction",
    },
    "internet_service": {
        "increases_risk": "Internet service type is associated with higher churn",
        "decreases_risk": "Internet service type is associated with lower churn",
    },
    "streaming_services": {
        "increases_risk": "Streaming bundle level affects churn risk",
        "decreases_risk": "Multiple streaming services increase engagement",
    },
    "paperless_billing": {
        "increases_risk": "Paperless billing is associated with higher churn",
        "decreases_risk": "Billing preference lowers churn risk",
    },
    "senior_citizen": {
        "increases_risk": "Senior citizen status affects churn risk",
        "decreases_risk": "Age group is associated with lower churn",
    },
}


# --- Application Lifecycle ---
@asynccontextmanager
async def lifespan(app: FastAPI):
    """Load model and explainer at startup; clean up at shutdown."""
    global model, explainer
    logger.info(f"Loading model from {MODEL_PATH} (version {MODEL_VERSION})")
    model = joblib.load(MODEL_PATH)
    explainer = shap.TreeExplainer(model)
    logger.info("Model and explainer loaded successfully")
    yield
    logger.info("Shutting down")


app = FastAPI(
    title="StreamFlow Churn Predictor",
    description="Real-time churn prediction with SHAP explanations",
    version=MODEL_VERSION,
    lifespan=lifespan,
)


# --- Middleware ---
@app.middleware("http")
async def log_requests(request: Request, call_next):
    start = time.time()
    response = await call_next(request)
    elapsed_ms = (time.time() - start) * 1000
    response.headers["X-Response-Time-Ms"] = f"{elapsed_ms:.2f}"
    response.headers["X-Model-Version"] = MODEL_VERSION
    logger.info(
        f"{request.method} {request.url.path} "
        f"status={response.status_code} time={elapsed_ms:.1f}ms"
    )
    return response


# --- Helper Functions ---
def get_risk_tier(probability: float) -> str:
    if probability < 0.3:
        return "low"
    elif probability < 0.6:
        return "medium"
    else:
        return "high"


def get_plain_language(feature: str, direction: str) -> str:
    """Generate a human-readable explanation for a SHAP reason."""
    templates = DIRECTION_TEMPLATES.get(feature, {})
    return templates.get(
        direction, f"{FEATURE_DESCRIPTIONS.get(feature, feature)} affects churn risk"
    )


def compute_shap_reasons(
    features_df: pd.DataFrame, n: int = 3
) -> list[ShapReason]:
    shap_values = explainer.shap_values(features_df)

    if isinstance(shap_values, list):
        sv = shap_values[1][0]
    else:
        sv = shap_values[0]

    contributions = []
    feature_values = features_df.iloc[0].to_dict()

    for fname, shap_val in zip(FEATURE_ORDER, sv):
        direction = "increases_risk" if shap_val > 0 else "decreases_risk"
        contributions.append(ShapReason(
            feature=fname,
            value=float(feature_values[fname]),
            shap_contribution=round(float(shap_val), 4),
            direction=direction,
            plain_language=get_plain_language(fname, direction),
        ))

    contributions.sort(key=lambda x: abs(x.shap_contribution), reverse=True)
    return contributions[:n]


# --- Endpoints ---
@app.get("/health")
def health_check():
    return {
        "status": "healthy",
        "model_version": MODEL_VERSION,
        "model_path": MODEL_PATH,
    }


@app.post("/predict", response_model=ChurnPredictionResponse)
def predict(request: ChurnPredictionRequest):
    features_df = encode_request(request)
    probability = float(model.predict_proba(features_df)[0, 1])
    reasons = compute_shap_reasons(features_df, n=3)

    return ChurnPredictionResponse(
        churn_probability=round(probability, 4),
        risk_tier=get_risk_tier(probability),
        top_reasons=reasons,
        model_version=MODEL_VERSION,
        show_retention_banner=probability > CHURN_THRESHOLD,
    )

Three production details worth highlighting:

  1. The lifespan context manager loads the model at startup and holds it in memory. The model is loaded once, not on every request. This is critical for latency.
  2. The logging middleware records the response time and model version for every request. This feeds into the monitoring system covered in Chapter 32.
  3. The plain_language field translates SHAP features into sentences the portal can display directly. Technical feature names are useless to customers.

Phase 3: Testing

Priya writes tests before building the container. If the tests fail, there is no point in containerizing a broken API.

# test_api.py
from fastapi.testclient import TestClient
from app import app

client = TestClient(app)

VALID_PAYLOAD = {
    "tenure": 14,
    "monthly_charges": 89.50,
    "total_charges": 1253.00,
    "contract_type": "month-to-month",
    "payment_method": "electronic_check",
    "num_support_tickets": 3,
    "internet_service": "fiber_optic",
    "streaming_services": 2,
    "paperless_billing": True,
    "senior_citizen": False,
}


def test_health():
    r = client.get("/health")
    assert r.status_code == 200
    assert r.json()["status"] == "healthy"
    assert "model_version" in r.json()


def test_predict_valid():
    r = client.post("/predict", json=VALID_PAYLOAD)
    assert r.status_code == 200
    data = r.json()
    assert 0 <= data["churn_probability"] <= 1
    assert data["risk_tier"] in ("low", "medium", "high")
    assert len(data["top_reasons"]) == 3
    assert isinstance(data["show_retention_banner"], bool)

    # Verify SHAP reasons have plain language
    for reason in data["top_reasons"]:
        assert reason["plain_language"]
        assert reason["direction"] in ("increases_risk", "decreases_risk")


def test_predict_missing_field():
    payload = {"tenure": 14}
    r = client.post("/predict", json=payload)
    assert r.status_code == 422


def test_predict_invalid_contract():
    payload = {**VALID_PAYLOAD, "contract_type": "weekly"}
    r = client.post("/predict", json=payload)
    assert r.status_code == 422


def test_predict_negative_tenure():
    payload = {**VALID_PAYLOAD, "tenure": -1}
    r = client.post("/predict", json=payload)
    assert r.status_code == 422


def test_predict_tenure_above_max():
    payload = {**VALID_PAYLOAD, "tenure": 999}
    r = client.post("/predict", json=payload)
    assert r.status_code == 422


def test_response_time_header():
    r = client.post("/predict", json=VALID_PAYLOAD)
    assert "X-Response-Time-Ms" in r.headers
    assert "X-Model-Version" in r.headers


def test_low_risk_no_banner():
    # Long tenure + two-year contract + low tickets = low risk
    payload = {
        "tenure": 72,
        "monthly_charges": 29.95,
        "total_charges": 2156.40,
        "contract_type": "two-year",
        "payment_method": "bank_transfer",
        "num_support_tickets": 0,
        "internet_service": "dsl",
        "streaming_services": 1,
        "paperless_billing": False,
        "senior_citizen": False,
    }
    r = client.post("/predict", json=payload)
    data = r.json()
    # A loyal customer with a two-year contract should be low risk
    assert data["churn_probability"] < 0.6
    assert data["show_retention_banner"] is False
pytest test_api.py -v
# 8 passed

Phase 4: Containerization

Dockerfile

# --- Build Stage ---
FROM python:3.11-slim AS builder

WORKDIR /build
COPY requirements.txt .
RUN pip install --no-cache-dir --prefix=/install -r requirements.txt

# --- Runtime Stage ---
FROM python:3.11-slim

WORKDIR /app

# Copy installed packages
COPY --from=builder /install /usr/local

# Copy application code
COPY app.py .
COPY schemas.py .
COPY preprocessing.py .
COPY model/ model/

# Non-root user
RUN useradd --create-home appuser
USER appuser

EXPOSE 8000

HEALTHCHECK --interval=30s --timeout=5s --retries=3 \
    CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"

CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "2"]

The --workers 2 flag runs two Uvicorn worker processes, allowing the API to handle concurrent requests. For a CPU-bound workload (model inference is CPU-bound), the number of workers should roughly match the number of available CPU cores.

Build, Run, Test

# Build
docker build -t streamflow-churn-api:v2.3.1 .

# Run
docker run -d --name churn-api -p 8000:8000 streamflow-churn-api:v2.3.1

# Test health
curl http://localhost:8000/health
# {"status":"healthy","model_version":"v2.3.1","model_path":"model/churn_model.joblib"}

# Test prediction
curl -X POST http://localhost:8000/predict \
  -H "Content-Type: application/json" \
  -d '{
    "tenure": 14,
    "monthly_charges": 89.50,
    "total_charges": 1253.00,
    "contract_type": "month-to-month",
    "payment_method": "electronic_check",
    "num_support_tickets": 3,
    "internet_service": "fiber_optic",
    "streaming_services": 2,
    "paperless_billing": true,
    "senior_citizen": false
  }'

The response:

{
  "churn_probability": 0.7312,
  "risk_tier": "high",
  "top_reasons": [
    {
      "feature": "contract_type",
      "value": 0.0,
      "shap_contribution": 0.1847,
      "direction": "increases_risk",
      "plain_language": "Month-to-month contract increases churn risk"
    },
    {
      "feature": "tenure",
      "value": 14.0,
      "shap_contribution": 0.1523,
      "direction": "increases_risk",
      "plain_language": "Short account history increases churn risk"
    },
    {
      "feature": "num_support_tickets",
      "value": 3.0,
      "shap_contribution": 0.0891,
      "direction": "increases_risk",
      "plain_language": "Recent support issues increase churn risk"
    }
  ],
  "model_version": "v2.3.1",
  "show_retention_banner": true
}

The portal team can consume this directly. show_retention_banner triggers the UI component. The plain_language strings appear in the banner. The model_version is logged for debugging. No interpretation required on the frontend.


Phase 5: Deploying to AWS ECS

Priya's team uses AWS. The deployment follows the pattern from the chapter.

Push to ECR

# Authenticate
aws ecr get-login-password --region us-east-1 | \
  docker login --username AWS --password-stdin \
  123456789.dkr.ecr.us-east-1.amazonaws.com

# Create repository (first time only)
aws ecr create-repository --repository-name streamflow-churn-api

# Tag and push
docker tag streamflow-churn-api:v2.3.1 \
  123456789.dkr.ecr.us-east-1.amazonaws.com/streamflow-churn-api:v2.3.1
docker push \
  123456789.dkr.ecr.us-east-1.amazonaws.com/streamflow-churn-api:v2.3.1

ECS Service Configuration

Priya uses Terraform (the infrastructure-as-code tool her team adopted in Chapter 29) to define the ECS service:

Service: streamflow-churn-api
Task:    1 vCPU, 2 GB RAM (Fargate)
Tasks:   3 (for high availability)
ALB:     Internal (not public-facing; only the portal backend calls it)
Health:  GET /health every 30s
Scaling: Target tracking on CPU utilization (target: 60%)

Three tasks behind an internal ALB means: - If one container fails, two others continue serving - The ALB distributes requests across healthy containers - Auto-scaling adds containers if CPU exceeds 60%

Latency Validation

After deployment, Priya runs a load test from an EC2 instance in the same VPC:

import requests
import numpy as np
import time

url = "http://internal-churn-alb-123456.us-east-1.elb.amazonaws.com/predict"
payload = {
    "tenure": 14,
    "monthly_charges": 89.50,
    "total_charges": 1253.00,
    "contract_type": "month-to-month",
    "payment_method": "electronic_check",
    "num_support_tickets": 3,
    "internet_service": "fiber_optic",
    "streaming_services": 2,
    "paperless_billing": True,
    "senior_citizen": False,
}

latencies = []
for _ in range(1000):
    start = time.time()
    r = requests.post(url, json=payload)
    elapsed = (time.time() - start) * 1000
    latencies.append(elapsed)

print(f"Mean:   {np.mean(latencies):.1f} ms")
print(f"Median: {np.median(latencies):.1f} ms")
print(f"P95:    {np.percentile(latencies, 95):.1f} ms")
print(f"P99:    {np.percentile(latencies, 99):.1f} ms")

Results:

Mean:   47.3 ms
Median: 38.2 ms
P95:    89.1 ms
P99:    142.7 ms

The p99 of 142 ms is under the 200 ms requirement. Most of the latency is SHAP computation. Priya notes that if the requirement tightens to 100 ms, she will need to pre-compute SHAP values or switch to an approximate explainer.


Phase 6: Canary Deployment for Model Updates

Two weeks after the initial deployment, the team retrains the model on updated data. The new model (v2.4.0) has an AUC of 0.8891 on the holdout set --- a small improvement. But AUC on a holdout set does not guarantee better performance in production.

Priya deploys v2.4.0 as a canary:

  1. Build and push the new image: streamflow-churn-api:v2.4.0
  2. Create a second ECS task definition pointing to the new image
  3. Configure the ALB to route 10% of traffic to v2.4.0 and 90% to v2.3.1
  4. Monitor for 48 hours

The monitoring dashboard (Chapter 32 territory) shows: - Prediction distribution is stable (v2.4.0 is not dramatically different) - Latency is comparable (45 ms mean vs. 47 ms) - No increase in error rate - The SHAP top-3 features are the same for both versions

After 48 hours with clean metrics, Priya shifts to 50/50, then 100% on v2.4.0, and decommissions v2.3.1.


Lessons Learned

  1. The model was the easy part. Training took one afternoon. The deployment pipeline --- API code, tests, Dockerfile, ECR, ECS, ALB, health checks, monitoring --- took two weeks. This ratio is normal.

  2. Pydantic Literal types eliminated an entire class of bugs. Instead of accepting any string and checking it downstream, the schema rejects "weekly" for contract_type immediately. Three bugs in the first week of development were caught by Pydantic before the model ever ran.

  3. SHAP is the latency bottleneck. Model inference takes 2 ms. SHAP computation takes 30--80 ms. If explanations are not required on every call, make them optional or serve them from a cache.

  4. plain_language translations saved the frontend team a week of work. Instead of mapping feature names to user-facing strings in JavaScript, the API returns ready-to-display text. The ML team owns the translation because the ML team understands what the features mean.

  5. Never deploy on Friday. Priya's team deploys on Tuesday mornings. If something goes wrong, the team is fresh, alert, and available for the full week. Friday deployments create weekend incidents.


This case study supports Chapter 31: Model Deployment. Return to the chapter for the foundational concepts.