Case Study 2: Building an Instruction-Following Model
Overview
In this case study, we build an instruction-following model from a base language model using supervised fine-tuning (SFT). We cover the entire process: curating a diverse instruction dataset, applying quality filters, formatting data with chat templates, training with LoRA using TRL, evaluating instruction-following quality with multiple metrics, and measuring catastrophic forgetting. The goal is to transform a base model into one that reliably follows diverse natural language instructions.
Learning Objectives
- Curate a diverse, high-quality instruction-tuning dataset.
- Apply quality filtering to remove low-quality examples.
- Implement the full SFT pipeline with LoRA using TRL.
- Evaluate instruction-following with automated and LLM-as-judge metrics.
- Measure and mitigate catastrophic forgetting through data mixing.
Step 1: Dataset Curation
"""Building an instruction-following model via supervised fine-tuning.
This case study covers dataset curation, quality filtering, SFT training,
and comprehensive evaluation including forgetting analysis.
Requirements:
pip install torch transformers peft trl datasets
"""
import json
import re
from dataclasses import dataclass, field
import torch
from datasets import Dataset, concatenate_datasets
torch.manual_seed(42)
# ---------------------------------------------------------------------------
# Instruction dataset structure
# ---------------------------------------------------------------------------
TASK_CATEGORIES = [
"summarization",
"classification",
"extraction",
"generation",
"reasoning",
"coding",
"math",
"rewriting",
"question_answering",
"brainstorming",
]
@dataclass
class InstructionExample:
"""A single instruction-tuning example.
Attributes:
task_category: Category of the task.
instruction: The natural language instruction.
input_text: Optional input context.
output: The desired response.
quality_score: Quality rating (1-5) for filtering.
"""
task_category: str
instruction: str
input_text: str
output: str
quality_score: float = 5.0
def create_diverse_dataset() -> list[InstructionExample]:
"""Create a diverse instruction dataset covering all task categories.
Returns:
List of InstructionExample objects.
"""
examples = [
# Summarization
InstructionExample(
task_category="summarization",
instruction="Summarize the following paragraph in one sentence.",
input_text=(
"Machine learning models require large amounts of data "
"to learn patterns. The quality of data directly affects "
"model performance. Data augmentation, cleaning, and "
"feature engineering are critical preprocessing steps."
),
output=(
"Effective machine learning depends on large quantities "
"of high-quality data, with preprocessing steps like "
"augmentation and feature engineering being crucial."
),
),
# Classification
InstructionExample(
task_category="classification",
instruction=(
"Classify the sentiment of the following product review "
"as positive, negative, or neutral."
),
input_text="The battery life is incredible but the screen is dim.",
output=(
"Neutral. The review contains both positive feedback "
"(battery life) and negative feedback (dim screen)."
),
),
# Extraction
InstructionExample(
task_category="extraction",
instruction="Extract all dates mentioned in the following text.",
input_text=(
"The project started on January 15, 2024 and the first "
"milestone was completed on March 1, 2024. The final "
"deadline is December 31, 2024."
),
output=(
"1. January 15, 2024 (project start)\n"
"2. March 1, 2024 (first milestone)\n"
"3. December 31, 2024 (final deadline)"
),
),
# Generation
InstructionExample(
task_category="generation",
instruction=(
"Write a professional email declining a meeting invitation "
"due to a scheduling conflict."
),
input_text="",
output=(
"Subject: Unable to Attend - Meeting on Thursday\n\n"
"Dear [Name],\n\n"
"Thank you for the meeting invitation. Unfortunately, "
"I have a prior commitment at that time and will not be "
"able to attend. Could we reschedule, or would it be "
"possible to receive the meeting notes afterward?\n\n"
"Best regards,\n[Your Name]"
),
),
# Reasoning
InstructionExample(
task_category="reasoning",
instruction="Explain why the conclusion does not follow from the premise.",
input_text=(
'Premise: "All dogs are mammals." '
'Conclusion: "All mammals are dogs."'
),
output=(
"This is a logical fallacy known as affirming the consequent "
"(or converse error). While all dogs are indeed mammals, the "
"reverse is not true. Mammals include cats, whales, humans, "
"and many other species. The set of dogs is a subset of "
"mammals, not equal to it."
),
),
# Coding
InstructionExample(
task_category="coding",
instruction="Write a Python function to compute the Fibonacci sequence up to n terms.",
input_text="",
output=(
"```python\n"
"def fibonacci(n: int) -> list[int]:\n"
' """Compute Fibonacci sequence up to n terms.\n\n'
" Args:\n"
" n: Number of terms to generate.\n\n"
" Returns:\n"
" List of Fibonacci numbers.\n"
' """\n'
" if n <= 0:\n"
" return []\n"
" if n == 1:\n"
" return [0]\n"
" fib = [0, 1]\n"
" for _ in range(2, n):\n"
" fib.append(fib[-1] + fib[-2])\n"
" return fib\n"
"```"
),
),
# Math
InstructionExample(
task_category="math",
instruction="Solve the following equation and show your work.",
input_text="3x + 7 = 22",
output=(
"Step 1: Subtract 7 from both sides.\n"
" 3x + 7 - 7 = 22 - 7\n"
" 3x = 15\n\n"
"Step 2: Divide both sides by 3.\n"
" 3x / 3 = 15 / 3\n"
" x = 5\n\n"
"The solution is x = 5."
),
),
# Rewriting
InstructionExample(
task_category="rewriting",
instruction="Rewrite the following sentence in active voice.",
input_text="The report was written by the team last Friday.",
output="The team wrote the report last Friday.",
),
# Question Answering
InstructionExample(
task_category="question_answering",
instruction="Answer the question based on the given context.",
input_text=(
"Context: The Great Wall of China stretches over 13,000 "
"miles. Construction began in the 7th century BC and "
"continued for centuries. Question: How long is the "
"Great Wall of China?"
),
output=(
"According to the context, the Great Wall of China "
"stretches over 13,000 miles."
),
),
# Brainstorming
InstructionExample(
task_category="brainstorming",
instruction="List 5 creative ways to reduce paper waste in an office.",
input_text="",
output=(
"1. Implement a digital-first policy where all documents "
"are shared electronically by default.\n"
"2. Set up a paper reuse station where single-sided prints "
"are collected for scratch paper.\n"
"3. Install tablet devices in meeting rooms to replace "
"printed agendas and handouts.\n"
"4. Create a monthly challenge where teams compete to "
"minimize their paper consumption.\n"
"5. Replace paper towels with air dryers and provide "
"reusable cloth alternatives."
),
),
]
return examples
Step 2: Quality Filtering
@dataclass
class QualityMetrics:
"""Quality metrics for a single instruction example.
Attributes:
response_length: Length of the output in characters.
has_instruction: Whether the instruction is non-empty.
has_output: Whether the output is non-empty.
is_not_repetitive: Whether the output avoids excessive repetition.
formatting_consistent: Whether formatting is clean.
overall_pass: Whether the example passes all quality checks.
"""
response_length: int
has_instruction: bool
has_output: bool
is_not_repetitive: bool
formatting_consistent: bool
overall_pass: bool
def compute_quality_metrics(example: InstructionExample) -> QualityMetrics:
"""Compute quality metrics for an instruction example.
Args:
example: The instruction example to evaluate.
Returns:
Quality metrics for the example.
"""
response_length = len(example.output)
has_instruction = len(example.instruction.strip()) > 10
has_output = len(example.output.strip()) > 5
# Check for excessive repetition (same trigram repeated 3+ times)
words = example.output.lower().split()
trigrams = [
" ".join(words[i:i + 3]) for i in range(len(words) - 2)
]
trigram_counts = {}
for tg in trigrams:
trigram_counts[tg] = trigram_counts.get(tg, 0) + 1
max_repeat = max(trigram_counts.values()) if trigram_counts else 0
is_not_repetitive = max_repeat < 3
# Check formatting
formatting_consistent = not example.output.startswith(" ") and len(
example.output.strip()
) == len(example.output.rstrip())
overall_pass = all([
has_instruction,
has_output,
is_not_repetitive,
formatting_consistent,
response_length >= 10,
response_length <= 5000,
])
return QualityMetrics(
response_length=response_length,
has_instruction=has_instruction,
has_output=has_output,
is_not_repetitive=is_not_repetitive,
formatting_consistent=formatting_consistent,
overall_pass=overall_pass,
)
def filter_dataset(
examples: list[InstructionExample],
min_quality_score: float = 3.0,
) -> list[InstructionExample]:
"""Filter a dataset based on quality metrics.
Args:
examples: List of instruction examples.
min_quality_score: Minimum quality score to keep.
Returns:
Filtered list of examples that pass quality checks.
"""
filtered = []
for ex in examples:
metrics = compute_quality_metrics(ex)
if metrics.overall_pass and ex.quality_score >= min_quality_score:
filtered.append(ex)
return filtered
Step 3: Data Mixing for Forgetting Prevention
def create_mixed_dataset(
task_examples: list[InstructionExample],
general_examples: list[InstructionExample],
task_ratio: float = 0.7,
) -> Dataset:
"""Create a mixed dataset combining task and general data.
Mixes task-specific data with general instruction data to prevent
catastrophic forgetting during fine-tuning.
Args:
task_examples: Task-specific instruction examples.
general_examples: General-purpose instruction examples.
task_ratio: Fraction of the dataset that is task-specific.
Returns:
Mixed HuggingFace Dataset in messages format.
"""
def to_messages(ex: InstructionExample) -> dict:
"""Convert an instruction example to chat messages format."""
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": ex.instruction + (
f"\n\n{ex.input_text}" if ex.input_text else ""
)},
{"role": "assistant", "content": ex.output},
]
return {"messages": messages, "task_category": ex.task_category}
task_data = [to_messages(ex) for ex in task_examples]
general_data = [to_messages(ex) for ex in general_examples]
# Compute sample sizes
total = len(task_data) + len(general_data)
n_task = int(total * task_ratio)
n_general = total - n_task
# Sample (with replacement if needed)
import random
random.seed(42)
if len(task_data) >= n_task:
sampled_task = random.sample(task_data, n_task)
else:
sampled_task = task_data * (n_task // len(task_data) + 1)
sampled_task = sampled_task[:n_task]
if len(general_data) >= n_general:
sampled_general = random.sample(general_data, n_general)
else:
sampled_general = general_data * (n_general // len(general_data) + 1)
sampled_general = sampled_general[:n_general]
combined = sampled_task + sampled_general
random.shuffle(combined)
return Dataset.from_list(combined)
Step 4: Training Configuration
from peft import LoraConfig, TaskType
from transformers import AutoModelForCausalLM, AutoTokenizer
def get_training_config(
output_dir: str = "./instruction-model",
num_epochs: int = 3,
batch_size: int = 4,
learning_rate: float = 2e-5,
) -> dict:
"""Get the training configuration for SFT.
Args:
output_dir: Directory for saving checkpoints.
num_epochs: Number of training epochs.
batch_size: Per-device training batch size.
learning_rate: Peak learning rate.
Returns:
Dictionary with all training hyperparameters.
"""
return {
"output_dir": output_dir,
"num_train_epochs": num_epochs,
"per_device_train_batch_size": batch_size,
"gradient_accumulation_steps": 4,
"learning_rate": learning_rate,
"weight_decay": 0.01,
"warmup_ratio": 0.05,
"lr_scheduler_type": "cosine",
"logging_steps": 10,
"save_strategy": "epoch",
"eval_strategy": "epoch",
"bf16": True,
"max_seq_length": 2048,
"gradient_checkpointing": True,
"lora_config": {
"r": 16,
"lora_alpha": 32,
"lora_dropout": 0.05,
"target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"],
"task_type": "CAUSAL_LM",
"bias": "none",
},
}
Step 5: Evaluation Framework
@dataclass
class InstructionEvalResult:
"""Evaluation results for the instruction-following model.
Attributes:
task_accuracy: Per-task accuracy or quality score.
format_compliance: Rate of properly formatted outputs.
instruction_following_rate: Rate of correctly following instructions.
avg_response_length: Average response length in tokens.
general_benchmark_scores: Scores on general benchmarks.
"""
task_accuracy: dict[str, float]
format_compliance: float
instruction_following_rate: float
avg_response_length: float
general_benchmark_scores: dict[str, float] = field(default_factory=dict)
def evaluate_instruction_following(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
eval_examples: list[InstructionExample],
) -> InstructionEvalResult:
"""Evaluate instruction-following quality.
Args:
model: The fine-tuned model.
tokenizer: HuggingFace tokenizer.
eval_examples: List of evaluation examples.
Returns:
Comprehensive evaluation results.
"""
task_results: dict[str, list[bool]] = {}
format_pass = 0
instruction_follow = 0
total_length = 0
total = len(eval_examples)
for ex in eval_examples:
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": ex.instruction + (
f"\n\n{ex.input_text}" if ex.input_text else ""
)},
]
formatted = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(formatted, return_tensors="pt").to(model.device)
with torch.no_grad():
output_ids = model.generate(
**inputs, max_new_tokens=512, temperature=0.1, do_sample=True
)
response = tokenizer.decode(
output_ids[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True,
)
total_length += len(tokenizer.encode(response))
# Check format compliance (non-empty, reasonable length)
if len(response.strip()) > 0 and len(response) < 5000:
format_pass += 1
# Basic instruction-following check
if len(response.strip()) > 10:
instruction_follow += 1
cat = ex.task_category
if cat not in task_results:
task_results[cat] = []
task_results[cat].append(len(response.strip()) > 10)
task_accuracy = {
cat: sum(results) / len(results)
for cat, results in task_results.items()
}
return InstructionEvalResult(
task_accuracy=task_accuracy,
format_compliance=format_pass / total if total > 0 else 0.0,
instruction_following_rate=instruction_follow / total if total > 0 else 0.0,
avg_response_length=total_length / total if total > 0 else 0.0,
)
def measure_forgetting(
base_scores: dict[str, float],
finetuned_scores: dict[str, float],
) -> dict[str, float]:
"""Measure catastrophic forgetting by comparing benchmark scores.
Args:
base_scores: Benchmark scores of the base model.
finetuned_scores: Benchmark scores of the fine-tuned model.
Returns:
Dictionary mapping benchmark names to forgetting percentages
(negative means degradation).
"""
forgetting = {}
for benchmark in base_scores:
if benchmark in finetuned_scores:
base = base_scores[benchmark]
finetuned = finetuned_scores[benchmark]
if base > 0:
forgetting[benchmark] = (finetuned - base) / base * 100
else:
forgetting[benchmark] = 0.0
return forgetting
Step 6: Demonstration
if __name__ == "__main__":
print("=" * 60)
print("Case Study 2: Building an Instruction-Following Model")
print("=" * 60)
# Create and filter dataset
raw_examples = create_diverse_dataset()
print(f"\nRaw examples: {len(raw_examples)}")
filtered = filter_dataset(raw_examples)
print(f"After filtering: {len(filtered)}")
# Show task distribution
print("\nTask distribution:")
from collections import Counter
task_counts = Counter(ex.task_category for ex in filtered)
for task, count in sorted(task_counts.items()):
print(f" {task}: {count}")
# Show quality metrics for first example
metrics = compute_quality_metrics(filtered[0])
print(f"\nQuality metrics (example 0):")
print(f" Response length: {metrics.response_length}")
print(f" Has instruction: {metrics.has_instruction}")
print(f" Not repetitive: {metrics.is_not_repetitive}")
print(f" Overall pass: {metrics.overall_pass}")
# Training config
config = get_training_config()
print(f"\nTraining config:")
print(f" Epochs: {config['num_train_epochs']}")
print(f" LR: {config['learning_rate']}")
print(f" LoRA rank: {config['lora_config']['r']}")
# Forgetting measurement demo
print("\nForgetting measurement demo:")
base = {"mmlu": 0.65, "hellaswag": 0.78, "arc": 0.72}
finetuned = {"mmlu": 0.62, "hellaswag": 0.75, "arc": 0.70}
forgetting = measure_forgetting(base, finetuned)
for bench, pct in forgetting.items():
print(f" {bench}: {pct:+.1f}%")
print("\nTo run full training, execute with a compatible GPU.")
Key Takeaways
- Dataset diversity is paramount. An instruction-following model must see diverse tasks (10+ categories), instruction phrasings, and complexity levels during training.
- Quality filtering removes noise. Automated checks for length, repetition, formatting, and content quality significantly improve training data.
- Data mixing prevents forgetting. Including 20-50% general-purpose data alongside task-specific data helps maintain the model's broad capabilities.
- Evaluation must be multi-dimensional. Task accuracy, format compliance, instruction-following rate, and general benchmark retention each reveal different aspects of model quality.
- SFT is necessary but not sufficient for alignment. SFT teaches the model to follow instructions, but it does not teach it to prefer good responses over bad ones---that requires the preference optimization techniques covered in Chapter 25.