Case Study 2: Aligning a Language Model with DPO

Overview

In this case study, we implement the complete DPO alignment pipeline: starting from an SFT model, preparing preference data, implementing the DPO loss from scratch, training with monitoring, and evaluating the aligned model. We also implement length-controlled DPO and compare alignment quality before and after training.

Learning Objectives

  • Implement the DPO loss function from first principles.
  • Compute sequence-level log probabilities for the policy and reference models.
  • Train a DPO model with proper monitoring of implicit rewards and accuracy.
  • Implement length-controlled DPO to prevent verbosity bias.
  • Evaluate alignment quality with multiple metrics.

Step 1: DPO Loss Implementation

"""Aligning a language model with Direct Preference Optimization.

Implements DPO from scratch: loss function, log-probability computation,
training loop, monitoring, and evaluation.

Requirements:
    pip install torch transformers peft trl datasets
"""

from dataclasses import dataclass, field
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer

torch.manual_seed(42)


def compute_log_probs(
    model: AutoModelForCausalLM,
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor,
    labels: torch.Tensor,
) -> torch.Tensor:
    """Compute per-sequence log probabilities under a model.

    Args:
        model: Causal language model.
        input_ids: Token IDs of shape (batch, seq_len).
        attention_mask: Attention mask of shape (batch, seq_len).
        labels: Target token IDs of shape (batch, seq_len).
            Positions with -100 are ignored.

    Returns:
        Per-sequence log probabilities of shape (batch,).
    """
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    logits = outputs.logits

    # Shift logits and labels for next-token prediction
    shift_logits = logits[:, :-1, :]
    shift_labels = labels[:, 1:]
    shift_mask = (shift_labels != -100).float()

    # Compute per-token log probabilities
    log_probs = F.log_softmax(shift_logits, dim=-1)
    token_log_probs = log_probs.gather(
        2, shift_labels.clamp(min=0).unsqueeze(-1)
    ).squeeze(-1)

    # Mask padding and sum over sequence
    token_log_probs = token_log_probs * shift_mask
    sequence_log_probs = token_log_probs.sum(dim=-1)

    return sequence_log_probs


def dpo_loss(
    policy_chosen_logps: torch.Tensor,
    policy_rejected_logps: torch.Tensor,
    ref_chosen_logps: torch.Tensor,
    ref_rejected_logps: torch.Tensor,
    beta: float = 0.1,
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
    """Compute the DPO loss.

    Loss = -log(sigma(beta * (log(pi/pi_ref)(y_w) - log(pi/pi_ref)(y_l))))

    Args:
        policy_chosen_logps: Log probs of chosen under policy (batch,).
        policy_rejected_logps: Log probs of rejected under policy (batch,).
        ref_chosen_logps: Log probs of chosen under reference (batch,).
        ref_rejected_logps: Log probs of rejected under reference (batch,).
        beta: KL regularization strength.

    Returns:
        Tuple of (loss, metrics_dict) where metrics_dict contains
        chosen_rewards, rejected_rewards, and accuracy.
    """
    # Implicit rewards
    chosen_rewards = beta * (policy_chosen_logps - ref_chosen_logps)
    rejected_rewards = beta * (policy_rejected_logps - ref_rejected_logps)

    # DPO loss
    reward_margin = chosen_rewards - rejected_rewards
    loss = -F.logsigmoid(reward_margin).mean()

    # Metrics
    accuracy = (chosen_rewards > rejected_rewards).float().mean()

    metrics = {
        "chosen_rewards": chosen_rewards.detach().mean(),
        "rejected_rewards": rejected_rewards.detach().mean(),
        "reward_margin": reward_margin.detach().mean(),
        "accuracy": accuracy.detach(),
    }

    return loss, metrics


def length_controlled_dpo_loss(
    policy_chosen_logps: torch.Tensor,
    policy_rejected_logps: torch.Tensor,
    ref_chosen_logps: torch.Tensor,
    ref_rejected_logps: torch.Tensor,
    chosen_lengths: torch.Tensor,
    rejected_lengths: torch.Tensor,
    beta: float = 0.1,
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
    """Compute length-controlled DPO loss.

    Normalizes log probabilities by sequence length to prevent
    verbosity bias.

    Args:
        policy_chosen_logps: Log probs of chosen under policy.
        policy_rejected_logps: Log probs of rejected under policy.
        ref_chosen_logps: Log probs of chosen under reference.
        ref_rejected_logps: Log probs of rejected under reference.
        chosen_lengths: Number of response tokens for chosen.
        rejected_lengths: Number of response tokens for rejected.
        beta: KL regularization strength.

    Returns:
        Tuple of (loss, metrics_dict).
    """
    # Length-normalized implicit rewards
    chosen_rewards = beta * (
        (policy_chosen_logps - ref_chosen_logps) / chosen_lengths.float()
    )
    rejected_rewards = beta * (
        (policy_rejected_logps - ref_rejected_logps) / rejected_lengths.float()
    )

    reward_margin = chosen_rewards - rejected_rewards
    loss = -F.logsigmoid(reward_margin).mean()

    accuracy = (chosen_rewards > rejected_rewards).float().mean()

    metrics = {
        "chosen_rewards": chosen_rewards.detach().mean(),
        "rejected_rewards": rejected_rewards.detach().mean(),
        "reward_margin": reward_margin.detach().mean(),
        "accuracy": accuracy.detach(),
    }

    return loss, metrics

Step 2: Preference Data Preparation

@dataclass
class DPOExample:
    """A DPO training example.

    Attributes:
        prompt: The input prompt.
        chosen: The preferred response.
        rejected: The dispreferred response.
    """

    prompt: str
    chosen: str
    rejected: str


DPO_TRAINING_DATA = [
    DPOExample(
        prompt="How do I sort a list in Python?",
        chosen=(
            "You can sort a list in Python using the `sorted()` function "
            "or the `.sort()` method. `sorted()` returns a new sorted list "
            "while `.sort()` modifies the list in-place.\n\n"
            "```python\nnumbers = [3, 1, 4, 1, 5]\n"
            "sorted_numbers = sorted(numbers)  # [1, 1, 3, 4, 5]\n"
            "numbers.sort()  # modifies in-place\n```"
        ),
        rejected=(
            "There are many ways to sort things in Python. You could "
            "use sort or sorted or even write your own sorting algorithm "
            "like bubble sort or merge sort. Sorting is a fundamental "
            "operation in computer science."
        ),
    ),
    DPOExample(
        prompt="What causes rain?",
        chosen=(
            "Rain forms through the water cycle. Water evaporates from "
            "surfaces, rises as water vapor, cools at higher altitudes, "
            "and condenses around tiny particles to form cloud droplets. "
            "When these droplets combine and become heavy enough, they "
            "fall as rain."
        ),
        rejected=(
            "Rain is when water falls from the sky. It happens because "
            "of clouds. Clouds have water in them and sometimes the "
            "water falls down."
        ),
    ),
    DPOExample(
        prompt="Should I invest all my savings in cryptocurrency?",
        chosen=(
            "I would not recommend investing all your savings in "
            "cryptocurrency. Cryptocurrencies are highly volatile and "
            "can lose significant value quickly. Financial advisors "
            "generally recommend diversifying investments across asset "
            "classes and only investing money you can afford to lose in "
            "high-risk assets. Consider consulting a certified financial "
            "advisor for personalized advice."
        ),
        rejected=(
            "Cryptocurrency is the future of finance! Bitcoin and "
            "Ethereum have made many people wealthy. You should "
            "definitely go all-in. The prices will only go up from here."
        ),
    ),
]


def prepare_dpo_batch(
    examples: list[DPOExample],
    tokenizer: AutoTokenizer,
    max_length: int = 512,
) -> dict[str, torch.Tensor]:
    """Prepare a batch of DPO examples for training.

    Args:
        examples: List of DPO examples.
        tokenizer: HuggingFace tokenizer.
        max_length: Maximum sequence length.

    Returns:
        Dictionary with tokenized chosen and rejected sequences.
    """
    chosen_texts = [f"{ex.prompt}\n\n{ex.chosen}" for ex in examples]
    rejected_texts = [f"{ex.prompt}\n\n{ex.rejected}" for ex in examples]

    chosen_enc = tokenizer(
        chosen_texts, truncation=True, max_length=max_length,
        padding="max_length", return_tensors="pt",
    )
    rejected_enc = tokenizer(
        rejected_texts, truncation=True, max_length=max_length,
        padding="max_length", return_tensors="pt",
    )

    # Create labels (mask padding with -100)
    chosen_labels = chosen_enc["input_ids"].clone()
    chosen_labels[chosen_enc["attention_mask"] == 0] = -100

    rejected_labels = rejected_enc["input_ids"].clone()
    rejected_labels[rejected_enc["attention_mask"] == 0] = -100

    return {
        "chosen_input_ids": chosen_enc["input_ids"],
        "chosen_attention_mask": chosen_enc["attention_mask"],
        "chosen_labels": chosen_labels,
        "rejected_input_ids": rejected_enc["input_ids"],
        "rejected_attention_mask": rejected_enc["attention_mask"],
        "rejected_labels": rejected_labels,
    }

Step 3: Training Loop with Monitoring

@dataclass
class DPOTrainingMetrics:
    """Aggregated metrics from DPO training.

    Attributes:
        epoch: Training epoch number.
        loss: Average DPO loss.
        chosen_reward: Average implicit reward for chosen responses.
        rejected_reward: Average implicit reward for rejected responses.
        reward_margin: Average margin between chosen and rejected.
        accuracy: Fraction where chosen reward > rejected reward.
    """

    epoch: int
    loss: float
    chosen_reward: float
    rejected_reward: float
    reward_margin: float
    accuracy: float


def train_dpo(
    policy_model: AutoModelForCausalLM,
    ref_model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    training_data: list[DPOExample],
    beta: float = 0.1,
    learning_rate: float = 5e-6,
    num_epochs: int = 3,
    device: str = "cpu",
) -> list[DPOTrainingMetrics]:
    """Train a model with DPO.

    Args:
        policy_model: The model to train (initialized from SFT).
        ref_model: The frozen reference model (SFT checkpoint).
        tokenizer: HuggingFace tokenizer.
        training_data: List of preference examples.
        beta: DPO beta parameter (KL regularization strength).
        learning_rate: Learning rate for the optimizer.
        num_epochs: Number of training epochs.
        device: Device for training.

    Returns:
        List of training metrics per epoch.
    """
    policy_model = policy_model.to(device)
    ref_model = ref_model.to(device)
    ref_model.eval()

    optimizer = torch.optim.AdamW(policy_model.parameters(), lr=learning_rate)

    history: list[DPOTrainingMetrics] = []

    for epoch in range(num_epochs):
        policy_model.train()
        batch = prepare_dpo_batch(training_data, tokenizer)
        batch = {k: v.to(device) for k, v in batch.items()}

        # Compute log probs under policy
        policy_chosen_logps = compute_log_probs(
            policy_model,
            batch["chosen_input_ids"],
            batch["chosen_attention_mask"],
            batch["chosen_labels"],
        )
        policy_rejected_logps = compute_log_probs(
            policy_model,
            batch["rejected_input_ids"],
            batch["rejected_attention_mask"],
            batch["rejected_labels"],
        )

        # Compute log probs under reference (no grad)
        ref_chosen_logps = compute_log_probs(
            ref_model,
            batch["chosen_input_ids"],
            batch["chosen_attention_mask"],
            batch["chosen_labels"],
        )
        ref_rejected_logps = compute_log_probs(
            ref_model,
            batch["rejected_input_ids"],
            batch["rejected_attention_mask"],
            batch["rejected_labels"],
        )

        # Enable gradients for policy log probs
        policy_model.train()
        outputs_chosen = policy_model(
            input_ids=batch["chosen_input_ids"],
            attention_mask=batch["chosen_attention_mask"],
        )
        outputs_rejected = policy_model(
            input_ids=batch["rejected_input_ids"],
            attention_mask=batch["rejected_attention_mask"],
        )

        # Recompute log probs with gradients
        def compute_logps_with_grad(logits, labels):
            shift_logits = logits[:, :-1, :]
            shift_labels = labels[:, 1:]
            mask = (shift_labels != -100).float()
            log_probs = F.log_softmax(shift_logits, dim=-1)
            token_lps = log_probs.gather(
                2, shift_labels.clamp(min=0).unsqueeze(-1)
            ).squeeze(-1)
            return (token_lps * mask).sum(dim=-1)

        policy_c_lps = compute_logps_with_grad(
            outputs_chosen.logits, batch["chosen_labels"]
        )
        policy_r_lps = compute_logps_with_grad(
            outputs_rejected.logits, batch["rejected_labels"]
        )

        loss, metrics = dpo_loss(
            policy_c_lps, policy_r_lps,
            ref_chosen_logps, ref_rejected_logps,
            beta=beta,
        )

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(policy_model.parameters(), 1.0)
        optimizer.step()

        epoch_metrics = DPOTrainingMetrics(
            epoch=epoch + 1,
            loss=loss.item(),
            chosen_reward=metrics["chosen_rewards"].item(),
            rejected_reward=metrics["rejected_rewards"].item(),
            reward_margin=metrics["reward_margin"].item(),
            accuracy=metrics["accuracy"].item(),
        )
        history.append(epoch_metrics)

        print(
            f"Epoch {epoch + 1}: loss={epoch_metrics.loss:.4f}, "
            f"margin={epoch_metrics.reward_margin:.4f}, "
            f"acc={epoch_metrics.accuracy:.1%}"
        )

    return history

Step 4: Evaluation

def evaluate_alignment(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    test_prompts: list[str],
    device: str = "cpu",
) -> list[dict[str, str]]:
    """Generate responses and evaluate alignment quality.

    Args:
        model: The aligned model.
        tokenizer: HuggingFace tokenizer.
        test_prompts: List of prompts to test.
        device: Device for inference.

    Returns:
        List of dicts with prompt and generated response.
    """
    model.eval()
    results = []

    for prompt in test_prompts:
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        with torch.no_grad():
            output_ids = model.generate(
                **inputs,
                max_new_tokens=200,
                temperature=0.7,
                do_sample=True,
                pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
            )
        response = tokenizer.decode(
            output_ids[0][inputs["input_ids"].shape[1]:],
            skip_special_tokens=True,
        )
        results.append({"prompt": prompt, "response": response.strip()})

    return results

Step 5: Demonstration

if __name__ == "__main__":
    print("=" * 60)
    print("Case Study 2: Aligning with DPO")
    print("=" * 60)

    print(f"\nTraining examples: {len(DPO_TRAINING_DATA)}")
    for i, ex in enumerate(DPO_TRAINING_DATA):
        print(f"  Example {i}: {ex.prompt[:50]}...")

    print("\nDPO Loss Components:")
    print("  1. Compute log P(y_w|x) and log P(y_l|x) under policy")
    print("  2. Compute log P(y_w|x) and log P(y_l|x) under reference")
    print("  3. Implicit rewards: beta * (log pi - log pi_ref)")
    print("  4. Loss: -log sigma(r_w - r_l)")

    print("\nMonitoring Metrics:")
    print("  - Chosen rewards (should increase)")
    print("  - Rejected rewards (should decrease)")
    print("  - Reward margin (should increase)")
    print("  - Accuracy (should approach 1.0)")
    print("  - KL divergence (should remain bounded)")

    # Demonstrate loss computation with synthetic values
    print("\n--- Synthetic Loss Demo ---")
    policy_c = torch.tensor([-10.0, -12.0, -8.0])
    policy_r = torch.tensor([-15.0, -14.0, -13.0])
    ref_c = torch.tensor([-11.0, -13.0, -9.0])
    ref_r = torch.tensor([-14.0, -13.0, -12.0])

    loss, metrics = dpo_loss(policy_c, policy_r, ref_c, ref_r, beta=0.1)
    print(f"  Loss: {loss.item():.4f}")
    print(f"  Chosen reward: {metrics['chosen_rewards'].item():.4f}")
    print(f"  Rejected reward: {metrics['rejected_rewards'].item():.4f}")
    print(f"  Margin: {metrics['reward_margin'].item():.4f}")
    print(f"  Accuracy: {metrics['accuracy'].item():.1%}")

    print("\nTo run full DPO training, execute with a compatible GPU.")

Key Takeaways

  1. DPO is fundamentally simpler than RLHF. It replaces the reward model and PPO with a single supervised loss computed from preference pairs and log probabilities under policy and reference models.
  2. The DPO $\beta$ parameter is the most important hyperparameter. Lower values (0.1) allow more deviation from the reference; higher values (0.5) are more conservative. Start low and increase if the model degrades.
  3. Monitor implicit rewards, not just loss. The chosen reward should increase, rejected should decrease, and the margin should grow. If these diverge from expectations, the training is likely unstable.
  4. Length-controlled DPO prevents verbosity bias by normalizing implicit rewards by response length. Without this, models tend to become increasingly verbose.
  5. The reference model must remain frozen throughout training. It provides the anchor that prevents the policy from drifting too far from sensible behavior.