Case Study 1: Fine-Tuning Llama for Code Generation
Overview
In this case study, we fine-tune an open-source language model for Python code generation using QLoRA. Starting from a base model, we prepare a code instruction dataset, configure 4-bit quantization with LoRA, train using the TRL library, evaluate on code generation benchmarks, and compare performance before and after fine-tuning. The case study demonstrates the complete fine-tuning pipeline from data preparation through deployment.
Learning Objectives
- Prepare a code instruction dataset in the appropriate chat format.
- Configure QLoRA with BitsAndBytes for memory-efficient fine-tuning.
- Train using TRL's SFTTrainer with response-only loss masking.
- Evaluate code generation quality with functional correctness metrics.
- Diagnose and address common fine-tuning issues.
Step 1: Dataset Preparation
We create a code instruction dataset in conversation format, covering diverse Python tasks.
"""Fine-tuning a language model for Python code generation using QLoRA.
This case study demonstrates the full pipeline: dataset preparation,
QLoRA configuration, training, and evaluation.
Requirements:
pip install torch transformers peft trl datasets bitsandbytes
"""
from dataclasses import dataclass
import torch
from datasets import Dataset
torch.manual_seed(42)
@dataclass
class CodeExample:
"""A single code instruction-response pair.
Attributes:
instruction: Natural language description of the coding task.
code: The Python solution.
explanation: Brief explanation of the approach.
"""
instruction: str
code: str
explanation: str
def create_code_dataset() -> Dataset:
"""Create a code instruction dataset for fine-tuning.
Returns:
HuggingFace Dataset with 'messages' column in chat format.
"""
examples = [
CodeExample(
instruction="Write a Python function to check if a string is a palindrome.",
code=(
'def is_palindrome(s: str) -> bool:\n'
' """Check if a string is a palindrome.\n\n'
' Args:\n'
' s: Input string to check.\n\n'
' Returns:\n'
' True if the string is a palindrome.\n'
' """\n'
' cleaned = s.lower().replace(" ", "")\n'
' return cleaned == cleaned[::-1]'
),
explanation="Normalize the string by lowering case and removing spaces, then compare with its reverse.",
),
CodeExample(
instruction="Write a Python function to find the two numbers in a list that sum to a target value.",
code=(
'def two_sum(nums: list[int], target: int) -> tuple[int, int]:\n'
' """Find two numbers that sum to the target.\n\n'
' Args:\n'
' nums: List of integers.\n'
' target: Target sum value.\n\n'
' Returns:\n'
' Tuple of indices of the two numbers.\n\n'
' Raises:\n'
' ValueError: If no solution exists.\n'
' """\n'
' seen: dict[int, int] = {}\n'
' for i, num in enumerate(nums):\n'
' complement = target - num\n'
' if complement in seen:\n'
' return (seen[complement], i)\n'
' seen[num] = i\n'
' raise ValueError("No two sum solution found.")'
),
explanation="Use a hash map for O(n) lookup of complements.",
),
CodeExample(
instruction="Write a Python function to flatten a nested list of arbitrary depth.",
code=(
'def flatten(nested: list) -> list:\n'
' """Flatten a nested list of arbitrary depth.\n\n'
' Args:\n'
' nested: A potentially nested list.\n\n'
' Returns:\n'
' A flat list with all elements.\n'
' """\n'
' result: list = []\n'
' for item in nested:\n'
' if isinstance(item, list):\n'
' result.extend(flatten(item))\n'
' else:\n'
' result.append(item)\n'
' return result'
),
explanation="Recursively process each element, extending the result for lists and appending for atoms.",
),
]
messages_list = []
for ex in examples:
messages_list.append({
"messages": [
{
"role": "system",
"content": "You are a Python coding assistant. Write clean, "
"well-documented Python code with type hints and Google-style "
"docstrings.",
},
{"role": "user", "content": ex.instruction},
{
"role": "assistant",
"content": f"{ex.explanation}\n\n```python\n{ex.code}\n```",
},
]
})
return Dataset.from_list(messages_list)
Step 2: QLoRA Configuration
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
def setup_qlora_model(
model_name: str = "meta-llama/Llama-3.1-8B-Instruct",
lora_rank: int = 16,
lora_alpha: int = 32,
lora_dropout: float = 0.05,
target_modules: list[str] | None = None,
) -> tuple[AutoModelForCausalLM, AutoTokenizer, LoraConfig]:
"""Set up a model with QLoRA configuration.
Args:
model_name: HuggingFace model identifier.
lora_rank: Rank for LoRA decomposition.
lora_alpha: Alpha scaling factor for LoRA.
lora_dropout: Dropout probability for LoRA layers.
target_modules: List of module names to apply LoRA to.
Returns:
Tuple of (model, tokenizer, lora_config).
"""
if target_modules is None:
target_modules = [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
]
# 4-bit quantization config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
# Load model
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id
# Prepare model for k-bit training
model = prepare_model_for_kbit_training(model)
# LoRA config
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
target_modules=target_modules,
task_type=TaskType.CAUSAL_LM,
bias="none",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
return model, tokenizer, lora_config
Step 3: Training with TRL
from transformers import TrainingArguments
from trl import SFTConfig, SFTTrainer
def train_code_model(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
dataset: Dataset,
output_dir: str = "./code-llama-qlora",
num_epochs: int = 3,
batch_size: int = 4,
learning_rate: float = 2e-5,
max_seq_length: int = 1024,
) -> SFTTrainer:
"""Train the model using SFTTrainer with QLoRA.
Args:
model: The PEFT-wrapped model.
tokenizer: HuggingFace tokenizer.
dataset: Training dataset with 'messages' column.
output_dir: Directory to save checkpoints.
num_epochs: Number of training epochs.
batch_size: Per-device batch size.
learning_rate: Learning rate for AdamW optimizer.
max_seq_length: Maximum sequence length for training.
Returns:
The trained SFTTrainer instance.
"""
sft_config = SFTConfig(
output_dir=output_dir,
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=4,
learning_rate=learning_rate,
weight_decay=0.01,
warmup_ratio=0.05,
lr_scheduler_type="cosine",
logging_steps=10,
save_strategy="epoch",
bf16=True,
max_seq_length=max_seq_length,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
dataset_text_field=None, # Using messages format
)
trainer = SFTTrainer(
model=model,
args=sft_config,
train_dataset=dataset,
processing_class=tokenizer,
)
trainer.train()
return trainer
Step 4: Evaluation
import re
def evaluate_code_generation(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
test_prompts: list[str],
) -> dict[str, float]:
"""Evaluate code generation quality.
Args:
model: The fine-tuned model.
tokenizer: HuggingFace tokenizer.
test_prompts: List of coding task descriptions.
Returns:
Dictionary with evaluation metrics.
"""
total = len(test_prompts)
syntactically_valid = 0
has_docstring = 0
has_type_hints = 0
for prompt in test_prompts:
messages = [
{"role": "system", "content": "You are a Python coding assistant."},
{"role": "user", "content": prompt},
]
formatted = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(formatted, return_tensors="pt").to(model.device)
with torch.no_grad():
output_ids = model.generate(
**inputs, max_new_tokens=512, temperature=0.1, do_sample=True
)
response = tokenizer.decode(
output_ids[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True,
)
# Extract code block
code_match = re.search(r"```python\n(.*?)```", response, re.DOTALL)
if code_match:
code = code_match.group(1)
else:
code = response
# Check syntax validity
try:
compile(code, "<string>", "exec")
syntactically_valid += 1
except SyntaxError:
pass
# Check for docstring
if '"""' in code or "'''" in code:
has_docstring += 1
# Check for type hints
if "->" in code or ": " in code:
has_type_hints += 1
return {
"syntax_validity": syntactically_valid / total,
"docstring_rate": has_docstring / total,
"type_hint_rate": has_type_hints / total,
}
Step 5: Demonstration
if __name__ == "__main__":
print("=" * 60)
print("Case Study 1: Fine-Tuning for Code Generation")
print("=" * 60)
# Create dataset
dataset = create_code_dataset()
print(f"\nDataset size: {len(dataset)} examples")
print(f"First example roles: {[m['role'] for m in dataset[0]['messages']]}")
# Show dataset structure
print("\nDataset structure:")
for i, example in enumerate(dataset):
user_msg = example["messages"][1]["content"]
print(f" Example {i}: {user_msg[:60]}...")
print("\nQLoRA Configuration:")
print(" - Quantization: NF4 with double quantization")
print(" - LoRA rank: 16, alpha: 32")
print(" - Target modules: q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj")
print(" - Training: 3 epochs, lr=2e-5, cosine schedule")
print(" - Memory estimate: ~6 GB for 7B model")
print("\nTo run the full pipeline, execute with a compatible GPU.")
Key Takeaways
- QLoRA enables fine-tuning large models on consumer hardware by storing the base model in 4-bit precision while training LoRA parameters in higher precision.
- Chat template formatting is critical for instruction-tuned models. Using the wrong template leads to significant performance degradation.
- Response-only loss masking focuses the training signal on the code generation task rather than learning to reproduce prompts.
- Code-specific evaluation metrics (syntax validity, docstring presence, type hints) provide more actionable feedback than generic perplexity.
- Gradient checkpointing is essential for fitting the training within GPU memory constraints.