Case Study 2: Building an Instruction-Following Model

Overview

In this case study, we build an instruction-following model from a base language model using supervised fine-tuning (SFT). We cover the entire process: curating a diverse instruction dataset, applying quality filters, formatting data with chat templates, training with LoRA using TRL, evaluating instruction-following quality with multiple metrics, and measuring catastrophic forgetting. The goal is to transform a base model into one that reliably follows diverse natural language instructions.

Learning Objectives

  • Curate a diverse, high-quality instruction-tuning dataset.
  • Apply quality filtering to remove low-quality examples.
  • Implement the full SFT pipeline with LoRA using TRL.
  • Evaluate instruction-following with automated and LLM-as-judge metrics.
  • Measure and mitigate catastrophic forgetting through data mixing.

Step 1: Dataset Curation

"""Building an instruction-following model via supervised fine-tuning.

This case study covers dataset curation, quality filtering, SFT training,
and comprehensive evaluation including forgetting analysis.

Requirements:
    pip install torch transformers peft trl datasets
"""

import json
import re
from dataclasses import dataclass, field

import torch
from datasets import Dataset, concatenate_datasets

torch.manual_seed(42)

# ---------------------------------------------------------------------------
# Instruction dataset structure
# ---------------------------------------------------------------------------

TASK_CATEGORIES = [
    "summarization",
    "classification",
    "extraction",
    "generation",
    "reasoning",
    "coding",
    "math",
    "rewriting",
    "question_answering",
    "brainstorming",
]


@dataclass
class InstructionExample:
    """A single instruction-tuning example.

    Attributes:
        task_category: Category of the task.
        instruction: The natural language instruction.
        input_text: Optional input context.
        output: The desired response.
        quality_score: Quality rating (1-5) for filtering.
    """

    task_category: str
    instruction: str
    input_text: str
    output: str
    quality_score: float = 5.0


def create_diverse_dataset() -> list[InstructionExample]:
    """Create a diverse instruction dataset covering all task categories.

    Returns:
        List of InstructionExample objects.
    """
    examples = [
        # Summarization
        InstructionExample(
            task_category="summarization",
            instruction="Summarize the following paragraph in one sentence.",
            input_text=(
                "Machine learning models require large amounts of data "
                "to learn patterns. The quality of data directly affects "
                "model performance. Data augmentation, cleaning, and "
                "feature engineering are critical preprocessing steps."
            ),
            output=(
                "Effective machine learning depends on large quantities "
                "of high-quality data, with preprocessing steps like "
                "augmentation and feature engineering being crucial."
            ),
        ),
        # Classification
        InstructionExample(
            task_category="classification",
            instruction=(
                "Classify the sentiment of the following product review "
                "as positive, negative, or neutral."
            ),
            input_text="The battery life is incredible but the screen is dim.",
            output=(
                "Neutral. The review contains both positive feedback "
                "(battery life) and negative feedback (dim screen)."
            ),
        ),
        # Extraction
        InstructionExample(
            task_category="extraction",
            instruction="Extract all dates mentioned in the following text.",
            input_text=(
                "The project started on January 15, 2024 and the first "
                "milestone was completed on March 1, 2024. The final "
                "deadline is December 31, 2024."
            ),
            output=(
                "1. January 15, 2024 (project start)\n"
                "2. March 1, 2024 (first milestone)\n"
                "3. December 31, 2024 (final deadline)"
            ),
        ),
        # Generation
        InstructionExample(
            task_category="generation",
            instruction=(
                "Write a professional email declining a meeting invitation "
                "due to a scheduling conflict."
            ),
            input_text="",
            output=(
                "Subject: Unable to Attend - Meeting on Thursday\n\n"
                "Dear [Name],\n\n"
                "Thank you for the meeting invitation. Unfortunately, "
                "I have a prior commitment at that time and will not be "
                "able to attend. Could we reschedule, or would it be "
                "possible to receive the meeting notes afterward?\n\n"
                "Best regards,\n[Your Name]"
            ),
        ),
        # Reasoning
        InstructionExample(
            task_category="reasoning",
            instruction="Explain why the conclusion does not follow from the premise.",
            input_text=(
                'Premise: "All dogs are mammals." '
                'Conclusion: "All mammals are dogs."'
            ),
            output=(
                "This is a logical fallacy known as affirming the consequent "
                "(or converse error). While all dogs are indeed mammals, the "
                "reverse is not true. Mammals include cats, whales, humans, "
                "and many other species. The set of dogs is a subset of "
                "mammals, not equal to it."
            ),
        ),
        # Coding
        InstructionExample(
            task_category="coding",
            instruction="Write a Python function to compute the Fibonacci sequence up to n terms.",
            input_text="",
            output=(
                "```python\n"
                "def fibonacci(n: int) -> list[int]:\n"
                '    """Compute Fibonacci sequence up to n terms.\n\n'
                "    Args:\n"
                "        n: Number of terms to generate.\n\n"
                "    Returns:\n"
                "        List of Fibonacci numbers.\n"
                '    """\n'
                "    if n <= 0:\n"
                "        return []\n"
                "    if n == 1:\n"
                "        return [0]\n"
                "    fib = [0, 1]\n"
                "    for _ in range(2, n):\n"
                "        fib.append(fib[-1] + fib[-2])\n"
                "    return fib\n"
                "```"
            ),
        ),
        # Math
        InstructionExample(
            task_category="math",
            instruction="Solve the following equation and show your work.",
            input_text="3x + 7 = 22",
            output=(
                "Step 1: Subtract 7 from both sides.\n"
                "  3x + 7 - 7 = 22 - 7\n"
                "  3x = 15\n\n"
                "Step 2: Divide both sides by 3.\n"
                "  3x / 3 = 15 / 3\n"
                "  x = 5\n\n"
                "The solution is x = 5."
            ),
        ),
        # Rewriting
        InstructionExample(
            task_category="rewriting",
            instruction="Rewrite the following sentence in active voice.",
            input_text="The report was written by the team last Friday.",
            output="The team wrote the report last Friday.",
        ),
        # Question Answering
        InstructionExample(
            task_category="question_answering",
            instruction="Answer the question based on the given context.",
            input_text=(
                "Context: The Great Wall of China stretches over 13,000 "
                "miles. Construction began in the 7th century BC and "
                "continued for centuries. Question: How long is the "
                "Great Wall of China?"
            ),
            output=(
                "According to the context, the Great Wall of China "
                "stretches over 13,000 miles."
            ),
        ),
        # Brainstorming
        InstructionExample(
            task_category="brainstorming",
            instruction="List 5 creative ways to reduce paper waste in an office.",
            input_text="",
            output=(
                "1. Implement a digital-first policy where all documents "
                "are shared electronically by default.\n"
                "2. Set up a paper reuse station where single-sided prints "
                "are collected for scratch paper.\n"
                "3. Install tablet devices in meeting rooms to replace "
                "printed agendas and handouts.\n"
                "4. Create a monthly challenge where teams compete to "
                "minimize their paper consumption.\n"
                "5. Replace paper towels with air dryers and provide "
                "reusable cloth alternatives."
            ),
        ),
    ]
    return examples

Step 2: Quality Filtering

@dataclass
class QualityMetrics:
    """Quality metrics for a single instruction example.

    Attributes:
        response_length: Length of the output in characters.
        has_instruction: Whether the instruction is non-empty.
        has_output: Whether the output is non-empty.
        is_not_repetitive: Whether the output avoids excessive repetition.
        formatting_consistent: Whether formatting is clean.
        overall_pass: Whether the example passes all quality checks.
    """

    response_length: int
    has_instruction: bool
    has_output: bool
    is_not_repetitive: bool
    formatting_consistent: bool
    overall_pass: bool


def compute_quality_metrics(example: InstructionExample) -> QualityMetrics:
    """Compute quality metrics for an instruction example.

    Args:
        example: The instruction example to evaluate.

    Returns:
        Quality metrics for the example.
    """
    response_length = len(example.output)
    has_instruction = len(example.instruction.strip()) > 10
    has_output = len(example.output.strip()) > 5

    # Check for excessive repetition (same trigram repeated 3+ times)
    words = example.output.lower().split()
    trigrams = [
        " ".join(words[i:i + 3]) for i in range(len(words) - 2)
    ]
    trigram_counts = {}
    for tg in trigrams:
        trigram_counts[tg] = trigram_counts.get(tg, 0) + 1
    max_repeat = max(trigram_counts.values()) if trigram_counts else 0
    is_not_repetitive = max_repeat < 3

    # Check formatting
    formatting_consistent = not example.output.startswith(" ") and len(
        example.output.strip()
    ) == len(example.output.rstrip())

    overall_pass = all([
        has_instruction,
        has_output,
        is_not_repetitive,
        formatting_consistent,
        response_length >= 10,
        response_length <= 5000,
    ])

    return QualityMetrics(
        response_length=response_length,
        has_instruction=has_instruction,
        has_output=has_output,
        is_not_repetitive=is_not_repetitive,
        formatting_consistent=formatting_consistent,
        overall_pass=overall_pass,
    )


def filter_dataset(
    examples: list[InstructionExample],
    min_quality_score: float = 3.0,
) -> list[InstructionExample]:
    """Filter a dataset based on quality metrics.

    Args:
        examples: List of instruction examples.
        min_quality_score: Minimum quality score to keep.

    Returns:
        Filtered list of examples that pass quality checks.
    """
    filtered = []
    for ex in examples:
        metrics = compute_quality_metrics(ex)
        if metrics.overall_pass and ex.quality_score >= min_quality_score:
            filtered.append(ex)
    return filtered

Step 3: Data Mixing for Forgetting Prevention

def create_mixed_dataset(
    task_examples: list[InstructionExample],
    general_examples: list[InstructionExample],
    task_ratio: float = 0.7,
) -> Dataset:
    """Create a mixed dataset combining task and general data.

    Mixes task-specific data with general instruction data to prevent
    catastrophic forgetting during fine-tuning.

    Args:
        task_examples: Task-specific instruction examples.
        general_examples: General-purpose instruction examples.
        task_ratio: Fraction of the dataset that is task-specific.

    Returns:
        Mixed HuggingFace Dataset in messages format.
    """
    def to_messages(ex: InstructionExample) -> dict:
        """Convert an instruction example to chat messages format."""
        messages = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": ex.instruction + (
                f"\n\n{ex.input_text}" if ex.input_text else ""
            )},
            {"role": "assistant", "content": ex.output},
        ]
        return {"messages": messages, "task_category": ex.task_category}

    task_data = [to_messages(ex) for ex in task_examples]
    general_data = [to_messages(ex) for ex in general_examples]

    # Compute sample sizes
    total = len(task_data) + len(general_data)
    n_task = int(total * task_ratio)
    n_general = total - n_task

    # Sample (with replacement if needed)
    import random
    random.seed(42)

    if len(task_data) >= n_task:
        sampled_task = random.sample(task_data, n_task)
    else:
        sampled_task = task_data * (n_task // len(task_data) + 1)
        sampled_task = sampled_task[:n_task]

    if len(general_data) >= n_general:
        sampled_general = random.sample(general_data, n_general)
    else:
        sampled_general = general_data * (n_general // len(general_data) + 1)
        sampled_general = sampled_general[:n_general]

    combined = sampled_task + sampled_general
    random.shuffle(combined)

    return Dataset.from_list(combined)

Step 4: Training Configuration

from peft import LoraConfig, TaskType
from transformers import AutoModelForCausalLM, AutoTokenizer


def get_training_config(
    output_dir: str = "./instruction-model",
    num_epochs: int = 3,
    batch_size: int = 4,
    learning_rate: float = 2e-5,
) -> dict:
    """Get the training configuration for SFT.

    Args:
        output_dir: Directory for saving checkpoints.
        num_epochs: Number of training epochs.
        batch_size: Per-device training batch size.
        learning_rate: Peak learning rate.

    Returns:
        Dictionary with all training hyperparameters.
    """
    return {
        "output_dir": output_dir,
        "num_train_epochs": num_epochs,
        "per_device_train_batch_size": batch_size,
        "gradient_accumulation_steps": 4,
        "learning_rate": learning_rate,
        "weight_decay": 0.01,
        "warmup_ratio": 0.05,
        "lr_scheduler_type": "cosine",
        "logging_steps": 10,
        "save_strategy": "epoch",
        "eval_strategy": "epoch",
        "bf16": True,
        "max_seq_length": 2048,
        "gradient_checkpointing": True,
        "lora_config": {
            "r": 16,
            "lora_alpha": 32,
            "lora_dropout": 0.05,
            "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"],
            "task_type": "CAUSAL_LM",
            "bias": "none",
        },
    }

Step 5: Evaluation Framework

@dataclass
class InstructionEvalResult:
    """Evaluation results for the instruction-following model.

    Attributes:
        task_accuracy: Per-task accuracy or quality score.
        format_compliance: Rate of properly formatted outputs.
        instruction_following_rate: Rate of correctly following instructions.
        avg_response_length: Average response length in tokens.
        general_benchmark_scores: Scores on general benchmarks.
    """

    task_accuracy: dict[str, float]
    format_compliance: float
    instruction_following_rate: float
    avg_response_length: float
    general_benchmark_scores: dict[str, float] = field(default_factory=dict)


def evaluate_instruction_following(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    eval_examples: list[InstructionExample],
) -> InstructionEvalResult:
    """Evaluate instruction-following quality.

    Args:
        model: The fine-tuned model.
        tokenizer: HuggingFace tokenizer.
        eval_examples: List of evaluation examples.

    Returns:
        Comprehensive evaluation results.
    """
    task_results: dict[str, list[bool]] = {}
    format_pass = 0
    instruction_follow = 0
    total_length = 0
    total = len(eval_examples)

    for ex in eval_examples:
        messages = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": ex.instruction + (
                f"\n\n{ex.input_text}" if ex.input_text else ""
            )},
        ]
        formatted = tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        inputs = tokenizer(formatted, return_tensors="pt").to(model.device)

        with torch.no_grad():
            output_ids = model.generate(
                **inputs, max_new_tokens=512, temperature=0.1, do_sample=True
            )
        response = tokenizer.decode(
            output_ids[0][inputs["input_ids"].shape[1]:],
            skip_special_tokens=True,
        )

        total_length += len(tokenizer.encode(response))

        # Check format compliance (non-empty, reasonable length)
        if len(response.strip()) > 0 and len(response) < 5000:
            format_pass += 1

        # Basic instruction-following check
        if len(response.strip()) > 10:
            instruction_follow += 1

        cat = ex.task_category
        if cat not in task_results:
            task_results[cat] = []
        task_results[cat].append(len(response.strip()) > 10)

    task_accuracy = {
        cat: sum(results) / len(results)
        for cat, results in task_results.items()
    }

    return InstructionEvalResult(
        task_accuracy=task_accuracy,
        format_compliance=format_pass / total if total > 0 else 0.0,
        instruction_following_rate=instruction_follow / total if total > 0 else 0.0,
        avg_response_length=total_length / total if total > 0 else 0.0,
    )


def measure_forgetting(
    base_scores: dict[str, float],
    finetuned_scores: dict[str, float],
) -> dict[str, float]:
    """Measure catastrophic forgetting by comparing benchmark scores.

    Args:
        base_scores: Benchmark scores of the base model.
        finetuned_scores: Benchmark scores of the fine-tuned model.

    Returns:
        Dictionary mapping benchmark names to forgetting percentages
        (negative means degradation).
    """
    forgetting = {}
    for benchmark in base_scores:
        if benchmark in finetuned_scores:
            base = base_scores[benchmark]
            finetuned = finetuned_scores[benchmark]
            if base > 0:
                forgetting[benchmark] = (finetuned - base) / base * 100
            else:
                forgetting[benchmark] = 0.0
    return forgetting

Step 6: Demonstration

if __name__ == "__main__":
    print("=" * 60)
    print("Case Study 2: Building an Instruction-Following Model")
    print("=" * 60)

    # Create and filter dataset
    raw_examples = create_diverse_dataset()
    print(f"\nRaw examples: {len(raw_examples)}")

    filtered = filter_dataset(raw_examples)
    print(f"After filtering: {len(filtered)}")

    # Show task distribution
    print("\nTask distribution:")
    from collections import Counter
    task_counts = Counter(ex.task_category for ex in filtered)
    for task, count in sorted(task_counts.items()):
        print(f"  {task}: {count}")

    # Show quality metrics for first example
    metrics = compute_quality_metrics(filtered[0])
    print(f"\nQuality metrics (example 0):")
    print(f"  Response length: {metrics.response_length}")
    print(f"  Has instruction: {metrics.has_instruction}")
    print(f"  Not repetitive: {metrics.is_not_repetitive}")
    print(f"  Overall pass: {metrics.overall_pass}")

    # Training config
    config = get_training_config()
    print(f"\nTraining config:")
    print(f"  Epochs: {config['num_train_epochs']}")
    print(f"  LR: {config['learning_rate']}")
    print(f"  LoRA rank: {config['lora_config']['r']}")

    # Forgetting measurement demo
    print("\nForgetting measurement demo:")
    base = {"mmlu": 0.65, "hellaswag": 0.78, "arc": 0.72}
    finetuned = {"mmlu": 0.62, "hellaswag": 0.75, "arc": 0.70}
    forgetting = measure_forgetting(base, finetuned)
    for bench, pct in forgetting.items():
        print(f"  {bench}: {pct:+.1f}%")

    print("\nTo run full training, execute with a compatible GPU.")

Key Takeaways

  1. Dataset diversity is paramount. An instruction-following model must see diverse tasks (10+ categories), instruction phrasings, and complexity levels during training.
  2. Quality filtering removes noise. Automated checks for length, repetition, formatting, and content quality significantly improve training data.
  3. Data mixing prevents forgetting. Including 20-50% general-purpose data alongside task-specific data helps maintain the model's broad capabilities.
  4. Evaluation must be multi-dimensional. Task accuracy, format compliance, instruction-following rate, and general benchmark retention each reveal different aspects of model quality.
  5. SFT is necessary but not sufficient for alignment. SFT teaches the model to follow instructions, but it does not teach it to prefer good responses over bad ones---that requires the preference optimization techniques covered in Chapter 25.