Case Study 2: Aligning a Language Model with DPO
Overview
In this case study, we implement the complete DPO alignment pipeline: starting from an SFT model, preparing preference data, implementing the DPO loss from scratch, training with monitoring, and evaluating the aligned model. We also implement length-controlled DPO and compare alignment quality before and after training.
Learning Objectives
- Implement the DPO loss function from first principles.
- Compute sequence-level log probabilities for the policy and reference models.
- Train a DPO model with proper monitoring of implicit rewards and accuracy.
- Implement length-controlled DPO to prevent verbosity bias.
- Evaluate alignment quality with multiple metrics.
Step 1: DPO Loss Implementation
"""Aligning a language model with Direct Preference Optimization.
Implements DPO from scratch: loss function, log-probability computation,
training loop, monitoring, and evaluation.
Requirements:
pip install torch transformers peft trl datasets
"""
from dataclasses import dataclass, field
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
torch.manual_seed(42)
def compute_log_probs(
model: AutoModelForCausalLM,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: torch.Tensor,
) -> torch.Tensor:
"""Compute per-sequence log probabilities under a model.
Args:
model: Causal language model.
input_ids: Token IDs of shape (batch, seq_len).
attention_mask: Attention mask of shape (batch, seq_len).
labels: Target token IDs of shape (batch, seq_len).
Positions with -100 are ignored.
Returns:
Per-sequence log probabilities of shape (batch,).
"""
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
# Shift logits and labels for next-token prediction
shift_logits = logits[:, :-1, :]
shift_labels = labels[:, 1:]
shift_mask = (shift_labels != -100).float()
# Compute per-token log probabilities
log_probs = F.log_softmax(shift_logits, dim=-1)
token_log_probs = log_probs.gather(
2, shift_labels.clamp(min=0).unsqueeze(-1)
).squeeze(-1)
# Mask padding and sum over sequence
token_log_probs = token_log_probs * shift_mask
sequence_log_probs = token_log_probs.sum(dim=-1)
return sequence_log_probs
def dpo_loss(
policy_chosen_logps: torch.Tensor,
policy_rejected_logps: torch.Tensor,
ref_chosen_logps: torch.Tensor,
ref_rejected_logps: torch.Tensor,
beta: float = 0.1,
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
"""Compute the DPO loss.
Loss = -log(sigma(beta * (log(pi/pi_ref)(y_w) - log(pi/pi_ref)(y_l))))
Args:
policy_chosen_logps: Log probs of chosen under policy (batch,).
policy_rejected_logps: Log probs of rejected under policy (batch,).
ref_chosen_logps: Log probs of chosen under reference (batch,).
ref_rejected_logps: Log probs of rejected under reference (batch,).
beta: KL regularization strength.
Returns:
Tuple of (loss, metrics_dict) where metrics_dict contains
chosen_rewards, rejected_rewards, and accuracy.
"""
# Implicit rewards
chosen_rewards = beta * (policy_chosen_logps - ref_chosen_logps)
rejected_rewards = beta * (policy_rejected_logps - ref_rejected_logps)
# DPO loss
reward_margin = chosen_rewards - rejected_rewards
loss = -F.logsigmoid(reward_margin).mean()
# Metrics
accuracy = (chosen_rewards > rejected_rewards).float().mean()
metrics = {
"chosen_rewards": chosen_rewards.detach().mean(),
"rejected_rewards": rejected_rewards.detach().mean(),
"reward_margin": reward_margin.detach().mean(),
"accuracy": accuracy.detach(),
}
return loss, metrics
def length_controlled_dpo_loss(
policy_chosen_logps: torch.Tensor,
policy_rejected_logps: torch.Tensor,
ref_chosen_logps: torch.Tensor,
ref_rejected_logps: torch.Tensor,
chosen_lengths: torch.Tensor,
rejected_lengths: torch.Tensor,
beta: float = 0.1,
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
"""Compute length-controlled DPO loss.
Normalizes log probabilities by sequence length to prevent
verbosity bias.
Args:
policy_chosen_logps: Log probs of chosen under policy.
policy_rejected_logps: Log probs of rejected under policy.
ref_chosen_logps: Log probs of chosen under reference.
ref_rejected_logps: Log probs of rejected under reference.
chosen_lengths: Number of response tokens for chosen.
rejected_lengths: Number of response tokens for rejected.
beta: KL regularization strength.
Returns:
Tuple of (loss, metrics_dict).
"""
# Length-normalized implicit rewards
chosen_rewards = beta * (
(policy_chosen_logps - ref_chosen_logps) / chosen_lengths.float()
)
rejected_rewards = beta * (
(policy_rejected_logps - ref_rejected_logps) / rejected_lengths.float()
)
reward_margin = chosen_rewards - rejected_rewards
loss = -F.logsigmoid(reward_margin).mean()
accuracy = (chosen_rewards > rejected_rewards).float().mean()
metrics = {
"chosen_rewards": chosen_rewards.detach().mean(),
"rejected_rewards": rejected_rewards.detach().mean(),
"reward_margin": reward_margin.detach().mean(),
"accuracy": accuracy.detach(),
}
return loss, metrics
Step 2: Preference Data Preparation
@dataclass
class DPOExample:
"""A DPO training example.
Attributes:
prompt: The input prompt.
chosen: The preferred response.
rejected: The dispreferred response.
"""
prompt: str
chosen: str
rejected: str
DPO_TRAINING_DATA = [
DPOExample(
prompt="How do I sort a list in Python?",
chosen=(
"You can sort a list in Python using the `sorted()` function "
"or the `.sort()` method. `sorted()` returns a new sorted list "
"while `.sort()` modifies the list in-place.\n\n"
"```python\nnumbers = [3, 1, 4, 1, 5]\n"
"sorted_numbers = sorted(numbers) # [1, 1, 3, 4, 5]\n"
"numbers.sort() # modifies in-place\n```"
),
rejected=(
"There are many ways to sort things in Python. You could "
"use sort or sorted or even write your own sorting algorithm "
"like bubble sort or merge sort. Sorting is a fundamental "
"operation in computer science."
),
),
DPOExample(
prompt="What causes rain?",
chosen=(
"Rain forms through the water cycle. Water evaporates from "
"surfaces, rises as water vapor, cools at higher altitudes, "
"and condenses around tiny particles to form cloud droplets. "
"When these droplets combine and become heavy enough, they "
"fall as rain."
),
rejected=(
"Rain is when water falls from the sky. It happens because "
"of clouds. Clouds have water in them and sometimes the "
"water falls down."
),
),
DPOExample(
prompt="Should I invest all my savings in cryptocurrency?",
chosen=(
"I would not recommend investing all your savings in "
"cryptocurrency. Cryptocurrencies are highly volatile and "
"can lose significant value quickly. Financial advisors "
"generally recommend diversifying investments across asset "
"classes and only investing money you can afford to lose in "
"high-risk assets. Consider consulting a certified financial "
"advisor for personalized advice."
),
rejected=(
"Cryptocurrency is the future of finance! Bitcoin and "
"Ethereum have made many people wealthy. You should "
"definitely go all-in. The prices will only go up from here."
),
),
]
def prepare_dpo_batch(
examples: list[DPOExample],
tokenizer: AutoTokenizer,
max_length: int = 512,
) -> dict[str, torch.Tensor]:
"""Prepare a batch of DPO examples for training.
Args:
examples: List of DPO examples.
tokenizer: HuggingFace tokenizer.
max_length: Maximum sequence length.
Returns:
Dictionary with tokenized chosen and rejected sequences.
"""
chosen_texts = [f"{ex.prompt}\n\n{ex.chosen}" for ex in examples]
rejected_texts = [f"{ex.prompt}\n\n{ex.rejected}" for ex in examples]
chosen_enc = tokenizer(
chosen_texts, truncation=True, max_length=max_length,
padding="max_length", return_tensors="pt",
)
rejected_enc = tokenizer(
rejected_texts, truncation=True, max_length=max_length,
padding="max_length", return_tensors="pt",
)
# Create labels (mask padding with -100)
chosen_labels = chosen_enc["input_ids"].clone()
chosen_labels[chosen_enc["attention_mask"] == 0] = -100
rejected_labels = rejected_enc["input_ids"].clone()
rejected_labels[rejected_enc["attention_mask"] == 0] = -100
return {
"chosen_input_ids": chosen_enc["input_ids"],
"chosen_attention_mask": chosen_enc["attention_mask"],
"chosen_labels": chosen_labels,
"rejected_input_ids": rejected_enc["input_ids"],
"rejected_attention_mask": rejected_enc["attention_mask"],
"rejected_labels": rejected_labels,
}
Step 3: Training Loop with Monitoring
@dataclass
class DPOTrainingMetrics:
"""Aggregated metrics from DPO training.
Attributes:
epoch: Training epoch number.
loss: Average DPO loss.
chosen_reward: Average implicit reward for chosen responses.
rejected_reward: Average implicit reward for rejected responses.
reward_margin: Average margin between chosen and rejected.
accuracy: Fraction where chosen reward > rejected reward.
"""
epoch: int
loss: float
chosen_reward: float
rejected_reward: float
reward_margin: float
accuracy: float
def train_dpo(
policy_model: AutoModelForCausalLM,
ref_model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
training_data: list[DPOExample],
beta: float = 0.1,
learning_rate: float = 5e-6,
num_epochs: int = 3,
device: str = "cpu",
) -> list[DPOTrainingMetrics]:
"""Train a model with DPO.
Args:
policy_model: The model to train (initialized from SFT).
ref_model: The frozen reference model (SFT checkpoint).
tokenizer: HuggingFace tokenizer.
training_data: List of preference examples.
beta: DPO beta parameter (KL regularization strength).
learning_rate: Learning rate for the optimizer.
num_epochs: Number of training epochs.
device: Device for training.
Returns:
List of training metrics per epoch.
"""
policy_model = policy_model.to(device)
ref_model = ref_model.to(device)
ref_model.eval()
optimizer = torch.optim.AdamW(policy_model.parameters(), lr=learning_rate)
history: list[DPOTrainingMetrics] = []
for epoch in range(num_epochs):
policy_model.train()
batch = prepare_dpo_batch(training_data, tokenizer)
batch = {k: v.to(device) for k, v in batch.items()}
# Compute log probs under policy
policy_chosen_logps = compute_log_probs(
policy_model,
batch["chosen_input_ids"],
batch["chosen_attention_mask"],
batch["chosen_labels"],
)
policy_rejected_logps = compute_log_probs(
policy_model,
batch["rejected_input_ids"],
batch["rejected_attention_mask"],
batch["rejected_labels"],
)
# Compute log probs under reference (no grad)
ref_chosen_logps = compute_log_probs(
ref_model,
batch["chosen_input_ids"],
batch["chosen_attention_mask"],
batch["chosen_labels"],
)
ref_rejected_logps = compute_log_probs(
ref_model,
batch["rejected_input_ids"],
batch["rejected_attention_mask"],
batch["rejected_labels"],
)
# Enable gradients for policy log probs
policy_model.train()
outputs_chosen = policy_model(
input_ids=batch["chosen_input_ids"],
attention_mask=batch["chosen_attention_mask"],
)
outputs_rejected = policy_model(
input_ids=batch["rejected_input_ids"],
attention_mask=batch["rejected_attention_mask"],
)
# Recompute log probs with gradients
def compute_logps_with_grad(logits, labels):
shift_logits = logits[:, :-1, :]
shift_labels = labels[:, 1:]
mask = (shift_labels != -100).float()
log_probs = F.log_softmax(shift_logits, dim=-1)
token_lps = log_probs.gather(
2, shift_labels.clamp(min=0).unsqueeze(-1)
).squeeze(-1)
return (token_lps * mask).sum(dim=-1)
policy_c_lps = compute_logps_with_grad(
outputs_chosen.logits, batch["chosen_labels"]
)
policy_r_lps = compute_logps_with_grad(
outputs_rejected.logits, batch["rejected_labels"]
)
loss, metrics = dpo_loss(
policy_c_lps, policy_r_lps,
ref_chosen_logps, ref_rejected_logps,
beta=beta,
)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(policy_model.parameters(), 1.0)
optimizer.step()
epoch_metrics = DPOTrainingMetrics(
epoch=epoch + 1,
loss=loss.item(),
chosen_reward=metrics["chosen_rewards"].item(),
rejected_reward=metrics["rejected_rewards"].item(),
reward_margin=metrics["reward_margin"].item(),
accuracy=metrics["accuracy"].item(),
)
history.append(epoch_metrics)
print(
f"Epoch {epoch + 1}: loss={epoch_metrics.loss:.4f}, "
f"margin={epoch_metrics.reward_margin:.4f}, "
f"acc={epoch_metrics.accuracy:.1%}"
)
return history
Step 4: Evaluation
def evaluate_alignment(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
test_prompts: list[str],
device: str = "cpu",
) -> list[dict[str, str]]:
"""Generate responses and evaluate alignment quality.
Args:
model: The aligned model.
tokenizer: HuggingFace tokenizer.
test_prompts: List of prompts to test.
device: Device for inference.
Returns:
List of dicts with prompt and generated response.
"""
model.eval()
results = []
for prompt in test_prompts:
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=200,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
)
response = tokenizer.decode(
output_ids[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True,
)
results.append({"prompt": prompt, "response": response.strip()})
return results
Step 5: Demonstration
if __name__ == "__main__":
print("=" * 60)
print("Case Study 2: Aligning with DPO")
print("=" * 60)
print(f"\nTraining examples: {len(DPO_TRAINING_DATA)}")
for i, ex in enumerate(DPO_TRAINING_DATA):
print(f" Example {i}: {ex.prompt[:50]}...")
print("\nDPO Loss Components:")
print(" 1. Compute log P(y_w|x) and log P(y_l|x) under policy")
print(" 2. Compute log P(y_w|x) and log P(y_l|x) under reference")
print(" 3. Implicit rewards: beta * (log pi - log pi_ref)")
print(" 4. Loss: -log sigma(r_w - r_l)")
print("\nMonitoring Metrics:")
print(" - Chosen rewards (should increase)")
print(" - Rejected rewards (should decrease)")
print(" - Reward margin (should increase)")
print(" - Accuracy (should approach 1.0)")
print(" - KL divergence (should remain bounded)")
# Demonstrate loss computation with synthetic values
print("\n--- Synthetic Loss Demo ---")
policy_c = torch.tensor([-10.0, -12.0, -8.0])
policy_r = torch.tensor([-15.0, -14.0, -13.0])
ref_c = torch.tensor([-11.0, -13.0, -9.0])
ref_r = torch.tensor([-14.0, -13.0, -12.0])
loss, metrics = dpo_loss(policy_c, policy_r, ref_c, ref_r, beta=0.1)
print(f" Loss: {loss.item():.4f}")
print(f" Chosen reward: {metrics['chosen_rewards'].item():.4f}")
print(f" Rejected reward: {metrics['rejected_rewards'].item():.4f}")
print(f" Margin: {metrics['reward_margin'].item():.4f}")
print(f" Accuracy: {metrics['accuracy'].item():.1%}")
print("\nTo run full DPO training, execute with a compatible GPU.")
Key Takeaways
- DPO is fundamentally simpler than RLHF. It replaces the reward model and PPO with a single supervised loss computed from preference pairs and log probabilities under policy and reference models.
- The DPO $\beta$ parameter is the most important hyperparameter. Lower values (0.1) allow more deviation from the reference; higher values (0.5) are more conservative. Start low and increase if the model degrades.
- Monitor implicit rewards, not just loss. The chosen reward should increase, rejected should decrease, and the margin should grow. If these diverge from expectations, the training is likely unstable.
- Length-controlled DPO prevents verbosity bias by normalizing implicit rewards by response length. Without this, models tend to become increasingly verbose.
- The reference model must remain frozen throughout training. It provides the anchor that prevents the policy from drifting too far from sensible behavior.