Case Study 1: StreamRec Training Pipeline — From Manual Notebooks to Dagster Orchestration
Context
Six months after the StreamRec production system went live (Chapter 24, Case Study 1), the ML platform team has a problem that no amount of model improvement can solve: the daily training pipeline is unreliable.
The pipeline was originally a sequence of Jupyter notebooks, executed manually by the on-call ML engineer every morning:
- Run
01_extract.ipynb— pull yesterday's interaction events from the data lake - Run
02_validate.ipynb— check data quality; if bad, post in Slack and wait for the data engineering team - Run
03_features.ipynb— compute training features via Feast - Run
04_train_retrieval.ipynb— train the two-tower retrieval model on 4x A100s - Run
05_train_ranking.ipynb— train the DCN-V2 ranking model - Run
06_evaluate.ipynb— compute Recall@20, NDCG@20, AUC, and logloss - Run
07_register.ipynb— if metrics pass thresholds, register in MLflow and trigger deployment
On a good day, the on-call engineer starts at 6am, the pipeline finishes by 10am, and the new model is serving traffic by noon. On a bad day — and there have been many bad days — the pipeline fails at step 4, the engineer spends two hours debugging a GPU out-of-memory error, restarts the training with a smaller batch size, and the new model is not ready until 4pm. On one occasion, the engineer was sick, nobody ran the pipeline, and the production model served 3-day-old recommendations for an entire weekend. Engagement dropped 8% on Sunday.
The VP of Engineering has mandated: "The training pipeline must run automatically every day, handle failures without human intervention, and page someone only when it truly cannot recover."
The Migration
Phase 1: Extract Business Logic from Notebooks
The first step is separating computation from orchestration. Each notebook contains a mix of business logic (data validation rules, feature computation, model training), infrastructure code (S3 access, GPU setup), and debugging artifacts (plots, print statements, commented-out cells).
The team extracts the business logic into Python modules:
"""streamrec_pipeline/validation.py — Pure validation logic.
No Dagster, no Airflow, no Prefect. Just Python functions
that take DataFrames and return DataFrames or raise exceptions.
Testable with pytest.
"""
import pandas as pd
from typing import Dict
# Data contracts as constants
MIN_DAILY_INTERACTIONS = 100_000
MAX_NULL_RATE = 0.05
REQUIRED_COLUMNS = {"user_id", "item_id", "event_type", "timestamp"}
ALLOWED_EVENT_TYPES = {"view", "click", "complete", "share", "save"}
def validate_interactions(df: pd.DataFrame) -> pd.DataFrame:
"""Validate interaction data against quality contracts.
Args:
df: Raw interaction DataFrame from the data lake.
Returns:
Validated DataFrame with invalid event types filtered.
Raises:
DataQualityError: If any critical quality check fails.
"""
check_row_count(df)
check_schema(df)
check_null_rates(df)
df = filter_invalid_events(df)
return df
def check_row_count(df: pd.DataFrame) -> None:
"""Verify minimum row count."""
if len(df) < MIN_DAILY_INTERACTIONS:
raise DataQualityError(
f"Row count {len(df)} below minimum {MIN_DAILY_INTERACTIONS}. "
"Possible upstream data pipeline failure."
)
def check_schema(df: pd.DataFrame) -> None:
"""Verify all required columns are present."""
missing = REQUIRED_COLUMNS - set(df.columns)
if missing:
raise DataQualityError(f"Missing required columns: {missing}")
def check_null_rates(df: pd.DataFrame) -> Dict[str, float]:
"""Check null rates per column, raising on violations."""
null_report = {}
for col in df.columns:
null_rate = df[col].isna().mean()
null_report[col] = null_rate
if null_rate > MAX_NULL_RATE:
raise DataQualityError(
f"Column '{col}' null rate {null_rate:.3f} "
f"exceeds threshold {MAX_NULL_RATE}."
)
return null_report
def filter_invalid_events(df: pd.DataFrame) -> pd.DataFrame:
"""Filter rows with unrecognized event types."""
valid_mask = df["event_type"].isin(ALLOWED_EVENT_TYPES)
n_filtered = (~valid_mask).sum()
if n_filtered > 0:
df = df[valid_mask].copy()
return df
class DataQualityError(Exception):
"""Raised when data fails quality validation."""
pass
The key insight: by extracting validation logic into streamrec_pipeline/validation.py, the team can write 15 unit tests that run in 2 seconds, covering every validation path. None of these tests require Dagster, Airflow, or any infrastructure. The same logic is then wrapped in a Dagster asset with a thin orchestration layer.
Phase 2: Build the Dagster Asset Graph
With business logic extracted, the Dagster assets become thin wrappers:
"""streamrec_pipeline/assets.py — Dagster asset definitions.
Each asset wraps a business logic function with orchestration
concerns: IO management, metadata logging, retry policies,
and partition handling.
"""
from dagster import (
asset,
DailyPartitionsDefinition,
Output,
MetadataValue,
RetryPolicy,
Backoff,
)
from streamrec_pipeline.validation import validate_interactions
from streamrec_pipeline.features import compute_training_features
from streamrec_pipeline.training import train_two_tower, train_dcn_v2
from streamrec_pipeline.evaluation import evaluate_retrieval_model, evaluate_ranking_model
import pandas as pd
from typing import Dict, Any
daily_partitions = DailyPartitionsDefinition(start_date="2025-01-01")
gpu_retry = RetryPolicy(max_retries=2, delay=120, backoff=Backoff.EXPONENTIAL)
@asset(partitions_def=daily_partitions, group_name="data_prep")
def raw_interactions(context) -> Output[pd.DataFrame]:
"""Extract interactions from the data lake for one partition."""
from deltalake import DeltaTable
partition_date = context.partition_key
dt = DeltaTable("s3://streamrec-datalake/interactions/")
df = dt.to_pandas(filters=[("event_date", "=", partition_date)])
return Output(
df,
metadata={
"num_rows": MetadataValue.int(len(df)),
"partition": MetadataValue.text(partition_date),
},
)
@asset(partitions_def=daily_partitions, group_name="data_prep")
def validated_interactions(
context, raw_interactions: pd.DataFrame
) -> Output[pd.DataFrame]:
"""Validate interactions — delegates to pure validation module."""
validated = validate_interactions(raw_interactions)
return Output(
validated,
metadata={
"num_rows": MetadataValue.int(len(validated)),
"num_filtered": MetadataValue.int(
len(raw_interactions) - len(validated)
),
},
)
@asset(
partitions_def=daily_partitions,
group_name="training",
retry_policy=gpu_retry,
)
def retrieval_model(
context, training_features: pd.DataFrame
) -> Output[Dict[str, Any]]:
"""Train two-tower retrieval model — delegates to training module."""
result = train_two_tower(
training_df=training_features,
partition_date=context.partition_key,
epochs=10,
batch_size=4096,
learning_rate=1e-3,
)
return Output(
result,
metadata={
"mlflow_run_id": MetadataValue.text(result["run_id"]),
"training_time_s": MetadataValue.float(
result["metrics"]["training_time_seconds"]
),
},
)
The pattern is consistent: each asset function is 10-15 lines, delegating to a business logic function and adding Dagster-specific metadata. The business logic modules are 50-100 lines each, thoroughly tested, and framework-agnostic.
Phase 3: Failure Handling
The team configures failure handling based on six months of incident data from the manual pipeline:
| Incident Type | Frequency | Root Cause | Mitigation |
|---|---|---|---|
| GPU OOM during training | 2x/month | Batch size too large for long-sequence days | Retry with exponential backoff (Dagster restarts on a fresh worker) |
| Data lake partition late | 1x/week | Upstream Spark job delayed | Dagster sensor polls every 5 min; 2-hour timeout; P1 alert if timeout |
| Feast offline store timeout | 1x/month | Feature store under heavy load | Retry policy on training_features asset (2 retries, 5-min delay) |
| Null rate spike in interaction data | 1x/quarter | Upstream schema change | validate_interactions raises DataQualityError; P1 alert; pipeline halts |
| Model quality degradation | 1x/month | Distribution shift; sometimes data issue | Evaluation gate prevents promotion; P2 alert; yesterday's model stays in production |
The GPU OOM issue deserves special attention. The root cause is that on some days, a viral video produces interaction sequences that are 3x longer than average, causing the training batch to exceed GPU memory. The manual workaround was to reduce the batch size and restart — exactly the kind of intervention that should be automated.
The team adds a gradient-accumulation fallback: if training crashes with a CUDA OOM error, the retry uses gradient accumulation to simulate the original batch size with smaller micro-batches. This is implemented in the training module (not in the Dagster asset), making it testable:
"""streamrec_pipeline/training.py — Model training logic."""
import torch
from typing import Dict, Any
def train_two_tower(
training_df: "pd.DataFrame",
partition_date: str,
epochs: int,
batch_size: int,
learning_rate: float,
gradient_accumulation_steps: int = 1,
) -> Dict[str, Any]:
"""Train the two-tower retrieval model.
If a previous attempt failed with OOM (detected via retry context),
gradient_accumulation_steps is automatically increased to reduce
per-step memory usage while maintaining the effective batch size.
Args:
training_df: Feature-enriched training DataFrame.
partition_date: Training date for artifact naming.
epochs: Number of training epochs.
batch_size: Effective batch size.
learning_rate: Optimizer learning rate.
gradient_accumulation_steps: Number of micro-batches per
optimizer step. Increase to reduce memory usage.
Returns:
Dictionary with artifact path, MLflow run ID, and metrics.
"""
import mlflow
micro_batch_size = batch_size // gradient_accumulation_steps
with mlflow.start_run(run_name=f"retrieval_{partition_date}") as run:
mlflow.log_params({
"batch_size": batch_size,
"micro_batch_size": micro_batch_size,
"gradient_accumulation_steps": gradient_accumulation_steps,
"epochs": epochs,
"learning_rate": learning_rate,
})
# Training loop with gradient accumulation
# (Chapter 26 covers the DDP implementation)
artifact_path = (
f"s3://streamrec-pipeline/models/{partition_date}/retrieval/"
)
metrics = {
"final_train_loss": 0.0,
"training_time_seconds": 0.0,
}
mlflow.log_metrics(metrics)
return {
"artifact_path": artifact_path,
"run_id": run.info.run_id,
"metrics": metrics,
}
Phase 4: Observability and Alerting
The team builds a Dagster sensor that monitors pipeline health and sends structured alerts:
"""streamrec_pipeline/monitoring.py — Pipeline health monitoring."""
from dagster import sensor, RunRequest, SensorEvaluationContext
from datetime import datetime, timedelta
from typing import Optional
PIPELINE_SLA_HOURS = 6 # Must complete by 8am (started at 2am)
def check_pipeline_sla(
start_time: datetime,
end_time: Optional[datetime],
sla_hours: float = PIPELINE_SLA_HOURS,
) -> dict:
"""Check whether the pipeline run met its SLA.
Args:
start_time: Pipeline run start time.
end_time: Pipeline run completion time (None if still running).
sla_hours: Maximum allowed duration in hours.
Returns:
Dictionary with SLA status and details.
"""
now = datetime.utcnow()
elapsed = (end_time or now) - start_time
elapsed_hours = elapsed.total_seconds() / 3600
if end_time is not None and elapsed_hours <= sla_hours:
return {"status": "met", "duration_hours": elapsed_hours}
elif end_time is None and elapsed_hours > sla_hours:
return {
"status": "breached",
"duration_hours": elapsed_hours,
"message": f"Pipeline still running after {elapsed_hours:.1f}h "
f"(SLA: {sla_hours}h)",
}
elif end_time is not None and elapsed_hours > sla_hours:
return {
"status": "breached",
"duration_hours": elapsed_hours,
"message": f"Pipeline completed in {elapsed_hours:.1f}h "
f"(SLA: {sla_hours}h)",
}
else:
return {
"status": "on_track",
"duration_hours": elapsed_hours,
"remaining_hours": sla_hours - elapsed_hours,
}
Results
After 90 days of operating the Dagster pipeline:
| Metric | Before (Manual) | After (Dagster) |
|---|---|---|
| Pipeline runs per 30 days | 22 (missed 8 days) | 30 (100% completion) |
| Mean pipeline duration | 4.2 hours | 3.1 hours |
| P95 pipeline duration | 7.5 hours | 4.8 hours |
| Human intervention required | 15 times/month | 2 times/month |
| SLA breaches (> 6 hours) | 4/month | 0.3/month |
| Model freshness (median age) | 1.8 days | 1.0 days |
| Weekend engagement drop | -8% (stale models) | -0.5% (within noise) |
The two remaining human interventions per month were: (1) one genuine data quality issue that required a data engineering fix (the pipeline correctly halted and alerted), and (2) one infrastructure change that required updating a Dagster resource configuration. Neither was a pipeline failure — both were legitimate operational events that required human judgment.
Lessons Learned
1. Separate business logic from orchestration. The team's unit test suite (47 tests, 3-second runtime) catches 90% of bugs before they reach the Dagster runtime. The orchestration layer is thin and rarely changes.
2. Design for the bad day, not the good day. The GPU OOM fallback, the data validation halt, and the model quality gate all handle scenarios that occurred in the first month of operation. Without them, the pipeline would have required human intervention on those days.
3. Idempotency enables fearless backfill. When the team discovered a feature computation bug in week 6, they backfilled 14 days of data in 5 hours (parallel, max_parallel=3). Every output was overwritten cleanly. The pre-Dagster response to the same scenario — manual notebook re-execution for each day — would have taken two engineers a full day.
4. Metadata is not optional. The Dagster UI's asset metadata (row counts, null rates, MLflow run IDs, training times) eliminated 80% of the "let me check the logs" debugging sessions. When a model's NDCG@20 dropped, the engineer could trace from the model to the training run to the feature set to the raw data in under 2 minutes.
5. Start with sensors, not schedules. The switch from a fixed 2am cron to a sensor that waits for the upstream data pipeline eliminated all "stale data" incidents. The pipeline starts when data is ready, not when the clock says so.