Case Study 1: Fine-Tuning Llama for Code Generation

Overview

In this case study, we fine-tune an open-source language model for Python code generation using QLoRA. Starting from a base model, we prepare a code instruction dataset, configure 4-bit quantization with LoRA, train using the TRL library, evaluate on code generation benchmarks, and compare performance before and after fine-tuning. The case study demonstrates the complete fine-tuning pipeline from data preparation through deployment.

Learning Objectives

  • Prepare a code instruction dataset in the appropriate chat format.
  • Configure QLoRA with BitsAndBytes for memory-efficient fine-tuning.
  • Train using TRL's SFTTrainer with response-only loss masking.
  • Evaluate code generation quality with functional correctness metrics.
  • Diagnose and address common fine-tuning issues.

Step 1: Dataset Preparation

We create a code instruction dataset in conversation format, covering diverse Python tasks.

"""Fine-tuning a language model for Python code generation using QLoRA.

This case study demonstrates the full pipeline: dataset preparation,
QLoRA configuration, training, and evaluation.

Requirements:
    pip install torch transformers peft trl datasets bitsandbytes
"""

from dataclasses import dataclass

import torch
from datasets import Dataset

torch.manual_seed(42)


@dataclass
class CodeExample:
    """A single code instruction-response pair.

    Attributes:
        instruction: Natural language description of the coding task.
        code: The Python solution.
        explanation: Brief explanation of the approach.
    """

    instruction: str
    code: str
    explanation: str


def create_code_dataset() -> Dataset:
    """Create a code instruction dataset for fine-tuning.

    Returns:
        HuggingFace Dataset with 'messages' column in chat format.
    """
    examples = [
        CodeExample(
            instruction="Write a Python function to check if a string is a palindrome.",
            code=(
                'def is_palindrome(s: str) -> bool:\n'
                '    """Check if a string is a palindrome.\n\n'
                '    Args:\n'
                '        s: Input string to check.\n\n'
                '    Returns:\n'
                '        True if the string is a palindrome.\n'
                '    """\n'
                '    cleaned = s.lower().replace(" ", "")\n'
                '    return cleaned == cleaned[::-1]'
            ),
            explanation="Normalize the string by lowering case and removing spaces, then compare with its reverse.",
        ),
        CodeExample(
            instruction="Write a Python function to find the two numbers in a list that sum to a target value.",
            code=(
                'def two_sum(nums: list[int], target: int) -> tuple[int, int]:\n'
                '    """Find two numbers that sum to the target.\n\n'
                '    Args:\n'
                '        nums: List of integers.\n'
                '        target: Target sum value.\n\n'
                '    Returns:\n'
                '        Tuple of indices of the two numbers.\n\n'
                '    Raises:\n'
                '        ValueError: If no solution exists.\n'
                '    """\n'
                '    seen: dict[int, int] = {}\n'
                '    for i, num in enumerate(nums):\n'
                '        complement = target - num\n'
                '        if complement in seen:\n'
                '            return (seen[complement], i)\n'
                '        seen[num] = i\n'
                '    raise ValueError("No two sum solution found.")'
            ),
            explanation="Use a hash map for O(n) lookup of complements.",
        ),
        CodeExample(
            instruction="Write a Python function to flatten a nested list of arbitrary depth.",
            code=(
                'def flatten(nested: list) -> list:\n'
                '    """Flatten a nested list of arbitrary depth.\n\n'
                '    Args:\n'
                '        nested: A potentially nested list.\n\n'
                '    Returns:\n'
                '        A flat list with all elements.\n'
                '    """\n'
                '    result: list = []\n'
                '    for item in nested:\n'
                '        if isinstance(item, list):\n'
                '            result.extend(flatten(item))\n'
                '        else:\n'
                '            result.append(item)\n'
                '    return result'
            ),
            explanation="Recursively process each element, extending the result for lists and appending for atoms.",
        ),
    ]

    messages_list = []
    for ex in examples:
        messages_list.append({
            "messages": [
                {
                    "role": "system",
                    "content": "You are a Python coding assistant. Write clean, "
                    "well-documented Python code with type hints and Google-style "
                    "docstrings.",
                },
                {"role": "user", "content": ex.instruction},
                {
                    "role": "assistant",
                    "content": f"{ex.explanation}\n\n```python\n{ex.code}\n```",
                },
            ]
        })

    return Dataset.from_list(messages_list)

Step 2: QLoRA Configuration

from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig


def setup_qlora_model(
    model_name: str = "meta-llama/Llama-3.1-8B-Instruct",
    lora_rank: int = 16,
    lora_alpha: int = 32,
    lora_dropout: float = 0.05,
    target_modules: list[str] | None = None,
) -> tuple[AutoModelForCausalLM, AutoTokenizer, LoraConfig]:
    """Set up a model with QLoRA configuration.

    Args:
        model_name: HuggingFace model identifier.
        lora_rank: Rank for LoRA decomposition.
        lora_alpha: Alpha scaling factor for LoRA.
        lora_dropout: Dropout probability for LoRA layers.
        target_modules: List of module names to apply LoRA to.

    Returns:
        Tuple of (model, tokenizer, lora_config).
    """
    if target_modules is None:
        target_modules = [
            "q_proj", "k_proj", "v_proj", "o_proj",
            "gate_proj", "up_proj", "down_proj",
        ]

    # 4-bit quantization config
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
    )

    # Load model
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype=torch.bfloat16,
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        model.config.pad_token_id = tokenizer.eos_token_id

    # Prepare model for k-bit training
    model = prepare_model_for_kbit_training(model)

    # LoRA config
    lora_config = LoraConfig(
        r=lora_rank,
        lora_alpha=lora_alpha,
        lora_dropout=lora_dropout,
        target_modules=target_modules,
        task_type=TaskType.CAUSAL_LM,
        bias="none",
    )

    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()

    return model, tokenizer, lora_config

Step 3: Training with TRL

from transformers import TrainingArguments
from trl import SFTConfig, SFTTrainer


def train_code_model(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    dataset: Dataset,
    output_dir: str = "./code-llama-qlora",
    num_epochs: int = 3,
    batch_size: int = 4,
    learning_rate: float = 2e-5,
    max_seq_length: int = 1024,
) -> SFTTrainer:
    """Train the model using SFTTrainer with QLoRA.

    Args:
        model: The PEFT-wrapped model.
        tokenizer: HuggingFace tokenizer.
        dataset: Training dataset with 'messages' column.
        output_dir: Directory to save checkpoints.
        num_epochs: Number of training epochs.
        batch_size: Per-device batch size.
        learning_rate: Learning rate for AdamW optimizer.
        max_seq_length: Maximum sequence length for training.

    Returns:
        The trained SFTTrainer instance.
    """
    sft_config = SFTConfig(
        output_dir=output_dir,
        num_train_epochs=num_epochs,
        per_device_train_batch_size=batch_size,
        gradient_accumulation_steps=4,
        learning_rate=learning_rate,
        weight_decay=0.01,
        warmup_ratio=0.05,
        lr_scheduler_type="cosine",
        logging_steps=10,
        save_strategy="epoch",
        bf16=True,
        max_seq_length=max_seq_length,
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs={"use_reentrant": False},
        dataset_text_field=None,  # Using messages format
    )

    trainer = SFTTrainer(
        model=model,
        args=sft_config,
        train_dataset=dataset,
        processing_class=tokenizer,
    )

    trainer.train()
    return trainer

Step 4: Evaluation

import re


def evaluate_code_generation(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    test_prompts: list[str],
) -> dict[str, float]:
    """Evaluate code generation quality.

    Args:
        model: The fine-tuned model.
        tokenizer: HuggingFace tokenizer.
        test_prompts: List of coding task descriptions.

    Returns:
        Dictionary with evaluation metrics.
    """
    total = len(test_prompts)
    syntactically_valid = 0
    has_docstring = 0
    has_type_hints = 0

    for prompt in test_prompts:
        messages = [
            {"role": "system", "content": "You are a Python coding assistant."},
            {"role": "user", "content": prompt},
        ]
        formatted = tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        inputs = tokenizer(formatted, return_tensors="pt").to(model.device)

        with torch.no_grad():
            output_ids = model.generate(
                **inputs, max_new_tokens=512, temperature=0.1, do_sample=True
            )

        response = tokenizer.decode(
            output_ids[0][inputs["input_ids"].shape[1]:],
            skip_special_tokens=True,
        )

        # Extract code block
        code_match = re.search(r"```python\n(.*?)```", response, re.DOTALL)
        if code_match:
            code = code_match.group(1)
        else:
            code = response

        # Check syntax validity
        try:
            compile(code, "<string>", "exec")
            syntactically_valid += 1
        except SyntaxError:
            pass

        # Check for docstring
        if '"""' in code or "'''" in code:
            has_docstring += 1

        # Check for type hints
        if "->" in code or ": " in code:
            has_type_hints += 1

    return {
        "syntax_validity": syntactically_valid / total,
        "docstring_rate": has_docstring / total,
        "type_hint_rate": has_type_hints / total,
    }

Step 5: Demonstration

if __name__ == "__main__":
    print("=" * 60)
    print("Case Study 1: Fine-Tuning for Code Generation")
    print("=" * 60)

    # Create dataset
    dataset = create_code_dataset()
    print(f"\nDataset size: {len(dataset)} examples")
    print(f"First example roles: {[m['role'] for m in dataset[0]['messages']]}")

    # Show dataset structure
    print("\nDataset structure:")
    for i, example in enumerate(dataset):
        user_msg = example["messages"][1]["content"]
        print(f"  Example {i}: {user_msg[:60]}...")

    print("\nQLoRA Configuration:")
    print("  - Quantization: NF4 with double quantization")
    print("  - LoRA rank: 16, alpha: 32")
    print("  - Target modules: q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj")
    print("  - Training: 3 epochs, lr=2e-5, cosine schedule")
    print("  - Memory estimate: ~6 GB for 7B model")

    print("\nTo run the full pipeline, execute with a compatible GPU.")

Key Takeaways

  1. QLoRA enables fine-tuning large models on consumer hardware by storing the base model in 4-bit precision while training LoRA parameters in higher precision.
  2. Chat template formatting is critical for instruction-tuned models. Using the wrong template leads to significant performance degradation.
  3. Response-only loss masking focuses the training signal on the code generation task rather than learning to reproduce prompts.
  4. Code-specific evaluation metrics (syntax validity, docstring presence, type hints) provide more actionable feedback than generic perplexity.
  5. Gradient checkpointing is essential for fitting the training within GPU memory constraints.