26 min read

> "Data pipelines are not ETL jobs. They are the circulatory system of a machine learning product. When they stop, everything stops."

Chapter 27: ML Pipeline Orchestration — Airflow, Dagster, Prefect, and Designing Robust Data Workflows

"Data pipelines are not ETL jobs. They are the circulatory system of a machine learning product. When they stop, everything stops." — Maxime Beauchemin, creator of Apache Airflow


Learning Objectives

By the end of this chapter, you will be able to:

  1. Design ML pipelines as directed acyclic graphs (DAGs) with explicit dependency management, idempotent tasks, and well-defined data contracts between stages
  2. Implement the same ML pipeline in Apache Airflow and Dagster, and articulate the philosophical differences between imperative (task-centric) and declarative (asset-centric) orchestration
  3. Handle pipeline failures robustly using retry policies, exponential backoff, alerting, backfill strategies, and dead letter queues
  4. Version pipelines alongside data and model artifacts using experiment tracking systems (MLflow, Weights & Biases) and artifact stores
  5. Design pipeline testing strategies that include unit tests for individual tasks, integration tests for end-to-end pipelines, and contract tests for inter-stage data schemas

27.1 The Orchestration Problem

Chapter 24 established the two-loop architecture of a production ML system: the inner loop serves predictions in milliseconds; the outer loop retrains models in hours. Chapter 25 built the data infrastructure — the feature store, the data lake, the streaming pipeline — that provides the raw materials. Chapter 26 scaled the training itself across multiple GPUs. This chapter addresses the question that sits between them: who runs the training pipeline, in what order, on what schedule, and what happens when something goes wrong?

The question is deceptively simple. Consider the StreamRec training pipeline designed in Chapter 24's Case Study 1:

  1. Extract interaction events from the last 24 hours from the data lake (Delta Lake on S3)
  2. Validate the extracted data (schema checks, row count assertions, null rate thresholds)
  3. Compute features using the feature store (Feast, from Chapter 25), writing point-in-time correct training examples
  4. Train the two-tower retrieval model and the DCN-V2 ranking model (DDP on 4x A100s, from Chapter 26)
  5. Evaluate both models on the held-out set (Recall@20, NDCG@20 for retrieval; AUC, logloss for ranking)
  6. Compare against the current production model; gate on minimum quality thresholds
  7. Register the new models in MLflow with metrics, feature schema, and data lineage
  8. Deploy (trigger the CI/CD pipeline from Chapter 29, which handles canary deployment)

Run sequentially by a competent engineer on a good day, this pipeline takes 3.5 hours. The engineer monitors it, handles failures, reruns failed steps, and eventually confirms that the new models are in the registry.

Now consider what happens in reality:

  • The data lake partition for yesterday is not available at the expected time because an upstream Spark job ran long. Step 1 fails.
  • The data passes schema checks but contains a column where 30% of values are null — well above the 5% threshold in the data contract. Step 2 catches this, but what happens next? Does the pipeline stop? Use fallback data? Alert someone?
  • The 4-GPU training job in step 4 crashes 90 minutes in because one GPU runs out of memory during a batch with unusually long sequences. Is the 90 minutes of compute wasted?
  • The evaluation in step 5 reveals that NDCG@20 dropped from 0.182 to 0.174. Is this noise or a real degradation? Should the model be promoted?
  • It is Saturday. Nobody is watching the Slack channel.

Pipeline orchestration is the engineering discipline that handles all of these scenarios automatically. An orchestrator schedules tasks, manages dependencies, retries failures, sends alerts, provides observability, and — critically — enables the pipeline to run unattended, every day, without human intervention.

Production ML = Software Engineering: A model training pipeline that requires a human operator to babysit it is not a pipeline — it is a script with a prayer. The transition from manual execution to automated orchestration is one of the defining transitions from prototype to production ML. Every principle from software engineering — separation of concerns, idempotency, observability, testing — applies with full force.


27.2 Pipelines as Directed Acyclic Graphs

The fundamental abstraction in pipeline orchestration is the directed acyclic graph (DAG). Each node in the DAG is a task (a unit of work), and each directed edge represents a dependency (task B cannot start until task A completes).

Why DAGs?

The DAG constraint — no cycles — ensures that the pipeline has a well-defined execution order. If task A depends on task B and task B depends on task A, there is no valid execution order: a cycle. By forbidding cycles, the orchestrator can compute a topological ordering of tasks and execute them in a sequence that respects all dependencies.

The DAG structure also enables parallelism: tasks with no dependency relationship can run simultaneously. In the StreamRec pipeline, once data extraction (step 1) and validation (step 2) are complete, the two-tower retrieval model and the DCN-V2 ranking model can train in parallel — they share the same training data but have no other dependency on each other.

Formalizing the StreamRec Pipeline as a DAG

from dataclasses import dataclass, field
from typing import Dict, List, Set, Optional, Tuple
from enum import Enum
import time


class TaskStatus(Enum):
    """Execution status for a pipeline task."""
    PENDING = "pending"
    RUNNING = "running"
    SUCCESS = "success"
    FAILED = "failed"
    SKIPPED = "skipped"
    UPSTREAM_FAILED = "upstream_failed"


@dataclass
class PipelineTask:
    """A single task in an ML pipeline DAG.

    Attributes:
        task_id: Unique identifier for this task.
        upstream_tasks: Set of task IDs that must complete before this task runs.
        retry_count: Number of retry attempts on failure.
        retry_delay_seconds: Base delay between retries (exponential backoff).
        timeout_seconds: Maximum execution time before the task is killed.
        is_idempotent: Whether the task produces the same output given the same input.
        status: Current execution status.
    """
    task_id: str
    upstream_tasks: Set[str] = field(default_factory=set)
    retry_count: int = 3
    retry_delay_seconds: int = 60
    timeout_seconds: int = 3600
    is_idempotent: bool = True
    status: TaskStatus = TaskStatus.PENDING

    def can_run(self, completed_tasks: Set[str]) -> bool:
        """Check whether all upstream dependencies are satisfied.

        Args:
            completed_tasks: Set of task IDs that have completed successfully.

        Returns:
            True if all upstream tasks are in completed_tasks.
        """
        return self.upstream_tasks.issubset(completed_tasks)


@dataclass
class PipelineDAG:
    """A directed acyclic graph representing an ML pipeline.

    Attributes:
        name: Pipeline name (e.g., 'streamrec_daily_training').
        tasks: Dictionary mapping task_id to PipelineTask.
        schedule: Cron expression for scheduled execution.
    """
    name: str
    tasks: Dict[str, PipelineTask] = field(default_factory=dict)
    schedule: str = "0 2 * * *"  # Daily at 2am UTC

    def add_task(self, task: PipelineTask) -> None:
        """Add a task to the DAG.

        Raises:
            ValueError: If adding this task would create a cycle.
        """
        self.tasks[task.task_id] = task
        if self._has_cycle():
            del self.tasks[task.task_id]
            raise ValueError(
                f"Adding task '{task.task_id}' would create a cycle in the DAG."
            )

    def _has_cycle(self) -> bool:
        """Detect cycles using depth-first search.

        Returns:
            True if the DAG contains a cycle.
        """
        visited: Set[str] = set()
        rec_stack: Set[str] = set()

        def dfs(task_id: str) -> bool:
            visited.add(task_id)
            rec_stack.add(task_id)
            task = self.tasks.get(task_id)
            if task is None:
                return False
            for upstream_id in task.upstream_tasks:
                if upstream_id not in visited:
                    if dfs(upstream_id):
                        return True
                elif upstream_id in rec_stack:
                    return True
            rec_stack.discard(task_id)
            return False

        for task_id in self.tasks:
            if task_id not in visited:
                if dfs(task_id):
                    return True
        return False

    def topological_sort(self) -> List[str]:
        """Return tasks in a valid execution order.

        Returns:
            List of task IDs in topological order.

        Raises:
            ValueError: If the DAG contains a cycle.
        """
        in_degree: Dict[str, int] = {tid: 0 for tid in self.tasks}
        for task in self.tasks.values():
            for upstream_id in task.upstream_tasks:
                if upstream_id in in_degree:
                    # Count how many tasks depend on each upstream
                    pass
        # Compute in-degree: for each task, count tasks that list it as upstream
        in_degree = {tid: 0 for tid in self.tasks}
        for task in self.tasks.values():
            for _ in task.upstream_tasks:
                # task depends on upstream, so task has in-degree contribution
                pass
        # Kahn's algorithm
        reverse_adj: Dict[str, List[str]] = {tid: [] for tid in self.tasks}
        in_degree = {tid: len(self.tasks[tid].upstream_tasks) for tid in self.tasks}
        for task in self.tasks.values():
            for upstream_id in task.upstream_tasks:
                if upstream_id in reverse_adj:
                    reverse_adj[upstream_id].append(task.task_id)

        queue = [tid for tid, deg in in_degree.items() if deg == 0]
        result: List[str] = []

        while queue:
            current = queue.pop(0)
            result.append(current)
            for downstream_id in reverse_adj[current]:
                in_degree[downstream_id] -= 1
                if in_degree[downstream_id] == 0:
                    queue.append(downstream_id)

        if len(result) != len(self.tasks):
            raise ValueError("DAG contains a cycle; topological sort is impossible.")
        return result

    def get_parallel_stages(self) -> List[List[str]]:
        """Identify groups of tasks that can execute in parallel.

        Returns:
            List of stages, where each stage is a list of task IDs that
            have no mutual dependencies and can run concurrently.
        """
        remaining = set(self.tasks.keys())
        completed: Set[str] = set()
        stages: List[List[str]] = []

        while remaining:
            ready = [
                tid for tid in remaining
                if self.tasks[tid].can_run(completed)
            ]
            if not ready:
                raise ValueError("DAG contains a cycle or unsatisfiable dependency.")
            stages.append(ready)
            completed.update(ready)
            remaining -= set(ready)

        return stages


# Build the StreamRec training pipeline DAG
streamrec_pipeline = PipelineDAG(name="streamrec_daily_training")

tasks = [
    PipelineTask(
        task_id="extract_interactions",
        upstream_tasks=set(),
        timeout_seconds=1800,
    ),
    PipelineTask(
        task_id="validate_data",
        upstream_tasks={"extract_interactions"},
        retry_count=1,
        timeout_seconds=600,
    ),
    PipelineTask(
        task_id="compute_features",
        upstream_tasks={"validate_data"},
        timeout_seconds=3600,
    ),
    PipelineTask(
        task_id="train_retrieval_model",
        upstream_tasks={"compute_features"},
        retry_count=2,
        timeout_seconds=7200,
    ),
    PipelineTask(
        task_id="train_ranking_model",
        upstream_tasks={"compute_features"},
        retry_count=2,
        timeout_seconds=7200,
    ),
    PipelineTask(
        task_id="evaluate_retrieval",
        upstream_tasks={"train_retrieval_model"},
        timeout_seconds=1800,
    ),
    PipelineTask(
        task_id="evaluate_ranking",
        upstream_tasks={"train_ranking_model"},
        timeout_seconds=1800,
    ),
    PipelineTask(
        task_id="register_models",
        upstream_tasks={"evaluate_retrieval", "evaluate_ranking"},
        timeout_seconds=600,
    ),
    PipelineTask(
        task_id="trigger_deployment",
        upstream_tasks={"register_models"},
        timeout_seconds=300,
    ),
]

for task in tasks:
    streamrec_pipeline.add_task(task)

# Inspect the parallel execution stages
stages = streamrec_pipeline.get_parallel_stages()
for i, stage in enumerate(stages):
    print(f"Stage {i}: {stage}")
# Stage 0: ['extract_interactions']
# Stage 1: ['validate_data']
# Stage 2: ['compute_features']
# Stage 3: ['train_retrieval_model', 'train_ranking_model']
# Stage 4: ['evaluate_retrieval', 'evaluate_ranking']
# Stage 5: ['register_models']
# Stage 6: ['trigger_deployment']

Stage 3 reveals the parallelism: the two training jobs run simultaneously, and Stage 4 shows that their evaluations also run in parallel. This cuts the end-to-end pipeline time by nearly 2 hours compared to sequential execution.

Data Intervals and Logical Dates

A subtlety that trips up every new orchestration user: the logical date (data interval) is not the execution date. When a daily pipeline runs at 2am on March 15, it processes data from March 14. The logical date is March 14; the execution date is March 15 at 02:00.

This distinction is critical for:

  • Idempotency. If the pipeline fails and is rerun, it must process the same data interval — March 14 — not whatever data is available at rerun time.
  • Backfill. If the pipeline was down for three days, backfilling means running the pipeline for each missed data interval (March 12, 13, 14), not running it once with all three days of data.
  • Auditing. A model registered from the March 14 run should have metadata indicating it was trained on data up to March 14, regardless of when the pipeline actually executed.

All three major orchestration frameworks enforce this concept, though they name it differently: Airflow uses "data interval" (historically "execution date"), Dagster uses "partition," and Prefect uses "flow run parameters."


27.3 Apache Airflow — The Industry Standard

Apache Airflow, created by Maxime Beauchemin at Airbnb in 2014 and open-sourced in 2015, is the most widely deployed pipeline orchestration framework. It introduced the DAG-as-Python-code paradigm that every subsequent orchestrator has adopted or extended.

Core Concepts

Concept Description
DAG A directed acyclic graph defined in a Python file. Represents a pipeline.
Operator A template for a task. PythonOperator runs a function. BashOperator runs a shell command. SparkSubmitOperator submits a Spark job.
Task An instance of an operator within a DAG, with specific parameters.
Task Instance A single execution of a task for a specific data interval.
XCom Cross-communication mechanism for passing small data (metadata, file paths, metrics) between tasks.
Pool A resource constraint that limits the number of concurrent task instances (e.g., max 4 GPU training jobs).
Sensor A special operator that waits for an external condition (file existence, partition availability, API response).
Trigger Rule A rule that determines when a task runs based on upstream task states (all_success, all_failed, one_success, etc.).
Connection A stored credential for connecting to external systems (databases, cloud storage, APIs).

Airflow's Philosophy: Imperative, Task-Centric

Airflow's mental model is imperative: you define tasks and the order in which they run. The orchestrator's job is to execute those tasks according to the schedule, retry failures, and provide a web UI for monitoring.

This is analogous to writing procedural code: "First do A, then do B, then do C." It is intuitive and maps directly to how engineers think about pipelines.

The StreamRec Pipeline in Airflow

"""StreamRec daily training pipeline — Apache Airflow implementation.

This DAG orchestrates the full model training lifecycle:
data extraction, validation, feature computation, model training,
evaluation, registration, and deployment triggering.

Schedule: Daily at 02:00 UTC.
"""
from datetime import datetime, timedelta
from typing import Any, Dict

from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.operators.trigger_dagrun import TriggerDagRunOperator
from airflow.sensors.external_task import ExternalTaskSensor
from airflow.utils.trigger_rule import TriggerRule
from airflow.models import Variable

# ── Default arguments applied to all tasks ──────────────────────────────────
default_args: Dict[str, Any] = {
    "owner": "ml-platform",
    "depends_on_past": False,
    "email": ["ml-oncall@streamrec.com"],
    "email_on_failure": True,
    "email_on_retry": False,
    "retries": 2,
    "retry_delay": timedelta(minutes=5),
    "retry_exponential_backoff": True,
    "max_retry_delay": timedelta(minutes=30),
    "execution_timeout": timedelta(hours=2),
}

# ── DAG definition ──────────────────────────────────────────────────────────
with DAG(
    dag_id="streamrec_daily_training",
    default_args=default_args,
    description="Daily retraining of StreamRec retrieval and ranking models.",
    schedule_interval="0 2 * * *",
    start_date=datetime(2025, 1, 1),
    catchup=False,
    max_active_runs=1,
    tags=["ml", "training", "streamrec"],
) as dag:

    # ── Task functions ──────────────────────────────────────────────────────

    def extract_interactions(**context: Any) -> str:
        """Extract interaction events for the data interval.

        Reads from the Delta Lake table partitioned by date,
        filtering to the logical date's partition.

        Returns:
            S3 path to the extracted Parquet file.
        """
        from deltalake import DeltaTable
        import pyarrow.parquet as pq

        logical_date = context["data_interval_start"].strftime("%Y-%m-%d")
        delta_path = "s3://streamrec-datalake/interactions/"
        dt = DeltaTable(delta_path)
        df = dt.to_pyarrow_dataset().to_table(
            filter=(
                ("event_date", "=", logical_date)
            )
        )

        output_path = (
            f"s3://streamrec-pipeline/extracts/{logical_date}/interactions.parquet"
        )
        pq.write_table(df, output_path)
        print(f"Extracted {len(df)} rows for {logical_date} -> {output_path}")
        return output_path

    def validate_data(**context: Any) -> Dict[str, Any]:
        """Validate extracted data against quality contracts.

        Checks:
            - Minimum row count (>= 100,000 interactions/day)
            - Schema conformance (required columns present)
            - Null rate per column (<= 5%)
            - Value range checks (e.g., rating in [1, 5])

        Returns:
            Dictionary with validation results and extracted path.

        Raises:
            ValueError: If any critical validation check fails.
        """
        import pyarrow.parquet as pq

        ti = context["ti"]
        extract_path = ti.xcom_pull(task_ids="extract_interactions")
        df = pq.read_table(extract_path)
        num_rows = len(df)

        results: Dict[str, Any] = {"path": extract_path, "num_rows": num_rows}

        # Critical checks
        min_rows = int(Variable.get("min_daily_interactions", default_var=100_000))
        if num_rows < min_rows:
            raise ValueError(
                f"Row count {num_rows} below minimum {min_rows}. "
                "Possible upstream data pipeline failure."
            )

        required_columns = {"user_id", "item_id", "event_type", "timestamp"}
        actual_columns = set(df.column_names)
        missing = required_columns - actual_columns
        if missing:
            raise ValueError(f"Missing required columns: {missing}")

        # Null rate checks
        max_null_rate = 0.05
        for col_name in df.column_names:
            null_count = df.column(col_name).null_count
            null_rate = null_count / num_rows if num_rows > 0 else 0
            results[f"null_rate_{col_name}"] = null_rate
            if null_rate > max_null_rate:
                raise ValueError(
                    f"Column '{col_name}' null rate {null_rate:.3f} "
                    f"exceeds threshold {max_null_rate}."
                )

        print(f"Validation passed: {num_rows} rows, all checks OK.")
        return results

    def compute_features(**context: Any) -> str:
        """Compute point-in-time correct training features via Feast.

        Joins interaction data with the offline feature store,
        ensuring temporal correctness (no future leakage).

        Returns:
            S3 path to the feature-enriched training dataset.
        """
        from feast import FeatureStore
        import pandas as pd

        ti = context["ti"]
        validation_results = ti.xcom_pull(task_ids="validate_data")
        extract_path = validation_results["path"]
        logical_date = context["data_interval_start"].strftime("%Y-%m-%d")

        store = FeatureStore(repo_path="/opt/feast/streamrec/")
        entity_df = pd.read_parquet(extract_path)

        training_df = store.get_historical_features(
            entity_df=entity_df,
            features=[
                "user_features:watch_count_7d",
                "user_features:avg_session_length",
                "user_features:genre_preference_vector",
                "item_features:popularity_score",
                "item_features:content_embedding",
                "item_features:days_since_release",
            ],
        ).to_df()

        output_path = (
            f"s3://streamrec-pipeline/features/{logical_date}/training_set.parquet"
        )
        training_df.to_parquet(output_path, index=False)
        print(f"Feature computation complete: {len(training_df)} training examples.")
        return output_path

    def train_model(model_type: str, **context: Any) -> Dict[str, Any]:
        """Train a model (retrieval or ranking) using DDP.

        Args:
            model_type: Either 'retrieval' (two-tower) or 'ranking' (DCN-V2).

        Returns:
            Dictionary with model artifact path and training metrics.
        """
        import torch
        import torch.distributed as dist
        import mlflow

        ti = context["ti"]
        features_path = ti.xcom_pull(task_ids="compute_features")
        logical_date = context["data_interval_start"].strftime("%Y-%m-%d")

        config = {
            "retrieval": {
                "model_class": "TwoTowerRetrieval",
                "epochs": 10,
                "batch_size": 4096,
                "learning_rate": 1e-3,
                "gpus": 4,
            },
            "ranking": {
                "model_class": "DCNV2Ranking",
                "epochs": 5,
                "batch_size": 2048,
                "learning_rate": 5e-4,
                "gpus": 4,
            },
        }[model_type]

        with mlflow.start_run(run_name=f"{model_type}_{logical_date}") as run:
            mlflow.log_param("model_type", model_type)
            mlflow.log_param("training_date", logical_date)
            mlflow.log_param("features_path", features_path)
            mlflow.log_params(config)

            # Training occurs here via torchrun / DDP
            # (Chapter 26 covers the distributed training implementation)
            artifact_path = (
                f"s3://streamrec-pipeline/models/{logical_date}/{model_type}/"
            )

            # After training completes, log the final metrics
            metrics: Dict[str, float] = {
                "final_loss": 0.0,  # Populated by training loop
                "training_time_seconds": 0.0,
            }
            mlflow.log_metrics(metrics)
            mlflow.log_artifact(artifact_path)

        return {
            "artifact_path": artifact_path,
            "run_id": run.info.run_id,
            "metrics": metrics,
        }

    def evaluate_model(model_type: str, **context: Any) -> Dict[str, Any]:
        """Evaluate a trained model against the held-out test set.

        Computes primary metrics and compares against the current
        production model. Returns evaluation results including
        a go/no-go recommendation.

        Args:
            model_type: Either 'retrieval' or 'ranking'.

        Returns:
            Dictionary with evaluation metrics and promotion decision.
        """
        import mlflow

        ti = context["ti"]
        train_result = ti.xcom_pull(task_ids=f"train_{model_type}")

        metric_thresholds = {
            "retrieval": {"recall_at_20": 0.15, "ndcg_at_20": 0.10},
            "ranking": {"auc": 0.75, "logloss": 0.50},
        }[model_type]

        # Load model, compute metrics on held-out test set
        # (Implementation details depend on model_type)
        eval_metrics: Dict[str, float] = {}  # Populated by evaluation

        # Gate: all metrics must meet minimum thresholds
        passed = all(
            eval_metrics.get(metric, 0.0) >= threshold
            for metric, threshold in metric_thresholds.items()
        )

        # Compare against current production model
        prod_run_id = Variable.get(
            f"production_{model_type}_run_id", default_var=None
        )
        improvement = None
        if prod_run_id:
            prod_metrics = mlflow.get_run(prod_run_id).data.metrics
            primary_metric = list(metric_thresholds.keys())[0]
            current = prod_metrics.get(primary_metric, 0.0)
            candidate = eval_metrics.get(primary_metric, 0.0)
            improvement = candidate - current

        result = {
            "model_type": model_type,
            "metrics": eval_metrics,
            "passed_threshold": passed,
            "improvement_over_production": improvement,
            "recommend_promotion": passed and (
                improvement is None or improvement >= -0.005
            ),
        }

        with mlflow.start_run(run_id=train_result["run_id"]):
            mlflow.log_metrics(
                {f"eval_{k}": v for k, v in eval_metrics.items()}
            )
            mlflow.set_tag("promotion_recommended", str(result["recommend_promotion"]))

        return result

    def register_models(**context: Any) -> Dict[str, str]:
        """Register evaluated models in MLflow Model Registry.

        Only registers models that passed evaluation. Transitions them
        to 'Staging' stage for downstream canary deployment.

        Returns:
            Dictionary mapping model_type to registry version.
        """
        import mlflow
        from mlflow.tracking import MlflowClient

        ti = context["ti"]
        client = MlflowClient()
        registered: Dict[str, str] = {}

        for model_type in ["retrieval", "ranking"]:
            eval_result = ti.xcom_pull(task_ids=f"evaluate_{model_type}")
            if not eval_result["recommend_promotion"]:
                print(
                    f"{model_type} model did not pass evaluation. "
                    "Skipping registration."
                )
                continue

            train_result = ti.xcom_pull(task_ids=f"train_{model_type}")
            model_uri = f"runs:/{train_result['run_id']}/model"
            model_name = f"streamrec-{model_type}"

            mv = mlflow.register_model(model_uri, model_name)
            client.transition_model_version_stage(
                name=model_name,
                version=mv.version,
                stage="Staging",
            )
            registered[model_type] = mv.version
            print(f"Registered {model_name} v{mv.version} -> Staging")

        return registered

    # ── Task definitions ────────────────────────────────────────────────────

    # Wait for the upstream data pipeline to finish
    wait_for_data = ExternalTaskSensor(
        task_id="wait_for_data_pipeline",
        external_dag_id="data_ingestion_pipeline",
        external_task_id="write_to_delta_lake",
        timeout=3600,
        poke_interval=120,
        mode="reschedule",  # Release worker slot while waiting
    )

    extract = PythonOperator(
        task_id="extract_interactions",
        python_callable=extract_interactions,
    )

    validate = PythonOperator(
        task_id="validate_data",
        python_callable=validate_data,
        retries=1,
    )

    features = PythonOperator(
        task_id="compute_features",
        python_callable=compute_features,
        execution_timeout=timedelta(hours=1),
        pool="feature_store_pool",  # Limit concurrent feature store access
    )

    train_retrieval = PythonOperator(
        task_id="train_retrieval",
        python_callable=train_model,
        op_kwargs={"model_type": "retrieval"},
        execution_timeout=timedelta(hours=3),
        pool="gpu_pool",  # Limit concurrent GPU jobs
    )

    train_ranking = PythonOperator(
        task_id="train_ranking",
        python_callable=train_model,
        op_kwargs={"model_type": "ranking"},
        execution_timeout=timedelta(hours=3),
        pool="gpu_pool",
    )

    eval_retrieval = PythonOperator(
        task_id="evaluate_retrieval",
        python_callable=evaluate_model,
        op_kwargs={"model_type": "retrieval"},
    )

    eval_ranking = PythonOperator(
        task_id="evaluate_ranking",
        python_callable=evaluate_model,
        op_kwargs={"model_type": "ranking"},
    )

    register = PythonOperator(
        task_id="register_models",
        python_callable=register_models,
        trigger_rule=TriggerRule.ALL_DONE,  # Run even if eval skipped a model
    )

    trigger_deploy = TriggerDagRunOperator(
        task_id="trigger_deployment",
        trigger_dag_id="streamrec_canary_deployment",
        trigger_rule=TriggerRule.ALL_SUCCESS,
    )

    # ── Dependencies ────────────────────────────────────────────────────────
    wait_for_data >> extract >> validate >> features
    features >> [train_retrieval, train_ranking]
    train_retrieval >> eval_retrieval
    train_ranking >> eval_ranking
    [eval_retrieval, eval_ranking] >> register >> trigger_deploy

Airflow Strengths and Weaknesses

Strengths: - Massive ecosystem: 80+ provider packages (AWS, GCP, Azure, Kubernetes, Spark, dbt) - Battle-tested at enormous scale (Airbnb, Google, Netflix, Uber) - Rich web UI for monitoring, debugging, and manual intervention - Mature community with extensive documentation

Weaknesses: - Task-centric model makes data lineage opaque: you know that train_retrieval depends on compute_features, but Airflow does not know what data flows between them - XCom serialization limits: large data must be stored externally (S3, GCS) with only a reference passed through XCom - DAG parsing overhead: the scheduler must parse every DAG file periodically, leading to performance issues with hundreds of DAGs - Testing is difficult: tasks are tightly coupled to the Airflow runtime, making unit testing require mocking the entire context


27.4 Dagster — Software-Defined Assets

Dagster, created by Nick Schrock (co-creator of GraphQL) in 2019, represents a philosophical departure from Airflow. Where Airflow asks "what tasks should I run?", Dagster asks "what data assets should exist, and how are they derived?"

Core Concepts

Concept Description
Asset A persistent data object (table, file, model artifact). The unit of orchestration.
Software-Defined Asset (SDA) A Python function that defines how an asset is computed from upstream assets.
Op A unit of computation (analogous to an Airflow operator). Lower-level than assets.
Job A graph of ops or a selection of assets to materialize together.
Resource An injectable dependency (database connection, API client, GPU cluster). Supports configuration and testing.
IO Manager A resource that handles loading and storing assets. Decouples computation from storage.
Partition A logical slice of an asset (e.g., one day of data). Enables backfill and incremental computation.
Sensor A function that monitors external state and triggers runs when conditions are met.
Schedule A cron-based trigger for asset materialization.

Dagster's Philosophy: Declarative, Asset-Centric

Dagster's mental model is declarative: you define the data assets that should exist and their derivation logic. The orchestrator infers the DAG from asset dependencies, manages materialization, and tracks the freshness and quality of every asset.

This is analogous to writing a Makefile: you declare what the outputs are and how to build them from inputs. The build system figures out the execution order and rebuilds only what has changed.

The asset-centric model has profound implications:

  1. Data lineage is first-class. Dagster knows that training_features depends on validated_interactions, which depends on raw_interactions. This lineage is visible in the UI and queryable via the API.
  2. Partial materialization is natural. You can rematerialize a single asset without rerunning the entire pipeline.
  3. Freshness policies replace schedules. Instead of "run at 2am," you declare "this asset should be no more than 24 hours old," and Dagster schedules accordingly.
  4. Testing is easier. Assets are pure functions with explicit inputs and outputs; mocking an IO manager is simpler than mocking the entire Airflow context.

The StreamRec Pipeline in Dagster

"""StreamRec daily training pipeline — Dagster implementation.

Assets represent the data and model artifacts produced at each stage.
Dagster infers the DAG from asset dependencies automatically.
"""
from dagster import (
    asset,
    define_asset_job,
    AssetSelection,
    DailyPartitionsDefinition,
    AssetIn,
    Output,
    MetadataValue,
    ResourceDefinition,
    IOManager,
    io_manager,
    InputContext,
    OutputContext,
    ScheduleDefinition,
    Definitions,
    Config,
    RetryPolicy,
    Backoff,
)
from dataclasses import dataclass
from typing import Any, Dict, Optional
import pandas as pd


# ── Configuration ───────────────────────────────────────────────────────────

daily_partitions = DailyPartitionsDefinition(start_date="2025-01-01")

training_retry_policy = RetryPolicy(
    max_retries=2,
    delay=60,
    backoff=Backoff.EXPONENTIAL,
)


# ── IO Manager for S3-backed storage ───────────────────────────────────────

class S3ParquetIOManager(IOManager):
    """IO Manager that reads and writes Parquet files to S3.

    Organizes files by asset name and partition key for deterministic
    path resolution, enabling idempotent re-materialization.

    Attributes:
        bucket: S3 bucket name.
        prefix: Key prefix for all managed objects.
    """

    def __init__(self, bucket: str, prefix: str = "pipeline"):
        self.bucket = bucket
        self.prefix = prefix

    def _get_path(self, context: Any) -> str:
        """Compute the S3 path for an asset + partition.

        Args:
            context: Dagster IO context (InputContext or OutputContext).

        Returns:
            S3 path string.
        """
        asset_key = context.asset_key.path[-1]
        partition_key = context.partition_key if context.has_partition_key else "unpartitioned"
        return f"s3://{self.bucket}/{self.prefix}/{asset_key}/{partition_key}/data.parquet"

    def handle_output(self, context: OutputContext, obj: pd.DataFrame) -> None:
        """Write a DataFrame to S3 as Parquet.

        Args:
            context: Output context with asset and partition metadata.
            obj: DataFrame to persist.
        """
        path = self._get_path(context)
        obj.to_parquet(path, index=False)
        context.add_output_metadata(
            {
                "path": MetadataValue.text(path),
                "num_rows": MetadataValue.int(len(obj)),
                "columns": MetadataValue.text(str(list(obj.columns))),
            }
        )

    def load_input(self, context: InputContext) -> pd.DataFrame:
        """Load a DataFrame from S3 Parquet.

        Args:
            context: Input context identifying the upstream asset and partition.

        Returns:
            DataFrame loaded from S3.
        """
        path = self._get_path(context)
        return pd.read_parquet(path)


@io_manager
def s3_parquet_io_manager(_context) -> S3ParquetIOManager:
    """Factory for S3 Parquet IO Manager."""
    return S3ParquetIOManager(
        bucket="streamrec-pipeline",
        prefix="dagster-assets",
    )


# ── Assets ──────────────────────────────────────────────────────────────────

@asset(
    partitions_def=daily_partitions,
    group_name="data_preparation",
    description="Raw interaction events extracted from the data lake.",
    metadata={"source": "Delta Lake", "table": "interactions"},
)
def raw_interactions(context) -> Output[pd.DataFrame]:
    """Extract raw interaction events for a single day's partition.

    Reads from the Delta Lake interactions table, filtering to the
    partition date. Returns a DataFrame with user_id, item_id,
    event_type, timestamp, and contextual features.
    """
    from deltalake import DeltaTable

    partition_date = context.partition_key
    dt = DeltaTable("s3://streamrec-datalake/interactions/")
    df = dt.to_pandas(
        filters=[("event_date", "=", partition_date)]
    )

    context.log.info(f"Extracted {len(df)} interactions for {partition_date}")
    return Output(
        df,
        metadata={
            "num_rows": MetadataValue.int(len(df)),
            "partition_date": MetadataValue.text(partition_date),
        },
    )


@asset(
    partitions_def=daily_partitions,
    group_name="data_preparation",
    description="Validated interaction events with quality checks passed.",
)
def validated_interactions(context, raw_interactions: pd.DataFrame) -> Output[pd.DataFrame]:
    """Validate extracted interactions against data quality contracts.

    Checks:
        - Minimum 100,000 rows per day
        - Required columns present
        - Null rate per column <= 5%
        - Event type in allowed set

    Raises:
        ValueError: If any critical check fails.
    """
    num_rows = len(raw_interactions)
    min_rows = 100_000

    if num_rows < min_rows:
        raise ValueError(
            f"Row count {num_rows} below minimum {min_rows}."
        )

    required_columns = {"user_id", "item_id", "event_type", "timestamp"}
    missing = required_columns - set(raw_interactions.columns)
    if missing:
        raise ValueError(f"Missing required columns: {missing}")

    max_null_rate = 0.05
    null_report = {}
    for col in raw_interactions.columns:
        null_rate = raw_interactions[col].isna().mean()
        null_report[col] = null_rate
        if null_rate > max_null_rate:
            raise ValueError(
                f"Column '{col}' null rate {null_rate:.3f} exceeds {max_null_rate}."
            )

    allowed_events = {"view", "click", "complete", "share", "save"}
    invalid_events = set(raw_interactions["event_type"].unique()) - allowed_events
    if invalid_events:
        context.log.warning(f"Filtering {len(invalid_events)} unknown event types.")
        raw_interactions = raw_interactions[
            raw_interactions["event_type"].isin(allowed_events)
        ]

    context.log.info(f"Validation passed: {len(raw_interactions)} rows.")
    return Output(
        raw_interactions,
        metadata={
            "num_rows": MetadataValue.int(len(raw_interactions)),
            "null_report": MetadataValue.json(null_report),
        },
    )


@asset(
    partitions_def=daily_partitions,
    group_name="feature_engineering",
    description="Point-in-time correct training features from Feast.",
)
def training_features(context, validated_interactions: pd.DataFrame) -> Output[pd.DataFrame]:
    """Compute training features by joining interactions with the offline feature store.

    Uses Feast's get_historical_features to ensure point-in-time correctness,
    preventing future information leakage.
    """
    from feast import FeatureStore

    store = FeatureStore(repo_path="/opt/feast/streamrec/")

    training_df = store.get_historical_features(
        entity_df=validated_interactions[["user_id", "item_id", "timestamp"]],
        features=[
            "user_features:watch_count_7d",
            "user_features:avg_session_length",
            "user_features:genre_preference_vector",
            "item_features:popularity_score",
            "item_features:content_embedding",
            "item_features:days_since_release",
        ],
    ).to_df()

    context.log.info(
        f"Feature computation complete: {len(training_df)} examples, "
        f"{len(training_df.columns)} features."
    )
    return Output(
        training_df,
        metadata={
            "num_examples": MetadataValue.int(len(training_df)),
            "num_features": MetadataValue.int(len(training_df.columns)),
        },
    )


@asset(
    partitions_def=daily_partitions,
    group_name="model_training",
    description="Trained two-tower retrieval model artifact.",
    retry_policy=training_retry_policy,
)
def retrieval_model(context, training_features: pd.DataFrame) -> Output[Dict[str, Any]]:
    """Train the two-tower retrieval model via DDP on 4x A100 GPUs.

    Returns a dictionary with the artifact path, training metrics,
    and MLflow run ID for downstream evaluation and registration.
    """
    import mlflow

    partition_date = context.partition_key

    with mlflow.start_run(run_name=f"retrieval_{partition_date}") as run:
        mlflow.log_param("model_type", "two_tower_retrieval")
        mlflow.log_param("training_date", partition_date)
        mlflow.log_param("num_examples", len(training_features))
        mlflow.log_param("epochs", 10)
        mlflow.log_param("batch_size", 4096)
        mlflow.log_param("learning_rate", 1e-3)

        # DDP training (implementation from Chapter 26)
        # ... training loop ...

        artifact_path = (
            f"s3://streamrec-pipeline/models/{partition_date}/retrieval/"
        )

        metrics = {
            "final_train_loss": 0.0,  # Populated by training loop
            "training_time_seconds": 0.0,
        }
        mlflow.log_metrics(metrics)

    result = {
        "artifact_path": artifact_path,
        "run_id": run.info.run_id,
        "metrics": metrics,
    }

    return Output(
        result,
        metadata={
            "mlflow_run_id": MetadataValue.text(run.info.run_id),
            "artifact_path": MetadataValue.text(artifact_path),
        },
    )


@asset(
    partitions_def=daily_partitions,
    group_name="model_training",
    description="Trained DCN-V2 ranking model artifact.",
    retry_policy=training_retry_policy,
)
def ranking_model(context, training_features: pd.DataFrame) -> Output[Dict[str, Any]]:
    """Train the DCN-V2 ranking model via DDP on 4x A100 GPUs.

    Returns a dictionary with the artifact path, training metrics,
    and MLflow run ID for downstream evaluation and registration.
    """
    import mlflow

    partition_date = context.partition_key

    with mlflow.start_run(run_name=f"ranking_{partition_date}") as run:
        mlflow.log_param("model_type", "dcn_v2_ranking")
        mlflow.log_param("training_date", partition_date)
        mlflow.log_param("num_examples", len(training_features))
        mlflow.log_param("epochs", 5)
        mlflow.log_param("batch_size", 2048)
        mlflow.log_param("learning_rate", 5e-4)

        # DDP training (implementation from Chapter 26)
        # ... training loop ...

        artifact_path = (
            f"s3://streamrec-pipeline/models/{partition_date}/ranking/"
        )

        metrics = {
            "final_train_loss": 0.0,  # Populated by training loop
            "training_time_seconds": 0.0,
        }
        mlflow.log_metrics(metrics)

    result = {
        "artifact_path": artifact_path,
        "run_id": run.info.run_id,
        "metrics": metrics,
    }

    return Output(
        result,
        metadata={
            "mlflow_run_id": MetadataValue.text(run.info.run_id),
            "artifact_path": MetadataValue.text(artifact_path),
        },
    )


@asset(
    partitions_def=daily_partitions,
    group_name="evaluation",
    description="Evaluation results for the retrieval model.",
)
def retrieval_evaluation(
    context, retrieval_model: Dict[str, Any], training_features: pd.DataFrame
) -> Output[Dict[str, Any]]:
    """Evaluate the retrieval model on the held-out test set.

    Computes Recall@20 and NDCG@20 on a time-split test partition.
    Compares against the current production model.
    """
    import mlflow

    # Split training_features into train/test by timestamp
    # Compute retrieval metrics
    eval_metrics = {
        "recall_at_20": 0.0,   # Populated by evaluation
        "ndcg_at_20": 0.0,     # Populated by evaluation
        "hit_at_10": 0.0,      # Populated by evaluation
    }

    thresholds = {"recall_at_20": 0.15, "ndcg_at_20": 0.10}
    passed = all(
        eval_metrics.get(m, 0.0) >= t for m, t in thresholds.items()
    )

    result = {
        "model_type": "retrieval",
        "metrics": eval_metrics,
        "passed_threshold": passed,
        "recommend_promotion": passed,
    }

    with mlflow.start_run(run_id=retrieval_model["run_id"]):
        mlflow.log_metrics({f"eval_{k}": v for k, v in eval_metrics.items()})

    return Output(
        result,
        metadata={
            "passed": MetadataValue.bool(passed),
            "recall_at_20": MetadataValue.float(eval_metrics["recall_at_20"]),
            "ndcg_at_20": MetadataValue.float(eval_metrics["ndcg_at_20"]),
        },
    )


@asset(
    partitions_def=daily_partitions,
    group_name="evaluation",
    description="Evaluation results for the ranking model.",
)
def ranking_evaluation(
    context, ranking_model: Dict[str, Any], training_features: pd.DataFrame
) -> Output[Dict[str, Any]]:
    """Evaluate the ranking model on the held-out test set.

    Computes AUC and logloss on a time-split test partition.
    Compares against the current production model.
    """
    import mlflow

    eval_metrics = {
        "auc": 0.0,       # Populated by evaluation
        "logloss": 0.0,   # Populated by evaluation
    }

    thresholds = {"auc": 0.75, "logloss": 0.50}
    passed = all(
        eval_metrics.get(m, 0.0) >= t
        if m != "logloss"
        else eval_metrics.get(m, float("inf")) <= t
        for m, t in thresholds.items()
    )

    result = {
        "model_type": "ranking",
        "metrics": eval_metrics,
        "passed_threshold": passed,
        "recommend_promotion": passed,
    }

    with mlflow.start_run(run_id=ranking_model["run_id"]):
        mlflow.log_metrics({f"eval_{k}": v for k, v in eval_metrics.items()})

    return Output(
        result,
        metadata={
            "passed": MetadataValue.bool(passed),
            "auc": MetadataValue.float(eval_metrics["auc"]),
            "logloss": MetadataValue.float(eval_metrics["logloss"]),
        },
    )


@asset(
    partitions_def=daily_partitions,
    group_name="deployment",
    description="Registered model versions in MLflow Model Registry.",
)
def registered_models(
    context,
    retrieval_model: Dict[str, Any],
    ranking_model: Dict[str, Any],
    retrieval_evaluation: Dict[str, Any],
    ranking_evaluation: Dict[str, Any],
) -> Output[Dict[str, str]]:
    """Register models that passed evaluation in MLflow Model Registry.

    Transitions promoted models to 'Staging' for canary deployment.
    """
    import mlflow
    from mlflow.tracking import MlflowClient

    client = MlflowClient()
    registered: Dict[str, str] = {}

    for model_type, train_result, eval_result in [
        ("retrieval", retrieval_model, retrieval_evaluation),
        ("ranking", ranking_model, ranking_evaluation),
    ]:
        if not eval_result["recommend_promotion"]:
            context.log.warning(
                f"{model_type} model did not pass evaluation; skipping."
            )
            continue

        model_name = f"streamrec-{model_type}"
        model_uri = f"runs:/{train_result['run_id']}/model"
        mv = mlflow.register_model(model_uri, model_name)
        client.transition_model_version_stage(
            name=model_name,
            version=mv.version,
            stage="Staging",
        )
        registered[model_type] = str(mv.version)
        context.log.info(f"Registered {model_name} v{mv.version} -> Staging")

    return Output(
        registered,
        metadata={
            "num_registered": MetadataValue.int(len(registered)),
        },
    )


# ── Job and Schedule ────────────────────────────────────────────────────────

streamrec_training_job = define_asset_job(
    name="streamrec_daily_training",
    selection=AssetSelection.groups("data_preparation", "feature_engineering",
                                     "model_training", "evaluation", "deployment"),
    description="Full StreamRec training pipeline: extract -> validate -> "
                "features -> train -> evaluate -> register.",
)

streamrec_schedule = ScheduleDefinition(
    job=streamrec_training_job,
    cron_schedule="0 2 * * *",
)


# ── Definitions (Dagster entry point) ──────────────────────────────────────

defs = Definitions(
    assets=[
        raw_interactions,
        validated_interactions,
        training_features,
        retrieval_model,
        ranking_model,
        retrieval_evaluation,
        ranking_evaluation,
        registered_models,
    ],
    jobs=[streamrec_training_job],
    schedules=[streamrec_schedule],
    resources={
        "io_manager": s3_parquet_io_manager,
    },
)

Comparing the Two Implementations

The Dagster implementation is noticeably different from Airflow in several respects:

Dimension Airflow Dagster
Organizing principle Tasks (what to run) Assets (what to produce)
Dependencies Explicit via >> operator Implicit via function arguments
Data passing XCom (small metadata) or external storage IO Managers (pluggable, type-aware)
Lineage Opaque (task-to-task only) Full asset lineage with metadata
Partitioning Manual via templates and macros First-class PartitionsDefinition
Testing Requires Airflow context mocking Pure functions; mock IO managers
UI Task-centric Gantt charts Asset-centric materialization graph
Backfill Trigger DAG runs for each date Select asset partitions to rematerialize
Maturity 10+ years, massive community 6 years, growing rapidly

Neither is universally superior. Airflow excels in environments with diverse workloads (ETL, ML, analytics, dbt) that benefit from its vast operator ecosystem. Dagster excels in ML-heavy environments where data lineage, asset freshness, and testability are priorities.


27.5 Prefect — Python-Native Orchestration

Prefect, created by Jeremiah Lowin in 2018, takes a third approach: minimal orchestration overhead on top of standard Python code. Where Airflow requires learning its operator model and Dagster requires learning its asset model, Prefect requires only adding @flow and @task decorators to existing Python functions.

Core Concepts

Concept Description
Flow A Python function decorated with @flow. The top-level orchestration unit.
Task A Python function decorated with @task. A unit of work within a flow.
Deployment A flow configured with infrastructure, schedule, and parameters for remote execution.
Work Pool A queue that routes flow runs to specific infrastructure (Kubernetes, Docker, local).
Block A typed, configurable storage for credentials and configuration (S3, GCS, Slack).
State The result of a task or flow execution (Completed, Failed, Cancelled, etc.).
Artifact A rich output (table, Markdown, link) attached to a flow run for observability.

Prefect's Philosophy: Negative Engineering

Prefect describes its approach as eliminating "negative engineering" — the defensive code that handles retries, logging, state management, and failure notification. Without an orchestrator, an ML pipeline is 40% business logic and 60% boilerplate for error handling, retrying, and alerting. Prefect aims to eliminate the 60% by making Python functions orchestrable with minimal changes.

The StreamRec Pipeline in Prefect

"""StreamRec daily training pipeline — Prefect implementation.

Flows and tasks are standard Python functions with decorators.
Prefect handles retries, logging, state management, and deployment.
"""
from prefect import flow, task, get_run_logger
from prefect.tasks import task_input_hash
from prefect.artifacts import create_table_artifact, create_markdown_artifact
from datetime import timedelta, date
from typing import Any, Dict, Optional
import pandas as pd


# ── Tasks ───────────────────────────────────────────────────────────────────

@task(
    retries=2,
    retry_delay_seconds=[60, 300],  # 1 min, then 5 min
    timeout_seconds=1800,
    cache_key_fn=task_input_hash,
    cache_expiration=timedelta(hours=12),
    tags=["data", "extraction"],
)
def extract_interactions(logical_date: str) -> pd.DataFrame:
    """Extract interaction events from the data lake.

    Args:
        logical_date: Date string (YYYY-MM-DD) identifying the data partition.

    Returns:
        DataFrame of interaction events for the specified date.
    """
    from deltalake import DeltaTable

    logger = get_run_logger()
    dt = DeltaTable("s3://streamrec-datalake/interactions/")
    df = dt.to_pandas(filters=[("event_date", "=", logical_date)])
    logger.info(f"Extracted {len(df)} interactions for {logical_date}")
    return df


@task(
    retries=1,
    timeout_seconds=600,
    tags=["data", "validation"],
)
def validate_data(df: pd.DataFrame) -> pd.DataFrame:
    """Validate interaction data against quality contracts.

    Args:
        df: Raw interaction DataFrame.

    Returns:
        Validated DataFrame (unchanged if all checks pass).

    Raises:
        ValueError: If any critical quality check fails.
    """
    logger = get_run_logger()
    num_rows = len(df)

    if num_rows < 100_000:
        raise ValueError(f"Row count {num_rows} below minimum 100,000.")

    required_columns = {"user_id", "item_id", "event_type", "timestamp"}
    missing = required_columns - set(df.columns)
    if missing:
        raise ValueError(f"Missing required columns: {missing}")

    max_null_rate = 0.05
    for col in df.columns:
        null_rate = df[col].isna().mean()
        if null_rate > max_null_rate:
            raise ValueError(
                f"Column '{col}' null rate {null_rate:.3f} exceeds {max_null_rate}."
            )

    logger.info(f"Validation passed: {num_rows} rows.")
    return df


@task(
    retries=1,
    timeout_seconds=3600,
    tags=["features"],
)
def compute_features(validated_df: pd.DataFrame, logical_date: str) -> pd.DataFrame:
    """Compute training features via Feast point-in-time join.

    Args:
        validated_df: Validated interaction DataFrame.
        logical_date: Date string for the training partition.

    Returns:
        Feature-enriched training DataFrame.
    """
    from feast import FeatureStore

    logger = get_run_logger()
    store = FeatureStore(repo_path="/opt/feast/streamrec/")

    training_df = store.get_historical_features(
        entity_df=validated_df[["user_id", "item_id", "timestamp"]],
        features=[
            "user_features:watch_count_7d",
            "user_features:avg_session_length",
            "user_features:genre_preference_vector",
            "item_features:popularity_score",
            "item_features:content_embedding",
            "item_features:days_since_release",
        ],
    ).to_df()

    logger.info(f"Computed features: {len(training_df)} rows, {len(training_df.columns)} cols")
    return training_df


@task(
    retries=2,
    retry_delay_seconds=[120, 600],
    timeout_seconds=10800,  # 3 hours
    tags=["training", "gpu"],
)
def train_model(
    training_df: pd.DataFrame,
    model_type: str,
    logical_date: str,
) -> Dict[str, Any]:
    """Train a model (retrieval or ranking) using DDP.

    Args:
        training_df: Feature-enriched training DataFrame.
        model_type: Either 'retrieval' or 'ranking'.
        logical_date: Training date for artifact naming.

    Returns:
        Dictionary with artifact path, MLflow run ID, and metrics.
    """
    import mlflow

    logger = get_run_logger()

    config = {
        "retrieval": {"epochs": 10, "batch_size": 4096, "lr": 1e-3},
        "ranking": {"epochs": 5, "batch_size": 2048, "lr": 5e-4},
    }[model_type]

    with mlflow.start_run(run_name=f"{model_type}_{logical_date}") as run:
        mlflow.log_params(config)
        mlflow.log_param("model_type", model_type)
        mlflow.log_param("num_examples", len(training_df))

        # DDP training (Chapter 26)
        artifact_path = f"s3://streamrec-pipeline/models/{logical_date}/{model_type}/"

        metrics = {"final_loss": 0.0, "training_time_seconds": 0.0}
        mlflow.log_metrics(metrics)

    logger.info(f"Training complete: {model_type}, run_id={run.info.run_id}")
    return {
        "artifact_path": artifact_path,
        "run_id": run.info.run_id,
        "metrics": metrics,
    }


@task(
    timeout_seconds=1800,
    tags=["evaluation"],
)
def evaluate_model(
    training_df: pd.DataFrame,
    train_result: Dict[str, Any],
    model_type: str,
) -> Dict[str, Any]:
    """Evaluate a trained model on the held-out test set.

    Args:
        training_df: Training data (will be split for evaluation).
        train_result: Training output with artifact path and run ID.
        model_type: Either 'retrieval' or 'ranking'.

    Returns:
        Evaluation results with metrics and promotion recommendation.
    """
    import mlflow

    logger = get_run_logger()

    thresholds = {
        "retrieval": {"recall_at_20": 0.15, "ndcg_at_20": 0.10},
        "ranking": {"auc": 0.75},
    }[model_type]

    eval_metrics: Dict[str, float] = {}  # Populated by evaluation logic

    passed = all(
        eval_metrics.get(m, 0.0) >= t for m, t in thresholds.items()
    )

    with mlflow.start_run(run_id=train_result["run_id"]):
        mlflow.log_metrics({f"eval_{k}": v for k, v in eval_metrics.items()})

    logger.info(f"{model_type} evaluation: passed={passed}, metrics={eval_metrics}")
    return {
        "model_type": model_type,
        "metrics": eval_metrics,
        "passed": passed,
    }


@task(
    timeout_seconds=600,
    tags=["deployment"],
)
def register_models(
    retrieval_train: Dict[str, Any],
    ranking_train: Dict[str, Any],
    retrieval_eval: Dict[str, Any],
    ranking_eval: Dict[str, Any],
) -> Dict[str, str]:
    """Register models that passed evaluation in MLflow Model Registry.

    Args:
        retrieval_train: Retrieval model training results.
        ranking_train: Ranking model training results.
        retrieval_eval: Retrieval model evaluation results.
        ranking_eval: Ranking model evaluation results.

    Returns:
        Dictionary mapping model type to registered version.
    """
    import mlflow
    from mlflow.tracking import MlflowClient

    logger = get_run_logger()
    client = MlflowClient()
    registered: Dict[str, str] = {}

    for model_type, train_result, eval_result in [
        ("retrieval", retrieval_train, retrieval_eval),
        ("ranking", ranking_train, ranking_eval),
    ]:
        if not eval_result["passed"]:
            logger.warning(f"{model_type} model skipped registration.")
            continue

        model_name = f"streamrec-{model_type}"
        model_uri = f"runs:/{train_result['run_id']}/model"
        mv = mlflow.register_model(model_uri, model_name)
        client.transition_model_version_stage(
            name=model_name, version=mv.version, stage="Staging"
        )
        registered[model_type] = str(mv.version)
        logger.info(f"Registered {model_name} v{mv.version}")

    return registered


# ── Flow ────────────────────────────────────────────────────────────────────

@flow(
    name="streamrec-daily-training",
    description="Full StreamRec training pipeline.",
    retries=0,
    timeout_seconds=21600,  # 6 hours total
)
def streamrec_training_pipeline(logical_date: Optional[str] = None) -> Dict[str, str]:
    """Orchestrate the StreamRec daily training pipeline.

    Args:
        logical_date: Override for the training date. Defaults to yesterday.

    Returns:
        Dictionary of registered model versions.
    """
    if logical_date is None:
        logical_date = (date.today() - timedelta(days=1)).isoformat()

    # Data preparation (sequential)
    raw_df = extract_interactions(logical_date)
    validated_df = validate_data(raw_df)
    training_df = compute_features(validated_df, logical_date)

    # Model training (parallel via .submit())
    retrieval_future = train_model.submit(training_df, "retrieval", logical_date)
    ranking_future = train_model.submit(training_df, "ranking", logical_date)

    retrieval_result = retrieval_future.result()
    ranking_result = ranking_future.result()

    # Evaluation (parallel)
    retrieval_eval_future = evaluate_model.submit(
        training_df, retrieval_result, "retrieval"
    )
    ranking_eval_future = evaluate_model.submit(
        training_df, ranking_result, "ranking"
    )

    retrieval_eval = retrieval_eval_future.result()
    ranking_eval = ranking_eval_future.result()

    # Registration
    registered = register_models(
        retrieval_result, ranking_result, retrieval_eval, ranking_eval
    )

    # Create observability artifact
    create_table_artifact(
        key="training-summary",
        table=[
            {
                "Model": "Retrieval",
                "Passed": str(retrieval_eval["passed"]),
                "Registered": registered.get("retrieval", "N/A"),
            },
            {
                "Model": "Ranking",
                "Passed": str(ranking_eval["passed"]),
                "Registered": registered.get("ranking", "N/A"),
            },
        ],
        description=f"Training summary for {logical_date}",
    )

    return registered

Three Frameworks, One Pipeline: A Comparative Summary

Concern Airflow Dagster Prefect
Defining the pipeline DAG object + operators @asset functions + auto-DAG @flow + @task decorators
Passing data XCom (JSON-serializable, <48KB default) IO managers (pluggable, any type) Task return values (any picklable object)
Parallelism [task_a, task_b] syntax Automatic from asset graph .submit() for concurrent tasks
Scheduling schedule_interval on DAG ScheduleDefinition or freshness policies Deployment with cron or interval
Retry configuration retries + retry_delay in default_args RetryPolicy on assets/ops retries + retry_delay_seconds on tasks
Caching Not built-in; use idempotent writes Asset materialization is cached cache_key_fn + cache_expiration
Local development airflow standalone (heavy) dagster dev (fast) python my_flow.py (instant)
Cloud offering Astronomer, MWAA (AWS), Cloud Composer (GCP) Dagster Cloud Prefect Cloud

27.6 Idempotency: The Cardinal Virtue of Pipeline Engineering

A task is idempotent if executing it multiple times with the same input produces the same output and the same side effects. Idempotency is the single most important property of a production pipeline task, because it makes retries and backfills safe.

Why Idempotency Matters

Consider the extract_interactions task. If it fails after writing half the data to S3, a retry must produce the same complete output — not append another half on top of the partial result. Without idempotency, retrying a failed task can produce duplicated data, corrupted model training sets, and silently degraded model quality.

Patterns for Achieving Idempotency

1. Write-then-rename (atomic writes).

Write output to a temporary location, then atomically rename to the final path. If the task fails before the rename, the temporary file is garbage-collected; if it fails after, the output is complete.

import os
import tempfile
import shutil
from typing import Callable
import pandas as pd


def idempotent_write(
    df: pd.DataFrame,
    final_path: str,
    write_fn: Callable[[pd.DataFrame, str], None],
) -> str:
    """Write a DataFrame to a final path idempotently.

    Uses a temporary file with atomic rename to ensure that the
    output is either complete or absent — never partial.

    Args:
        df: DataFrame to write.
        final_path: Destination path for the output.
        write_fn: Function that writes the DataFrame to a path.

    Returns:
        The final path.
    """
    temp_dir = os.path.dirname(final_path)
    with tempfile.NamedTemporaryFile(
        dir=temp_dir, suffix=".tmp", delete=False
    ) as tmp:
        temp_path = tmp.name

    try:
        write_fn(df, temp_path)
        shutil.move(temp_path, final_path)
    except Exception:
        if os.path.exists(temp_path):
            os.remove(temp_path)
        raise

    return final_path

2. Partition-based overwrites.

Write output to a partition-specific path (/data/2025-03-14/) and overwrite the entire partition on each run. Partial writes are replaced by complete writes on retry.

def partition_path(base_path: str, partition_date: str, filename: str) -> str:
    """Construct a deterministic, partition-specific output path.

    The path is a function of the base, partition date, and filename only —
    not of execution time, retry count, or run ID. This ensures that
    re-running the task for the same partition overwrites the previous output.

    Args:
        base_path: Root directory for the asset.
        partition_date: Logical date (YYYY-MM-DD).
        filename: Output filename.

    Returns:
        Full path for the partitioned output.
    """
    return f"{base_path}/{partition_date}/{filename}"


# Example: every run for the same date overwrites the same file
output = partition_path(
    "s3://streamrec-pipeline/features",
    "2025-03-14",
    "training_set.parquet",
)
# -> "s3://streamrec-pipeline/features/2025-03-14/training_set.parquet"

3. Deterministic model artifact paths.

Model artifacts should be stored at paths determined by the training date and configuration hash — not by timestamps or random UUIDs. This ensures that retraining on the same data and configuration produces an artifact at the same location.

import hashlib
import json
from typing import Dict, Any


def model_artifact_path(
    base_path: str,
    partition_date: str,
    model_type: str,
    config: Dict[str, Any],
) -> str:
    """Compute a deterministic artifact path from date and configuration.

    The configuration hash ensures that changing hyperparameters
    produces a different path, while rerunning with the same
    configuration overwrites the previous artifact.

    Args:
        base_path: Root directory for model artifacts.
        partition_date: Training date.
        model_type: Model identifier (e.g., 'retrieval', 'ranking').
        config: Hyperparameter dictionary.

    Returns:
        Deterministic artifact path.
    """
    config_hash = hashlib.sha256(
        json.dumps(config, sort_keys=True).encode()
    ).hexdigest()[:8]

    return f"{base_path}/{partition_date}/{model_type}/{config_hash}/"

Non-Idempotent Operations and Compensating Actions

Some operations are inherently non-idempotent:

  • Sending notifications. Retrying a task that sends a Slack alert will send duplicate alerts. Solution: send notifications from a separate, non-retried task.
  • Incrementing counters. Writing to an append-only log or incrementing a database counter will double-count on retry. Solution: use upsert (insert-or-update) operations keyed on the partition date.
  • External API calls with side effects. Calling an API that triggers a payment or sends an email. Solution: use idempotency keys (a unique identifier per request that the API uses to deduplicate).

27.7 Failure Handling: Retries, Backoff, and Alerting

Pipelines fail. Hardware fails, networks partition, upstream dependencies are late, data is corrupted. Robust orchestration is defined by how gracefully a pipeline handles failure, not by how often it succeeds.

Retry Policies

A retry policy specifies: - Max retries: How many times to re-attempt a failed task (typically 2-3) - Delay: How long to wait before the first retry - Backoff strategy: How the delay grows between retries - Jitter: Random noise added to the delay to prevent thundering herd problems

from dataclasses import dataclass
from enum import Enum
import random
import math
from typing import Optional


class BackoffStrategy(Enum):
    """Backoff strategy for retry delays."""
    CONSTANT = "constant"       # Same delay every time
    LINEAR = "linear"           # delay * attempt_number
    EXPONENTIAL = "exponential"  # delay * 2^(attempt_number - 1)


@dataclass
class RetryConfig:
    """Configuration for task retry behavior.

    Attributes:
        max_retries: Maximum number of retry attempts.
        base_delay_seconds: Base delay before the first retry.
        backoff: Backoff strategy for increasing delays.
        max_delay_seconds: Upper bound on the retry delay.
        jitter_fraction: Random jitter as a fraction of the computed delay.
    """
    max_retries: int = 3
    base_delay_seconds: float = 60.0
    backoff: BackoffStrategy = BackoffStrategy.EXPONENTIAL
    max_delay_seconds: float = 1800.0  # 30 minutes
    jitter_fraction: float = 0.1

    def compute_delay(self, attempt: int) -> float:
        """Compute the delay before retry attempt `attempt`.

        The delay grows according to the backoff strategy and is
        capped at max_delay_seconds. Jitter is added to prevent
        synchronized retries across parallel tasks.

        Args:
            attempt: The retry attempt number (1-indexed).

        Returns:
            Delay in seconds before this retry attempt.
        """
        if self.backoff == BackoffStrategy.CONSTANT:
            delay = self.base_delay_seconds
        elif self.backoff == BackoffStrategy.LINEAR:
            delay = self.base_delay_seconds * attempt
        elif self.backoff == BackoffStrategy.EXPONENTIAL:
            delay = self.base_delay_seconds * (2 ** (attempt - 1))
        else:
            delay = self.base_delay_seconds

        delay = min(delay, self.max_delay_seconds)

        # Add jitter: uniform random in [-jitter, +jitter]
        jitter = delay * self.jitter_fraction
        delay += random.uniform(-jitter, jitter)

        return max(0, delay)


# Example: exponential backoff with 60s base
config = RetryConfig(
    max_retries=3,
    base_delay_seconds=60,
    backoff=BackoffStrategy.EXPONENTIAL,
)

for attempt in range(1, config.max_retries + 1):
    delay = config.compute_delay(attempt)
    print(f"Attempt {attempt}: wait {delay:.0f}s")
# Attempt 1: wait ~60s
# Attempt 2: wait ~120s
# Attempt 3: wait ~240s

Exponential Backoff: The Mathematics

With exponential backoff, the delay for attempt $k$ is:

$$d_k = \min\left(d_0 \cdot 2^{k-1} + \text{Uniform}(-j, j),\; d_{\max}\right)$$

where $d_0$ is the base delay, $j = d_0 \cdot 2^{k-1} \cdot f$ is the jitter magnitude with jitter fraction $f$, and $d_{\max}$ is the maximum delay.

The total maximum wait time across $n$ retries (ignoring jitter) is:

$$T_{\max} = d_0 \sum_{k=0}^{n-1} 2^k = d_0 \cdot (2^n - 1)$$

For $d_0 = 60\text{s}$ and $n = 3$: $T_{\max} = 60 \cdot (2^3 - 1) = 420\text{s} = 7$ minutes. This is acceptable for a pipeline that runs in hours, but would be too slow for a real-time serving path.

Alerting Strategy

Not all failures deserve the same response. A taxonomy of failure severity guides alerting:

Severity Example Response Alert Channel
P0 — Critical All retries exhausted; pipeline halted Oncall page PagerDuty
P1 — High Data validation failed; training skipped today Immediate attention Slack #ml-alerts + oncall
P2 — Medium Training succeeded but metrics degraded; model not promoted Investigate within 24h Slack #ml-alerts
P3 — Low Retry succeeded on second attempt Review in weekly digest Logged only
from dataclasses import dataclass
from enum import Enum
from typing import Optional


class AlertSeverity(Enum):
    """Alert severity levels."""
    P0_CRITICAL = "P0"
    P1_HIGH = "P1"
    P2_MEDIUM = "P2"
    P3_LOW = "P3"


@dataclass
class PipelineAlert:
    """A structured alert from the ML pipeline.

    Attributes:
        severity: Alert severity level.
        pipeline_name: Name of the pipeline that generated the alert.
        task_name: Task that triggered the alert.
        partition_date: Data interval being processed.
        message: Human-readable description of the issue.
        run_url: URL to the orchestrator UI for this run.
    """
    severity: AlertSeverity
    pipeline_name: str
    task_name: str
    partition_date: str
    message: str
    run_url: Optional[str] = None

    def should_page(self) -> bool:
        """Whether this alert should page the oncall engineer."""
        return self.severity == AlertSeverity.P0_CRITICAL

    def format_slack_message(self) -> str:
        """Format the alert for Slack notification.

        Returns:
            Slack-formatted message string.
        """
        emoji = {
            AlertSeverity.P0_CRITICAL: ":rotating_light:",
            AlertSeverity.P1_HIGH: ":warning:",
            AlertSeverity.P2_MEDIUM: ":information_source:",
            AlertSeverity.P3_LOW: ":memo:",
        }[self.severity]

        lines = [
            f"{emoji} *[{self.severity.value}] Pipeline Alert*",
            f"*Pipeline:* {self.pipeline_name}",
            f"*Task:* {self.task_name}",
            f"*Partition:* {self.partition_date}",
            f"*Message:* {self.message}",
        ]
        if self.run_url:
            lines.append(f"*Run:* <{self.run_url}|View in UI>")

        return "\n".join(lines)

Dead Letter Queues for Data Pipelines

When a data validation task encounters records that fail quality checks but the overall batch is acceptable (e.g., 0.1% of records have malformed fields), a dead letter queue (DLQ) captures the bad records for later investigation without halting the pipeline.

from dataclasses import dataclass, field
from typing import List, Dict, Any
import pandas as pd


@dataclass
class DeadLetterRecord:
    """A record that failed validation, captured for investigation.

    Attributes:
        record: The original data record.
        failure_reason: Description of the validation failure.
        pipeline_name: Pipeline that produced this dead letter.
        partition_date: Data interval of the failed record.
        task_name: Task where the failure was detected.
    """
    record: Dict[str, Any]
    failure_reason: str
    pipeline_name: str
    partition_date: str
    task_name: str


@dataclass
class DeadLetterQueue:
    """Accumulates validation failures for downstream investigation.

    Provides a threshold: if the dead letter rate exceeds
    max_failure_rate, the pipeline should halt rather than
    proceeding with severely degraded data.

    Attributes:
        max_failure_rate: Maximum fraction of records that can fail
            before the pipeline halts.
        letters: Accumulated dead letter records.
        total_records: Total records processed (for rate calculation).
    """
    max_failure_rate: float = 0.01
    letters: List[DeadLetterRecord] = field(default_factory=list)
    total_records: int = 0

    def add(self, record: DeadLetterRecord) -> None:
        """Add a dead letter record."""
        self.letters.append(record)

    def failure_rate(self) -> float:
        """Compute the current dead letter rate.

        Returns:
            Fraction of total records that are dead letters.
        """
        if self.total_records == 0:
            return 0.0
        return len(self.letters) / self.total_records

    def should_halt(self) -> bool:
        """Check whether the failure rate exceeds the threshold.

        Returns:
            True if the pipeline should halt due to excessive failures.
        """
        return self.failure_rate() > self.max_failure_rate

    def flush_to_storage(self, path: str) -> int:
        """Write dead letters to persistent storage for investigation.

        Args:
            path: Storage path (e.g., S3 prefix) for dead letter output.

        Returns:
            Number of dead letters written.
        """
        if not self.letters:
            return 0

        records = [
            {
                **dl.record,
                "_failure_reason": dl.failure_reason,
                "_pipeline_name": dl.pipeline_name,
                "_partition_date": dl.partition_date,
                "_task_name": dl.task_name,
            }
            for dl in self.letters
        ]
        df = pd.DataFrame(records)
        df.to_parquet(f"{path}/dead_letters.parquet", index=False)
        count = len(self.letters)
        self.letters.clear()
        return count

27.8 Backfill: Reprocessing Historical Data

Backfill is the process of running a pipeline for historical data intervals that were missed (due to pipeline downtime) or need reprocessing (due to a bug fix in the feature computation code, a schema change in the upstream data, or a new model architecture that requires retraining on all historical data).

Backfill is one of the most operationally complex tasks in pipeline management. Done correctly, it fills gaps in data and model history. Done incorrectly, it overwrites correct data, triggers cascading recomputations, and exhausts infrastructure resources.

Backfill Strategies

Sequential backfill. Process each missing data interval one at a time, in chronological order. Safe but slow. Appropriate when tasks have temporal dependencies (today's features depend on yesterday's computed aggregates).

Parallel backfill. Process multiple data intervals simultaneously. Fast but resource-intensive. Appropriate when tasks are truly independent across partitions (e.g., extracting raw events for different dates).

Prioritized backfill. Process the most recent missing intervals first, then work backward. Ensures that the current model is trained on the freshest available data while historical gaps are filled asynchronously.

from dataclasses import dataclass
from typing import List, Optional
from datetime import date, timedelta
from enum import Enum


class BackfillStrategy(Enum):
    """Execution strategy for historical backfills."""
    SEQUENTIAL = "sequential"
    PARALLEL = "parallel"
    PRIORITIZED = "prioritized"


@dataclass
class BackfillPlan:
    """A plan for reprocessing historical pipeline runs.

    Attributes:
        pipeline_name: Pipeline to backfill.
        start_date: First data interval to process (inclusive).
        end_date: Last data interval to process (inclusive).
        strategy: Execution strategy.
        max_parallel: Maximum concurrent runs for parallel/prioritized.
        exclude_dates: Dates to skip (e.g., known holidays with no data).
    """
    pipeline_name: str
    start_date: date
    end_date: date
    strategy: BackfillStrategy = BackfillStrategy.SEQUENTIAL
    max_parallel: int = 4
    exclude_dates: List[date] = None

    def __post_init__(self):
        if self.exclude_dates is None:
            self.exclude_dates = []

    def get_intervals(self) -> List[date]:
        """Generate the list of data intervals to process.

        Returns intervals in execution order based on the strategy:
        - SEQUENTIAL / PARALLEL: chronological
        - PRIORITIZED: reverse chronological (newest first)
        """
        intervals = []
        current = self.start_date
        while current <= self.end_date:
            if current not in self.exclude_dates:
                intervals.append(current)
            current += timedelta(days=1)

        if self.strategy == BackfillStrategy.PRIORITIZED:
            intervals.reverse()

        return intervals

    def estimate_duration_hours(
        self, avg_run_duration_hours: float
    ) -> float:
        """Estimate total backfill duration.

        Args:
            avg_run_duration_hours: Average pipeline run duration in hours.

        Returns:
            Estimated total duration in hours.
        """
        num_intervals = len(self.get_intervals())

        if self.strategy == BackfillStrategy.SEQUENTIAL:
            return num_intervals * avg_run_duration_hours
        else:
            # Parallel: ceil(intervals / max_parallel) batches
            num_batches = math.ceil(num_intervals / self.max_parallel)
            return num_batches * avg_run_duration_hours


# Example: backfill 7 days of missed pipeline runs
plan = BackfillPlan(
    pipeline_name="streamrec_daily_training",
    start_date=date(2025, 3, 8),
    end_date=date(2025, 3, 14),
    strategy=BackfillStrategy.PRIORITIZED,
    max_parallel=3,
)

intervals = plan.get_intervals()
# [date(2025, 3, 14), date(2025, 3, 13), ..., date(2025, 3, 8)]

duration = plan.estimate_duration_hours(avg_run_duration_hours=3.5)
# ceil(7 / 3) * 3.5 = 3 * 3.5 = 10.5 hours

Backfill Safety: The Overwrite Problem

The most dangerous backfill error is overwriting production data with historical reprocessing. If the feature computation logic has changed since the original run, backfilling with the new logic will produce features that are inconsistent with the models trained on the original features.

Mitigation strategies:

  1. Version-tagged output paths. Include a pipeline version or code hash in the output path: /features/v2/2025-03-14/ vs. /features/v1/2025-03-14/. Backfilling with a new version writes to a new path, leaving the original intact.
  2. Dry-run mode. Run the backfill pipeline with writes redirected to a staging area. Inspect the output before promoting it to production.
  3. Comparison reports. After backfilling, automatically compare the new output against the original (row counts, column statistics, value distributions) and flag significant differences.

27.9 Pipeline Versioning and Artifact Management

A production ML pipeline produces three categories of versioned artifacts:

  1. Data artifacts: Training datasets, feature tables, evaluation splits
  2. Model artifacts: Serialized models, ONNX exports, FAISS indices
  3. Pipeline artifacts: DAG definitions, configuration files, dependency specifications

All three must be versioned, linked, and reproducible. Given a model in production, you should be able to trace back to the exact pipeline code, configuration, training data, and feature computation logic that produced it.

The Artifact Lineage Graph

Pipeline code (git SHA: a1b2c3d)
    ├── Config (hyperparams.yaml, v12)
    │
    ├── Training Data
    │   ├── raw_interactions (2025-03-14, Delta Lake version 847)
    │   ├── validated_interactions (2025-03-14, pipeline run abc123)
    │   └── training_features (2025-03-14, Feast v3.2, feature set v7)
    │
    ├── Retrieval Model
    │   ├── Artifact: s3://models/2025-03-14/retrieval/e8f7a2b1/
    │   ├── MLflow Run ID: run_xyz789
    │   ├── Recall@20: 0.187, NDCG@20: 0.121
    │   └── Registry: streamrec-retrieval v42 (Staging)
    │
    └── Ranking Model
        ├── Artifact: s3://models/2025-03-14/ranking/c3d4e5f6/
        ├── MLflow Run ID: run_uvw456
        ├── AUC: 0.812, LogLoss: 0.423
        └── Registry: streamrec-ranking v38 (Staging)

Experiment Tracking Integration

MLflow and Weights & Biases (W&B) are the two dominant experiment tracking systems. Both record parameters, metrics, and artifacts for each training run. The pipeline orchestrator should integrate with the experiment tracker at three points:

  1. Before training: Log the data interval, feature set version, and hyperparameters
  2. During training: Log training metrics (loss curves, learning rate schedule)
  3. After evaluation: Log evaluation metrics and the promotion decision
from dataclasses import dataclass, field
from typing import Dict, Any, List, Optional


@dataclass
class PipelineRunMetadata:
    """Complete metadata for a single pipeline run.

    This record links the pipeline execution to its data inputs,
    code version, configuration, and model outputs, enabling
    full reproducibility and lineage tracing.

    Attributes:
        run_id: Unique identifier for this pipeline run.
        pipeline_name: Name of the pipeline.
        partition_date: Logical date (data interval).
        git_sha: Git commit hash of the pipeline code.
        config_version: Version identifier for the configuration.
        data_versions: Dictionary mapping data asset names to version identifiers.
        model_runs: Dictionary mapping model types to MLflow run IDs.
        metrics: Dictionary mapping metric names to values.
        status: Final pipeline status.
        start_time: Pipeline start time (ISO 8601).
        end_time: Pipeline end time (ISO 8601).
    """
    run_id: str
    pipeline_name: str
    partition_date: str
    git_sha: str
    config_version: str
    data_versions: Dict[str, str] = field(default_factory=dict)
    model_runs: Dict[str, str] = field(default_factory=dict)
    metrics: Dict[str, float] = field(default_factory=dict)
    status: str = "running"
    start_time: str = ""
    end_time: str = ""

    def to_mlflow_tags(self) -> Dict[str, str]:
        """Convert metadata to MLflow tags for experiment tracking.

        Returns:
            Dictionary of string key-value pairs suitable for
            mlflow.set_tags().
        """
        tags = {
            "pipeline_run_id": self.run_id,
            "pipeline_name": self.pipeline_name,
            "partition_date": self.partition_date,
            "git_sha": self.git_sha,
            "config_version": self.config_version,
            "pipeline_status": self.status,
        }
        for asset_name, version in self.data_versions.items():
            tags[f"data_version.{asset_name}"] = version
        return tags

27.10 Testing ML Pipelines

A pipeline that is not tested is a pipeline that will fail silently. ML pipeline testing requires three levels: unit tests for individual tasks, integration tests for end-to-end pipelines, and contract tests for inter-stage data schemas.

Unit Tests: Testing Tasks in Isolation

Each pipeline task should be testable as a pure function, without running the orchestrator. This requires that tasks separate computation from orchestration concerns (connections, XCom, context).

"""Unit tests for StreamRec pipeline tasks.

Tests run without Airflow/Dagster/Prefect runtime.
Data dependencies are provided directly as function arguments.
"""
import pytest
import pandas as pd
import numpy as np
from datetime import datetime


def make_test_interactions(
    num_rows: int = 10_000,
    null_rate: float = 0.0,
    include_invalid_events: bool = False,
) -> pd.DataFrame:
    """Create a synthetic interaction DataFrame for testing.

    Args:
        num_rows: Number of rows to generate.
        null_rate: Fraction of values to set to null (per column).
        include_invalid_events: Whether to include invalid event types.

    Returns:
        DataFrame with the same schema as production interaction data.
    """
    rng = np.random.default_rng(42)

    event_types = ["view", "click", "complete", "share", "save"]
    if include_invalid_events:
        event_types.append("unknown_event")

    df = pd.DataFrame({
        "user_id": rng.integers(1, 100_000, size=num_rows).astype(str),
        "item_id": rng.integers(1, 50_000, size=num_rows).astype(str),
        "event_type": rng.choice(event_types, size=num_rows),
        "timestamp": pd.date_range(
            "2025-03-14", periods=num_rows, freq="s"
        ),
    })

    if null_rate > 0:
        for col in df.columns:
            mask = rng.random(num_rows) < null_rate
            df.loc[mask, col] = None

    return df


class TestValidateData:
    """Unit tests for the data validation task."""

    def test_valid_data_passes(self):
        """Validation should pass for data meeting all contracts."""
        df = make_test_interactions(num_rows=200_000)
        # validate_data should return without raising
        result = validate_data_logic(df)
        assert len(result) == 200_000

    def test_insufficient_rows_raises(self):
        """Validation should fail when row count is below minimum."""
        df = make_test_interactions(num_rows=50_000)
        with pytest.raises(ValueError, match="below minimum"):
            validate_data_logic(df)

    def test_missing_columns_raises(self):
        """Validation should fail when required columns are absent."""
        df = make_test_interactions()
        df = df.drop(columns=["user_id"])
        with pytest.raises(ValueError, match="Missing required columns"):
            validate_data_logic(df)

    def test_high_null_rate_raises(self):
        """Validation should fail when null rate exceeds 5%."""
        df = make_test_interactions(num_rows=200_000, null_rate=0.10)
        with pytest.raises(ValueError, match="null rate"):
            validate_data_logic(df)

    def test_invalid_events_filtered(self):
        """Invalid event types should be filtered, not cause failure."""
        df = make_test_interactions(
            num_rows=200_000, include_invalid_events=True
        )
        result = validate_data_logic(df)
        assert "unknown_event" not in result["event_type"].values


def validate_data_logic(df: pd.DataFrame) -> pd.DataFrame:
    """Pure validation logic extracted from orchestrator-specific task.

    This function contains the same logic as the Airflow/Dagster/Prefect
    task but with no runtime dependencies, making it unit-testable.
    """
    num_rows = len(df)
    if num_rows < 100_000:
        raise ValueError(f"Row count {num_rows} below minimum 100,000.")

    required_columns = {"user_id", "item_id", "event_type", "timestamp"}
    missing = required_columns - set(df.columns)
    if missing:
        raise ValueError(f"Missing required columns: {missing}")

    max_null_rate = 0.05
    for col in df.columns:
        null_rate = df[col].isna().mean()
        if null_rate > max_null_rate:
            raise ValueError(
                f"Column '{col}' null rate {null_rate:.3f} exceeds {max_null_rate}."
            )

    allowed_events = {"view", "click", "complete", "share", "save"}
    df = df[df["event_type"].isin(allowed_events)]
    return df

Integration Tests: Testing the Full Pipeline

Integration tests run the entire pipeline on synthetic data, verifying that data flows correctly through all stages and that the final output has the expected schema and properties.

"""Integration test for the StreamRec training pipeline.

Runs the full pipeline on synthetic data, verifying end-to-end
correctness without requiring production infrastructure.
"""
import pytest
from unittest.mock import patch, MagicMock
import pandas as pd
import numpy as np


class TestStreamRecPipelineIntegration:
    """End-to-end pipeline integration tests."""

    @pytest.fixture
    def synthetic_data(self) -> pd.DataFrame:
        """Generate a minimal but complete synthetic dataset."""
        rng = np.random.default_rng(42)
        num_rows = 150_000

        return pd.DataFrame({
            "user_id": rng.integers(1, 10_000, size=num_rows).astype(str),
            "item_id": rng.integers(1, 5_000, size=num_rows).astype(str),
            "event_type": rng.choice(
                ["view", "click", "complete", "share", "save"],
                size=num_rows,
            ),
            "timestamp": pd.date_range(
                "2025-03-14", periods=num_rows, freq="s"
            ),
            "watch_count_7d": rng.integers(0, 50, size=num_rows),
            "avg_session_length": rng.uniform(1, 60, size=num_rows),
            "popularity_score": rng.uniform(0, 1, size=num_rows),
            "days_since_release": rng.integers(0, 365, size=num_rows),
        })

    def test_pipeline_produces_registered_models(self, synthetic_data):
        """The full pipeline should produce at least one registered model."""
        # Mock external dependencies
        with patch("feast.FeatureStore") as mock_feast, \
             patch("mlflow.start_run") as mock_mlflow_run, \
             patch("mlflow.register_model") as mock_register:

            mock_feast.return_value.get_historical_features.return_value.to_df.return_value = synthetic_data

            mock_run = MagicMock()
            mock_run.info.run_id = "test_run_123"
            mock_mlflow_run.return_value.__enter__ = MagicMock(return_value=mock_run)
            mock_mlflow_run.return_value.__exit__ = MagicMock(return_value=False)

            mock_version = MagicMock()
            mock_version.version = "1"
            mock_register.return_value = mock_version

            # Run pipeline logic (framework-agnostic)
            validated = validate_data_logic(synthetic_data)
            assert len(validated) > 100_000

    def test_validation_failure_halts_pipeline(self):
        """When validation fails, downstream tasks should not execute."""
        bad_data = pd.DataFrame({
            "user_id": ["1"] * 100,
            "item_id": ["2"] * 100,
            "event_type": ["view"] * 100,
            "timestamp": pd.date_range("2025-03-14", periods=100, freq="s"),
        })

        with pytest.raises(ValueError, match="below minimum"):
            validate_data_logic(bad_data)

Contract Tests: Validating Inter-Stage Data Schemas

Contract tests verify that the output of one task matches the expected input schema of the next task. They catch schema evolution bugs early — when a feature engineer adds a column to the feature computation, the contract test fails if the training task does not expect the new column.

from dataclasses import dataclass, field
from typing import Dict, Set, Optional, Type
import pandas as pd


@dataclass
class DataContract:
    """Schema contract between pipeline stages.

    Specifies the required columns, their types, and value constraints
    that a DataFrame must satisfy to be passed to the next stage.

    Attributes:
        name: Human-readable contract name.
        required_columns: Mapping from column name to expected dtype string.
        non_null_columns: Columns that must not contain any null values.
        min_rows: Minimum number of rows.
        value_constraints: Column-level value constraints
            (e.g., {'age': ('>=', 0), 'probability': ('<=', 1.0)}).
    """
    name: str
    required_columns: Dict[str, str] = field(default_factory=dict)
    non_null_columns: Set[str] = field(default_factory=set)
    min_rows: int = 0
    value_constraints: Dict[str, tuple] = field(default_factory=dict)

    def validate(self, df: pd.DataFrame) -> None:
        """Validate a DataFrame against this contract.

        Args:
            df: DataFrame to validate.

        Raises:
            ContractViolation: If any contract clause is violated.
        """
        violations = []

        # Check row count
        if len(df) < self.min_rows:
            violations.append(
                f"Row count {len(df)} below minimum {self.min_rows}."
            )

        # Check required columns
        for col, expected_dtype in self.required_columns.items():
            if col not in df.columns:
                violations.append(f"Missing required column: '{col}'.")
            elif not str(df[col].dtype).startswith(expected_dtype):
                violations.append(
                    f"Column '{col}' has dtype {df[col].dtype}, "
                    f"expected {expected_dtype}."
                )

        # Check non-null columns
        for col in self.non_null_columns:
            if col in df.columns and df[col].isna().any():
                null_count = df[col].isna().sum()
                violations.append(
                    f"Column '{col}' has {null_count} null values "
                    "(expected none)."
                )

        if violations:
            raise ContractViolation(self.name, violations)


class ContractViolation(Exception):
    """Raised when a data contract is violated."""

    def __init__(self, contract_name: str, violations: list):
        self.contract_name = contract_name
        self.violations = violations
        msg = f"Contract '{contract_name}' violated:\n" + "\n".join(
            f"  - {v}" for v in violations
        )
        super().__init__(msg)


# Define contracts for the StreamRec pipeline
validated_interactions_contract = DataContract(
    name="validated_interactions",
    required_columns={
        "user_id": "object",
        "item_id": "object",
        "event_type": "object",
        "timestamp": "datetime",
    },
    non_null_columns={"user_id", "item_id", "event_type", "timestamp"},
    min_rows=100_000,
)

training_features_contract = DataContract(
    name="training_features",
    required_columns={
        "user_id": "object",
        "item_id": "object",
        "watch_count_7d": "int",
        "avg_session_length": "float",
        "popularity_score": "float",
        "days_since_release": "int",
    },
    non_null_columns={"user_id", "item_id"},
    min_rows=100_000,
)

27.11 Designing the StreamRec Orchestration Architecture

With the framework comparison, failure handling patterns, and testing strategies established, we can now design the complete orchestration architecture for StreamRec. This architecture integrates the components from Chapters 24 (system design), 25 (data infrastructure), and 26 (distributed training) into a unified, automated pipeline.

Architecture Decision Record: Orchestration Framework Selection

ADR-027: Choice of Dagster for StreamRec Pipeline Orchestration

Status: Accepted

Context: StreamRec's ML platform team (3 ML engineers, 2 backend engineers, 1 data engineer) needs a pipeline orchestration framework for the daily training pipeline and supporting data pipelines. The pipeline produces versioned data and model artifacts, requires backfill capability, and must integrate with the existing feature store (Feast), experiment tracker (MLflow), and model registry.

Options Considered:

  1. Apache Airflow: Industry standard. Largest ecosystem. Team has prior experience. Task-centric model.
  2. Dagster: Asset-centric model. Strong data lineage. Better testing story. Growing ecosystem.
  3. Prefect: Minimal overhead. Python-native. Good for smaller teams. Cloud-first.

Decision: Dagster, for the following reasons:

  • Asset lineage directly supports the requirement to trace from a production model back to its training data, feature set version, and pipeline code. Airflow would require building this lineage tracking separately.
  • Partitioning is first-class in Dagster, simplifying backfill (select partitions to rematerialize) vs. Airflow (trigger individual DAG runs with date parameters).
  • IO Managers decouple computation from storage, enabling the team to swap between local filesystem (development), S3 (staging), and S3 with Delta Lake versioning (production) by changing a single resource configuration.
  • Testability. Assets are pure functions with typed inputs and outputs. The team can write pytest tests that call asset functions directly, without mocking the orchestrator runtime.

Consequences:

  • Smaller community than Airflow; fewer provider integrations available. The team will need to build custom resources for some integrations.
  • Team members with Airflow experience will require 1-2 weeks of ramp-up on Dagster concepts.
  • Dagster Cloud provides managed hosting; self-hosted deployment on Kubernetes is also an option with the dagster-k8s package.

The Full Pipeline DAG

wait_for_upstream_data (sensor)
    │
    ▼
raw_interactions (extract from data lake)
    │
    ▼
validated_interactions (schema + quality checks)
    │
    ▼
training_features (Feast point-in-time join)
    │
    ├───────────────────┐
    ▼                   ▼
retrieval_model    ranking_model
(two-tower, DDP)   (DCN-V2, DDP)
    │                   │
    ▼                   ▼
retrieval_eval     ranking_eval
(Recall@20,        (AUC, LogLoss)
 NDCG@20)
    │                   │
    └───────┬───────────┘
            ▼
    registered_models (MLflow Model Registry)
            │
            ▼
    trigger_canary_deployment (Chapter 29)

SLA and Monitoring

The pipeline must complete within a 6-hour window (2am–8am UTC) to ensure that fresh models are available for the morning traffic peak. The SLA monitoring tracks:

Metric Threshold Alert Severity
Pipeline duration > 5 hours P2 (warning)
Pipeline duration > 6 hours (SLA breach) P0 (page)
Any task failure after all retries P1
Model quality degradation NDCG@20 drop > 0.01 P2
Data validation failure P1
Backfill queue > 3 days P2

27.12 Advanced Patterns

Dynamic DAGs

Some pipelines require tasks that are determined at runtime. For example, StreamRec might add a new model type (an LLM-based re-ranker) that should be trained in parallel with the existing models. Dynamic DAGs allow the pipeline to discover which models to train from a configuration file or model registry.

In Dagster, dynamic assets use the @multi_asset decorator or build the asset graph programmatically:

from dagster import multi_asset, AssetOut, Output
from typing import Dict, Any, List


def get_model_configs() -> Dict[str, Dict[str, Any]]:
    """Load model configurations from the model registry.

    Returns:
        Dictionary mapping model names to their training configurations.
    """
    return {
        "retrieval": {"epochs": 10, "batch_size": 4096, "lr": 1e-3},
        "ranking": {"epochs": 5, "batch_size": 2048, "lr": 5e-4},
        "reranker": {"epochs": 3, "batch_size": 512, "lr": 1e-4},
    }


@multi_asset(
    outs={
        name: AssetOut(description=f"Trained {name} model")
        for name in get_model_configs()
    },
    required_resource_keys=set(),
)
def trained_models(context, training_features: Any):
    """Train all configured models in the model registry.

    Dynamically produces one asset per model configuration,
    enabling new models to be added without modifying pipeline code.
    """
    configs = get_model_configs()
    results = {}

    for model_name, config in configs.items():
        context.log.info(f"Training {model_name} with config: {config}")
        # Training logic here
        result = {"artifact_path": f"s3://models/{model_name}/", "config": config}
        results[model_name] = result

    return tuple(Output(results[name], output_name=name) for name in configs)

Conditional Execution and Trigger Rules

Not all evaluation failures should block model registration. If the retrieval model passes but the ranking model fails, the retrieval model should still be registered (the existing ranking model remains in production).

Airflow handles this with trigger rules (e.g., TriggerRule.ALL_DONE allows a task to run even if some upstream tasks failed). Dagster handles it with explicit conditional logic in the asset function. The key principle: fail the specific component, not the entire pipeline.

SLA-Based Scheduling

Instead of a fixed cron schedule, some organizations use SLA-based scheduling: "The model asset should be no more than 26 hours old." This allows the pipeline to tolerate occasional delays without triggering a false SLA breach. If the pipeline finishes at 4am instead of 3am, the model is still within the 26-hour freshness window, and no alert fires.

In Dagster, freshness policies implement this:

from dagster import asset, FreshnessPolicy

@asset(
    freshness_policy=FreshnessPolicy(
        maximum_lag_minutes=26 * 60,  # 26 hours
        cron_schedule="0 2 * * *",     # Expected materialization at 2am
    ),
)
def retrieval_model(training_features):
    """Train the retrieval model with a 26-hour freshness SLA."""
    pass  # Training logic

27.13 Pipeline Anti-Patterns

Having covered the patterns, it is equally important to catalog the anti-patterns — the recurring mistakes that cause pipeline failures in production.

1. The Monolithic Task. A single task that extracts, validates, transforms, trains, evaluates, and registers. When it fails at minute 90 of a 120-minute run, the entire 90 minutes of work must be repeated. Fix: Decompose into granular, idempotent tasks with checkpointed intermediate outputs.

2. The Hidden State Dependency. A task reads from a shared mutable resource (a database table, a global configuration file) that can change between runs. The pipeline produces different results when rerun even with the same logical date. Fix: Snapshot inputs at the beginning of the run and pass snapshots through the pipeline.

3. The Silent Swallow. A task catches all exceptions and returns a default value instead of failing. Downstream tasks run on degraded data without any signal that something went wrong. Fix: Let tasks fail loudly; use the orchestrator's retry and alerting mechanisms instead of in-task exception suppression.

4. The Shared Mutable Artifact. Multiple tasks write to the same output file or database table. Race conditions, partial writes, and overwrites are inevitable. Fix: Each task writes to a unique, partition-keyed path. A final "commit" task promotes the staged outputs.

5. The Schedule Waterfall. Pipeline B starts at 3am because pipeline A "usually finishes by 2:30am." When pipeline A runs 15 minutes late, pipeline B processes incomplete data. Fix: Use sensors or event-driven triggers instead of fixed schedules.

6. The Backfill Afterthought. The pipeline was designed for daily forward execution and breaks when run for historical dates (hardcoded timestamps, non-partitioned outputs, queries that always fetch "yesterday"). Fix: Design for backfill from day one. Use logical dates, not wall-clock time, for all data references.


27.14 Orchestration as Engineering Discipline

This chapter has covered the orchestration layer of the production ML stack — the component that transforms a collection of scripts into a reliable, automated, self-monitoring pipeline.

The key ideas are:

DAGs enforce execution order. By modeling the pipeline as a directed acyclic graph, the orchestrator can compute valid execution orders, identify parallelism opportunities, and propagate failure states through the dependency chain.

Three frameworks, three philosophies. Airflow (imperative, task-centric) excels in diverse-workload environments with its massive operator ecosystem. Dagster (declarative, asset-centric) excels in data-intensive ML environments where lineage, freshness, and testability are priorities. Prefect (Python-native, minimal-overhead) excels when teams want to orchestrate existing Python code with minimal framework adoption cost. All three implement the same fundamental concepts — DAGs, retries, scheduling, backfill — with different organizing principles.

Idempotency is the cardinal virtue. A task that can be safely rerun is a task that can be safely retried, backfilled, and tested. Partition-based overwrites, deterministic paths, and atomic writes are the primary techniques.

Failure handling is not optional. Retry policies with exponential backoff, structured alerting by severity, dead letter queues for partial failures, and SLA monitoring are the mechanisms that allow pipelines to run unattended. A pipeline without these mechanisms requires a human operator — and a human operator is a single point of failure.

Testing at three levels. Unit tests verify individual task logic. Integration tests verify end-to-end data flow. Contract tests verify inter-stage schema compatibility. All three are necessary; none alone is sufficient.

Production ML = Software Engineering: Pipeline orchestration is where the "software engineering" part of production ML is most visible. The DAG is a software architecture diagram. The retry policy is an error-handling strategy. The data contract is an API specification. The backfill is a database migration. Every concept from software engineering applies, and the tools of the orchestration ecosystem (Airflow, Dagster, Prefect) are the ML-specific instantiations of those concepts.

The pipeline is now orchestrated. Chapter 28 adds the testing and validation infrastructure that ensures the pipeline's outputs meet quality standards before reaching production. Chapter 29 builds the CI/CD system that deploys validated models through canary, shadow, and blue-green deployment patterns.


Summary

ML pipeline orchestration automates the execution, monitoring, and failure recovery of multi-step ML workflows. Pipelines are modeled as directed acyclic graphs (DAGs), where nodes are tasks and edges are dependencies. The DAG structure enforces execution order and enables parallelism.

Apache Airflow uses an imperative, task-centric model: you define tasks and their execution order. It has the largest ecosystem (80+ provider packages) and the deepest production track record, but its task-centric model makes data lineage opaque and testing difficult.

Dagster uses a declarative, asset-centric model: you define the data assets that should exist and their derivation logic. The DAG is inferred from asset dependencies. First-class support for partitioning, IO managers, and metadata makes it well-suited for ML pipelines where data lineage and testability are priorities.

Prefect uses a Python-native model: @flow and @task decorators on standard Python functions. It minimizes framework adoption cost and provides built-in caching, concurrent execution via .submit(), and rich observability artifacts.

Idempotency — the property that rerunning a task produces the same output — enables safe retries and backfills. Partition-based overwrites, atomic writes, and deterministic artifact paths are the primary implementation techniques.

Failure handling uses retry policies with exponential backoff ($d_k = d_0 \cdot 2^{k-1}$), severity-tiered alerting, dead letter queues for partial data failures, and SLA monitoring to ensure pipelines run reliably without human intervention.

Backfill reprocesses historical data intervals after pipeline downtime or logic changes. Sequential, parallel, and prioritized strategies trade off between safety and speed. Version-tagged output paths prevent backfill from overwriting production data.

Pipeline testing operates at three levels: unit tests for individual task logic, integration tests for end-to-end data flow, and contract tests for inter-stage schema compatibility.