Case Study 2: Systematic Prompt Optimization

Overview

In this case study, we build a systematic prompt optimization pipeline that treats prompt engineering as a measurable optimization problem. Starting with a baseline prompt for multi-label text classification, we iteratively improve it through structured experimentation, automated evaluation, self-consistency, and A/B testing. We implement the full optimization loop: define metrics, generate prompt variants, evaluate against a test set, analyze failure modes, and converge on a high-quality prompt.

This case study demonstrates concepts from Sections 23.12 (Evaluating Prompt Quality) and 23.14 (Prompting Decision Framework), showing how to move from ad-hoc prompt tweaking to rigorous, data-driven prompt optimization.

Learning Objectives

  • Implement a systematic prompt optimization loop with measurable metrics.
  • Build an automated evaluation pipeline for prompt variants.
  • Apply self-consistency to improve classification accuracy.
  • Conduct A/B testing with statistical significance testing.
  • Use an LLM-as-judge for evaluating open-ended outputs.
  • Analyze and address common failure modes in prompt-based systems.

Scenario

You are building a prompt-based system that classifies research paper abstracts into one or more of the following domains: machine_learning, natural_language_processing, computer_vision, robotics, reinforcement_learning, theory. A paper can belong to multiple domains. The system must output a JSON object with the selected domains and a brief justification for each.

Step 1: Baseline Prompt and Test Data

"""Systematic prompt optimization pipeline.

This module implements an iterative optimization loop for improving
prompt quality through structured experimentation and evaluation.
"""

import json
import random
import re
from dataclasses import dataclass, field
from typing import Optional

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


torch.manual_seed(42)
random.seed(42)

VALID_DOMAINS = {
    "machine_learning",
    "natural_language_processing",
    "computer_vision",
    "robotics",
    "reinforcement_learning",
    "theory",
}


@dataclass
class ClassificationResult:
    """Result of classifying a single abstract.

    Attributes:
        domains: Set of predicted domain labels.
        justifications: Mapping from domain to justification string.
        raw_output: The raw model output before parsing.
    """

    domains: set[str]
    justifications: dict[str, str]
    raw_output: str


@dataclass
class TestExample:
    """A labeled test example for evaluation.

    Attributes:
        abstract: The paper abstract text.
        true_domains: The ground-truth domain labels.
    """

    abstract: str
    true_domains: set[str]


# Curated test set with ground-truth labels
TEST_SET: list[TestExample] = [
    TestExample(
        abstract=(
            "We propose a novel attention mechanism for Transformer models "
            "that reduces quadratic complexity to linear while maintaining "
            "performance on machine translation and text summarization tasks."
        ),
        true_domains={"machine_learning", "natural_language_processing"},
    ),
    TestExample(
        abstract=(
            "This paper presents a self-supervised learning framework for "
            "object detection that leverages contrastive learning on "
            "unlabeled image data, achieving state-of-the-art results on "
            "COCO and Pascal VOC benchmarks."
        ),
        true_domains={"machine_learning", "computer_vision"},
    ),
    TestExample(
        abstract=(
            "We derive tight sample complexity bounds for learning halfspaces "
            "under the uniform distribution, improving the best known results "
            "by a factor of log(n)."
        ),
        true_domains={"theory", "machine_learning"},
    ),
    TestExample(
        abstract=(
            "Our approach combines model-free reinforcement learning with "
            "sim-to-real transfer to train a quadruped robot to walk over "
            "rough terrain using only onboard sensors."
        ),
        true_domains={"reinforcement_learning", "robotics"},
    ),
    TestExample(
        abstract=(
            "We introduce a new benchmark for evaluating large language "
            "models on multi-hop reasoning tasks, and show that chain-of-"
            "thought prompting significantly improves performance."
        ),
        true_domains={"natural_language_processing", "machine_learning"},
    ),
    TestExample(
        abstract=(
            "This work proposes a diffusion-based generative model for "
            "high-resolution image synthesis that achieves superior FID "
            "scores compared to GANs on ImageNet 256x256."
        ),
        true_domains={"computer_vision", "machine_learning"},
    ),
    TestExample(
        abstract=(
            "We prove that no polynomial-time algorithm can approximate the "
            "maximum clique problem within a factor of n^(1-epsilon) unless "
            "P=NP, resolving a long-standing open question."
        ),
        true_domains={"theory"},
    ),
    TestExample(
        abstract=(
            "Our multi-agent reinforcement learning algorithm enables a team "
            "of drones to cooperatively map an unknown environment while "
            "avoiding collisions, using decentralized communication."
        ),
        true_domains={"reinforcement_learning", "robotics"},
    ),
]

Step 2: Prompt Variants

We define multiple prompt variants to test systematically.

# Variant A: Minimal zero-shot prompt (baseline)
PROMPT_VARIANT_A = """Classify the following research paper abstract into one \
or more domains: machine_learning, natural_language_processing, \
computer_vision, robotics, reinforcement_learning, theory.

Output a JSON object with a "domains" array and a "justifications" object.

Abstract: {abstract}

JSON:"""


# Variant B: Detailed instructions with output schema
PROMPT_VARIANT_B = """You are a research paper classifier. Classify the \
following abstract into one or more of these domains:

- machine_learning: Papers about ML algorithms, optimization, neural \
architectures, training methods
- natural_language_processing: Papers about text, language models, \
translation, summarization, NLU/NLG
- computer_vision: Papers about images, video, object detection, \
segmentation, generation of visual content
- robotics: Papers about robot control, manipulation, navigation, \
sim-to-real, embodied systems
- reinforcement_learning: Papers about RL algorithms, policies, rewards, \
exploration, multi-agent RL
- theory: Papers about mathematical proofs, complexity bounds, information \
theory, statistical learning theory

Rules:
1. Assign ALL applicable domains (usually 1-3).
2. A paper about "using ML for computer vision" belongs to BOTH domains.
3. Provide a one-sentence justification for each selected domain.

Output format (JSON only, no other text):
{{
  "domains": ["domain1", "domain2"],
  "justifications": {{
    "domain1": "reason...",
    "domain2": "reason..."
  }}
}}

Abstract: {abstract}

JSON:"""


# Variant C: Few-shot with chain-of-thought
PROMPT_VARIANT_C = """Classify research paper abstracts into domains. Think \
through the classification step by step.

Domains: machine_learning, natural_language_processing, computer_vision, \
robotics, reinforcement_learning, theory

Example 1:
Abstract: "We present a new convolutional architecture for semantic \
segmentation that achieves real-time performance on autonomous driving \
datasets."
Reasoning: This paper proposes a neural network architecture (machine \
learning) specifically for semantic segmentation of images (computer vision). \
It targets autonomous driving but does not focus on robot control itself.
{{"domains": ["machine_learning", "computer_vision"], "justifications": \
{{"machine_learning": "Proposes a new convolutional architecture.", \
"computer_vision": "Addresses semantic segmentation of images."}}}}

Example 2:
Abstract: "We analyze the convergence rate of stochastic gradient descent \
for overparameterized neural networks and show it achieves a global minimum \
at a linear rate."
Reasoning: This is a theoretical analysis (theory) of an optimization \
method used in machine learning (machine_learning). It does not target a \
specific application domain.
{{"domains": ["theory", "machine_learning"], "justifications": \
{{"theory": "Provides convergence rate analysis and proofs.", \
"machine_learning": "Analyzes SGD for neural network optimization."}}}}

Now classify:
Abstract: {abstract}
Reasoning:"""


def format_prompt(template: str, abstract: str) -> str:
    """Fill a prompt template with the given abstract.

    Args:
        template: Prompt template with {abstract} placeholder.
        abstract: The paper abstract to classify.

    Returns:
        Formatted prompt string.
    """
    return template.format(abstract=abstract)

Step 3: Parsing and Evaluation Metrics

def parse_classification(raw_output: str) -> Optional[ClassificationResult]:
    """Parse a classification response into structured form.

    Args:
        raw_output: Raw text output from the model.

    Returns:
        Parsed result, or None if parsing fails.
    """
    # Try to extract JSON from the output
    json_match = re.search(r"\{.*\}", raw_output, re.DOTALL)
    if not json_match:
        return None

    try:
        data = json.loads(json_match.group())
    except json.JSONDecodeError:
        return None

    if "domains" not in data or not isinstance(data["domains"], list):
        return None

    domains = set(data["domains"]) & VALID_DOMAINS
    if not domains:
        return None

    justifications = data.get("justifications", {})
    if not isinstance(justifications, dict):
        justifications = {}

    return ClassificationResult(
        domains=domains,
        justifications={
            k: str(v) for k, v in justifications.items() if k in domains
        },
        raw_output=raw_output,
    )


@dataclass
class PromptMetrics:
    """Evaluation metrics for a prompt variant.

    Attributes:
        variant_name: Name of the prompt variant.
        exact_match_accuracy: Fraction where predicted == true domains exactly.
        precision: Macro-averaged precision across examples.
        recall: Macro-averaged recall across examples.
        f1: Macro-averaged F1 score.
        format_compliance: Fraction of valid JSON outputs.
        avg_num_domains: Average number of domains predicted.
        failure_examples: Indices of incorrectly classified examples.
    """

    variant_name: str
    exact_match_accuracy: float
    precision: float
    recall: float
    f1: float
    format_compliance: float
    avg_num_domains: float
    failure_examples: list[int] = field(default_factory=list)


def compute_metrics(
    variant_name: str,
    test_set: list[TestExample],
    predictions: list[Optional[ClassificationResult]],
) -> PromptMetrics:
    """Compute evaluation metrics for a set of predictions.

    Args:
        variant_name: Name of the prompt variant being evaluated.
        test_set: The labeled test examples.
        predictions: Model predictions for each test example.

    Returns:
        Comprehensive metrics for the prompt variant.
    """
    total = len(test_set)
    exact_matches = 0
    total_precision = 0.0
    total_recall = 0.0
    valid_outputs = 0
    total_domains_predicted = 0
    failures: list[int] = []

    for i, (example, pred) in enumerate(zip(test_set, predictions)):
        if pred is None:
            failures.append(i)
            continue

        valid_outputs += 1
        total_domains_predicted += len(pred.domains)

        true_set = example.true_domains
        pred_set = pred.domains

        if pred_set == true_set:
            exact_matches += 1

        # Per-example precision and recall
        if len(pred_set) > 0:
            p = len(pred_set & true_set) / len(pred_set)
        else:
            p = 0.0
        if len(true_set) > 0:
            r = len(pred_set & true_set) / len(true_set)
        else:
            r = 0.0

        total_precision += p
        total_recall += r

        if pred_set != true_set:
            failures.append(i)

    avg_p = total_precision / total if total > 0 else 0.0
    avg_r = total_recall / total if total > 0 else 0.0
    f1 = (
        2 * avg_p * avg_r / (avg_p + avg_r) if (avg_p + avg_r) > 0 else 0.0
    )

    return PromptMetrics(
        variant_name=variant_name,
        exact_match_accuracy=exact_matches / total if total > 0 else 0.0,
        precision=avg_p,
        recall=avg_r,
        f1=f1,
        format_compliance=valid_outputs / total if total > 0 else 0.0,
        avg_num_domains=(
            total_domains_predicted / valid_outputs if valid_outputs > 0 else 0.0
        ),
        failure_examples=failures,
    )

Step 4: Self-Consistency for Multi-Label Classification

def classify_with_self_consistency(
    abstract: str,
    prompt_template: str,
    tokenizer: AutoTokenizer,
    model: AutoModelForCausalLM,
    n_samples: int = 5,
    temperature: float = 0.7,
    threshold: float = 0.5,
) -> Optional[ClassificationResult]:
    """Classify with self-consistency over multiple samples.

    For multi-label classification, we use a per-domain voting scheme:
    a domain is included if it appears in at least threshold * n_samples
    of the sampled responses.

    Args:
        abstract: The paper abstract to classify.
        prompt_template: The prompt template to use.
        tokenizer: HuggingFace tokenizer.
        model: HuggingFace causal language model.
        n_samples: Number of reasoning paths to sample.
        temperature: Sampling temperature for diversity.
        threshold: Minimum vote fraction to include a domain.

    Returns:
        Aggregated classification result, or None if all samples fail.
    """
    prompt = format_prompt(prompt_template, abstract)
    inputs = tokenizer(prompt, return_tensors="pt")

    domain_votes: dict[str, int] = {d: 0 for d in VALID_DOMAINS}
    domain_justifications: dict[str, list[str]] = {d: [] for d in VALID_DOMAINS}
    valid_samples = 0

    for _ in range(n_samples):
        with torch.no_grad():
            output_ids = model.generate(
                **inputs,
                max_new_tokens=300,
                temperature=temperature,
                do_sample=True,
                top_p=0.95,
            )
        raw_output = tokenizer.decode(
            output_ids[0][inputs["input_ids"].shape[1]:],
            skip_special_tokens=True,
        )

        result = parse_classification(raw_output)
        if result is not None:
            valid_samples += 1
            for domain in result.domains:
                domain_votes[domain] += 1
                if domain in result.justifications:
                    domain_justifications[domain].append(
                        result.justifications[domain]
                    )

    if valid_samples == 0:
        return None

    # Include domains that exceed the voting threshold
    min_votes = int(threshold * valid_samples)
    selected_domains = {
        d for d, v in domain_votes.items() if v > min_votes
    }

    if not selected_domains:
        # Fallback: select the domain with the most votes
        max_votes = max(domain_votes.values())
        selected_domains = {d for d, v in domain_votes.items() if v == max_votes}

    # Use the most common justification for each domain
    justifications = {}
    for domain in selected_domains:
        if domain_justifications[domain]:
            justifications[domain] = domain_justifications[domain][0]
        else:
            justifications[domain] = "Selected by majority vote."

    return ClassificationResult(
        domains=selected_domains,
        justifications=justifications,
        raw_output=f"Self-consistency over {valid_samples}/{n_samples} valid samples",
    )

Step 5: A/B Testing Framework

import math


def compute_z_test(
    successes_a: int,
    total_a: int,
    successes_b: int,
    total_b: int,
) -> tuple[float, float]:
    """Compute a two-proportion z-test for A/B testing.

    Args:
        successes_a: Number of successes for variant A.
        total_a: Total trials for variant A.
        successes_b: Number of successes for variant B.
        total_b: Total trials for variant B.

    Returns:
        Tuple of (z_statistic, p_value).
    """
    p_a = successes_a / total_a if total_a > 0 else 0.0
    p_b = successes_b / total_b if total_b > 0 else 0.0

    # Pooled proportion
    p_pool = (successes_a + successes_b) / (total_a + total_b)
    se = math.sqrt(p_pool * (1 - p_pool) * (1 / total_a + 1 / total_b))

    if se == 0:
        return 0.0, 1.0

    z = (p_a - p_b) / se

    # Two-tailed p-value using normal approximation
    p_value = 2 * (1 - _normal_cdf(abs(z)))
    return z, p_value


def _normal_cdf(x: float) -> float:
    """Approximate the standard normal CDF using the error function.

    Args:
        x: Input value.

    Returns:
        Approximate CDF value.
    """
    return 0.5 * (1 + math.erf(x / math.sqrt(2)))


def run_ab_test(
    metrics_a: PromptMetrics,
    metrics_b: PromptMetrics,
    test_size: int,
    alpha: float = 0.05,
) -> dict:
    """Run an A/B test comparing two prompt variants.

    Args:
        metrics_a: Metrics for prompt variant A.
        metrics_b: Metrics for prompt variant B.
        test_size: Number of test examples.
        alpha: Significance level for the test.

    Returns:
        Dictionary with test results and recommendation.
    """
    exact_a = int(metrics_a.exact_match_accuracy * test_size)
    exact_b = int(metrics_b.exact_match_accuracy * test_size)

    z_stat, p_value = compute_z_test(exact_a, test_size, exact_b, test_size)

    if p_value < alpha:
        if metrics_a.exact_match_accuracy > metrics_b.exact_match_accuracy:
            winner = metrics_a.variant_name
        else:
            winner = metrics_b.variant_name
        significant = True
    else:
        winner = "No significant difference"
        significant = True if p_value < alpha else False

    return {
        "variant_a": metrics_a.variant_name,
        "variant_b": metrics_b.variant_name,
        "accuracy_a": metrics_a.exact_match_accuracy,
        "accuracy_b": metrics_b.exact_match_accuracy,
        "f1_a": metrics_a.f1,
        "f1_b": metrics_b.f1,
        "z_statistic": z_stat,
        "p_value": p_value,
        "significant": significant,
        "winner": winner,
    }

Step 6: LLM-as-Judge for Justification Quality

JUDGE_PROMPT = """Rate the quality of the following domain classification \
justification on a scale of 1-5.

Scoring rubric:
1 - Completely wrong or irrelevant justification
2 - Partially correct but vague or misleading
3 - Correct but generic (could apply to many papers)
4 - Correct and specific to the paper's content
5 - Correct, specific, and insightful (identifies the key contribution)

Paper abstract: {abstract}
Assigned domain: {domain}
Justification: {justification}

Output ONLY a JSON object:
{{"score": <1-5>, "reasoning": "<brief explanation>"}}

JSON:"""


def build_judge_prompt(
    abstract: str,
    domain: str,
    justification: str,
) -> str:
    """Build an LLM-as-judge prompt for evaluating justification quality.

    Args:
        abstract: The paper abstract.
        domain: The assigned domain label.
        justification: The justification text to evaluate.

    Returns:
        Formatted judge prompt string.
    """
    return JUDGE_PROMPT.format(
        abstract=abstract,
        domain=domain,
        justification=justification,
    )

Step 7: Failure Analysis and Iteration

def analyze_failures(
    test_set: list[TestExample],
    predictions: list[Optional[ClassificationResult]],
    metrics: PromptMetrics,
) -> dict:
    """Analyze failure patterns in predictions.

    Args:
        test_set: The labeled test examples.
        predictions: Model predictions for each example.
        metrics: Computed metrics including failure indices.

    Returns:
        Dictionary describing failure patterns.
    """
    patterns: dict[str, list[int]] = {
        "parse_failure": [],
        "missing_domain": [],
        "extra_domain": [],
        "complete_mismatch": [],
    }

    for i in metrics.failure_examples:
        pred = predictions[i]
        true_domains = test_set[i].true_domains

        if pred is None:
            patterns["parse_failure"].append(i)
        else:
            missing = true_domains - pred.domains
            extra = pred.domains - true_domains

            if missing and extra:
                patterns["complete_mismatch"].append(i)
            elif missing:
                patterns["missing_domain"].append(i)
            elif extra:
                patterns["extra_domain"].append(i)

    # Compute which domains are most commonly missed or over-predicted
    missed_domains: dict[str, int] = {}
    extra_domains: dict[str, int] = {}

    for i in metrics.failure_examples:
        pred = predictions[i]
        true_domains = test_set[i].true_domains
        if pred is not None:
            for d in true_domains - pred.domains:
                missed_domains[d] = missed_domains.get(d, 0) + 1
            for d in pred.domains - true_domains:
                extra_domains[d] = extra_domains.get(d, 0) + 1

    return {
        "total_failures": len(metrics.failure_examples),
        "patterns": {k: len(v) for k, v in patterns.items()},
        "most_missed_domains": missed_domains,
        "most_over_predicted_domains": extra_domains,
    }


def print_optimization_report(
    all_metrics: list[PromptMetrics],
    test_size: int,
) -> None:
    """Print a comparative report across all prompt variants.

    Args:
        all_metrics: List of metrics for each variant tested.
        test_size: Number of test examples.
    """
    print("=" * 70)
    print("PROMPT OPTIMIZATION REPORT")
    print("=" * 70)
    print(
        f"{'Variant':<15s} {'Exact Match':<13s} {'Precision':<11s} "
        f"{'Recall':<9s} {'F1':<8s} {'Format %':<10s}"
    )
    print("-" * 70)
    for m in all_metrics:
        print(
            f"{m.variant_name:<15s} {m.exact_match_accuracy:<13.1%} "
            f"{m.precision:<11.3f} {m.recall:<9.3f} {m.f1:<8.3f} "
            f"{m.format_compliance:<10.1%}"
        )
    print("-" * 70)

    # Identify best variant
    best = max(all_metrics, key=lambda m: m.f1)
    print(f"\nBest variant by F1: {best.variant_name} (F1 = {best.f1:.3f})")
    print("=" * 70)

Step 8: Demonstration

if __name__ == "__main__":
    print("Systematic Prompt Optimization Pipeline")
    print("=" * 50)
    print(f"Test set size: {len(TEST_SET)}")
    print(f"Valid domains: {sorted(VALID_DOMAINS)}")
    print(f"Prompt variants: A (minimal), B (detailed), C (few-shot CoT)")
    print()

    # Show prompt lengths
    sample_abstract = TEST_SET[0].abstract
    for name, template in [
        ("A", PROMPT_VARIANT_A),
        ("B", PROMPT_VARIANT_B),
        ("C", PROMPT_VARIANT_C),
    ]:
        prompt = format_prompt(template, sample_abstract)
        print(f"Variant {name} prompt length: {len(prompt)} chars")

    print()
    print("Optimization loop:")
    print("  1. Evaluate all variants on test set")
    print("  2. Analyze failures of best variant")
    print("  3. Apply self-consistency to improve accuracy")
    print("  4. A/B test improved variant against baseline")
    print("  5. Assess justification quality with LLM-as-judge")
    print("  6. Iterate until target metrics are met")

Key Takeaways

  1. Treat prompt engineering as optimization, not guesswork. Define metrics, create a test set, and iterate systematically.
  2. Detailed instructions outperform minimal prompts in most cases. Providing domain definitions and explicit rules improves both accuracy and format compliance.
  3. Few-shot CoT provides the largest gains on classification tasks that require nuanced reasoning about domain boundaries.
  4. Self-consistency is particularly effective for multi-label tasks where the per-domain voting scheme naturally handles ambiguity.
  5. Failure analysis reveals specific, actionable patterns. Understanding whether errors come from missing domains, extra domains, or parse failures guides the next round of optimization.
  6. A/B testing with statistical significance prevents overfitting to small differences that may be due to random variation.