Every other learning paradigm we have studied in this book shares a common structure: you have data, you have labels (or at least unlabeled data with inherent structure), and you optimize a model to fit that data. Reinforcement learning (RL) breaks...
In This Chapter
- Introduction: Learning Through Interaction
- 36.1 The Reinforcement Learning Problem
- 36.2 Markov Decision Processes
- 36.3 Value Functions and the Bellman Equations
- 36.4 Tabular Methods: Q-Learning and SARSA
- 36.5 Deep Q-Networks (DQN)
- 36.6 Policy Gradient Methods
- 36.7 Actor-Critic Methods
- 36.8 Proximal Policy Optimization (PPO)
- 36.9 RL for Language Models: RLHF and GRPO
- 36.10 Working with Gymnasium Environments
- 36.11 Challenges in Reinforcement Learning
- 36.12 Multi-Agent RL (Brief Introduction)
- 36.13 Practical Tips for RL Engineering
- 36.14 Training Loop: Putting It All Together
- Summary
Chapter 36: Reinforcement Learning for AI Engineers
Introduction: Learning Through Interaction
Every other learning paradigm we have studied in this book shares a common structure: you have data, you have labels (or at least unlabeled data with inherent structure), and you optimize a model to fit that data. Reinforcement learning (RL) breaks this mold entirely. In RL, an agent learns by interacting with an environment, receiving rewards or penalties for its actions, and gradually discovering strategies that maximize long-term cumulative reward. There is no dataset of correct answers handed to the learner. Instead, the agent must explore, experiment, and learn from the consequences of its own decisions.
This distinction is not merely academic. Reinforcement learning has produced some of the most dramatic achievements in modern AI: DeepMind's AlphaGo defeating the world champion in Go, OpenAI Five playing Dota 2 at a professional level, and---most relevant to contemporary AI engineering---reinforcement learning from human feedback (RLHF) aligning large language models with human preferences. If you have used ChatGPT, Claude, or any modern conversational AI, you have interacted with a system shaped by reinforcement learning.
This chapter provides a rigorous yet practical introduction to reinforcement learning for AI engineers working across the full spectrum of modern applications. We will build from the mathematical foundations of Markov Decision Processes through classical algorithms like Q-learning, then progress to deep RL methods including DQN, policy gradients, and PPO. We will connect these ideas to the RLHF techniques introduced in Chapter 25 and cover newer approaches like GRPO. Throughout, we will use Gymnasium (the maintained successor to OpenAI Gym) for practical implementations.
Prerequisites
Before diving in, you should be comfortable with: - Neural network training with PyTorch (Chapters 4-7) - Basic probability and statistics (Chapter 2) - Optimization concepts (gradient descent, loss functions) - Familiarity with RLHF concepts from Chapter 25 is helpful but not required
36.1 The Reinforcement Learning Problem
36.1.1 What Makes RL Different
Consider three learning scenarios:
- Supervised learning: Given images and labels, learn to classify images.
- Unsupervised learning: Given unlabeled data, discover structure.
- Reinforcement learning: Given an environment, learn to act to maximize reward.
The key differences in RL are:
- No supervisor: The agent receives only a scalar reward signal, not the correct action.
- Delayed reward: The consequences of an action may not be apparent for many steps.
- Sequential decisions: Actions affect future states and future rewards.
- Exploration vs. exploitation: The agent must balance trying new actions (exploration) with using known good actions (exploitation).
36.1.2 The Agent-Environment Interface
The RL framework consists of two entities:
- Agent: The learner and decision-maker.
- Environment: Everything the agent interacts with.
At each discrete time step $t$: 1. The agent observes state $s_t \in \mathcal{S}$ 2. The agent selects action $a_t \in \mathcal{A}$ 3. The environment transitions to state $s_{t+1}$ 4. The agent receives reward $r_{t+1} \in \mathbb{R}$
This interaction loop continues until a terminal state is reached (in episodic tasks) or indefinitely (in continuing tasks).
import gymnasium as gym
import torch
import numpy as np
torch.manual_seed(42)
np.random.seed(42)
# The agent-environment loop in Gymnasium
env = gym.make("CartPole-v1")
state, info = env.reset(seed=42)
total_reward = 0.0
done = False
while not done:
# Agent selects an action (here, random)
action = env.action_space.sample()
# Environment responds with next state, reward, termination signals
next_state, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
total_reward += reward
state = next_state
print(f"Episode reward: {total_reward}")
env.close()
36.2 Markov Decision Processes
36.2.1 Formal Definition
The mathematical framework underlying RL is the Markov Decision Process (MDP). An MDP is defined by the tuple $(\mathcal{S}, \mathcal{A}, P, R, \gamma)$:
- $\mathcal{S}$: State space (finite or continuous)
- $\mathcal{A}$: Action space (finite or continuous)
- $P(s' | s, a)$: Transition probability function---the probability of transitioning to state $s'$ given state $s$ and action $a$
- $R(s, a, s')$: Reward function---the immediate reward received
- $\gamma \in [0, 1)$: Discount factor---how much future rewards are valued relative to immediate rewards
The Markov property states that the future depends only on the current state, not on the history of how we arrived there:
$$P(s_{t+1} | s_t, a_t, s_{t-1}, a_{t-1}, \ldots) = P(s_{t+1} | s_t, a_t)$$
This property is both a simplification and a constraint. Many real-world problems are not truly Markov, but we can often make them approximately Markov by including enough information in the state representation.
36.2.2 Policies
A policy $\pi$ defines the agent's behavior. It maps states to actions (or distributions over actions):
- Deterministic policy: $a = \pi(s)$
- Stochastic policy: $a \sim \pi(a | s)$
The goal of RL is to find an optimal policy $\pi^*$ that maximizes expected cumulative reward.
36.2.3 Returns and Discounting
The return $G_t$ is the cumulative discounted reward from time step $t$:
$$G_t = r_{t+1} + \gamma r_{t+2} + \gamma^2 r_{t+3} + \cdots = \sum_{k=0}^{\infty} \gamma^k r_{t+k+1}$$
The discount factor $\gamma$ serves multiple purposes: - Mathematical: Ensures the sum converges for infinite-horizon problems. - Practical: Encodes a preference for sooner rewards over later ones. - Behavioral: $\gamma = 0$ makes the agent myopic; $\gamma \to 1$ makes it far-sighted.
Typical values range from 0.95 to 0.99.
36.3 Value Functions and the Bellman Equations
36.3.1 State-Value Function
The state-value function $V^\pi(s)$ measures how good it is to be in state $s$ under policy $\pi$:
$$V^\pi(s) = \mathbb{E}_\pi[G_t | S_t = s] = \mathbb{E}_\pi\left[\sum_{k=0}^{\infty} \gamma^k r_{t+k+1} \bigg| S_t = s\right]$$
36.3.2 Action-Value Function
The action-value function $Q^\pi(s, a)$ measures how good it is to take action $a$ in state $s$ under policy $\pi$:
$$Q^\pi(s, a) = \mathbb{E}_\pi[G_t | S_t = s, A_t = a]$$
The relationship between $V$ and $Q$ is:
$$V^\pi(s) = \sum_{a \in \mathcal{A}} \pi(a|s) Q^\pi(s, a)$$
36.3.3 The Bellman Equations
The Bellman equations express a recursive relationship: the value of a state equals the immediate reward plus the discounted value of the next state.
Bellman Expectation Equation (for a given policy $\pi$):
$$V^\pi(s) = \sum_{a} \pi(a|s) \sum_{s'} P(s'|s,a) \left[R(s,a,s') + \gamma V^\pi(s')\right]$$
$$Q^\pi(s,a) = \sum_{s'} P(s'|s,a) \left[R(s,a,s') + \gamma \sum_{a'} \pi(a'|s') Q^\pi(s',a')\right]$$
Bellman Optimality Equation (for the optimal policy):
$$V^*(s) = \max_{a} \sum_{s'} P(s'|s,a) \left[R(s,a,s') + \gamma V^*(s')\right]$$
$$Q^*(s,a) = \sum_{s'} P(s'|s,a) \left[R(s,a,s') + \gamma \max_{a'} Q^*(s',a')\right]$$
These equations are the foundation of nearly every RL algorithm. If we know $Q^*$, the optimal policy is simply $\pi^*(s) = \arg\max_a Q^*(s,a)$.
36.3.4 Intuition Behind the Bellman Equations
Think of the Bellman equation as a consistency condition. If you know how good the future will be (the value of the next state), you can figure out how good the present is (the value of the current state) by adding the immediate reward. This recursive decomposition allows us to break a complex, long-horizon decision problem into a sequence of one-step problems.
36.4 Tabular Methods: Q-Learning and SARSA
36.4.1 Temporal Difference Learning
When we do not know the transition dynamics $P(s'|s,a)$, we cannot solve the Bellman equations directly. Instead, we use temporal difference (TD) learning, which updates value estimates based on observed transitions.
The TD update rule for state values:
$$V(s_t) \leftarrow V(s_t) + \alpha \left[r_{t+1} + \gamma V(s_{t+1}) - V(s_t)\right]$$
where $\alpha$ is the learning rate and $\delta_t = r_{t+1} + \gamma V(s_{t+1}) - V(s_t)$ is the TD error.
Intuition: The TD error measures the surprise -- the difference between what we expected (our current value estimate) and what we actually observed (the immediate reward plus the discounted estimate of the next state's value). If $\delta_t > 0$, the outcome was better than expected, so we increase the value estimate. If $\delta_t < 0$, the outcome was worse than expected.
TD learning sits between two extremes. Monte Carlo methods wait until the end of an episode to compute actual returns and update value estimates -- these have zero bias but high variance because individual returns can vary significantly. Dynamic programming methods use the complete model of the environment to compute exact expectations -- these have zero variance but require a model. TD learning uses a one-step sample of the return and bootstraps (uses its own estimate for the future), giving it lower variance than Monte Carlo and not requiring a model.
The relationship between these approaches can be expressed through n-step returns:
$$G_t^{(n)} = r_{t+1} + \gamma r_{t+2} + \cdots + \gamma^{n-1} r_{t+n} + \gamma^n V(s_{t+n})$$
When $n = 1$, we get TD(0) -- the standard one-step TD update. When $n = \infty$ (or the episode terminates), we get Monte Carlo. The parameter $n$ controls the bias-variance trade-off: larger $n$ reduces bias (we use more actual rewards) but increases variance (we sum more random variables). This idea generalizes to the TD($\lambda$) algorithm, which takes an exponentially weighted average of all n-step returns, and later to GAE (Section 36.7.3).
36.4.2 Q-Learning
Q-learning (Watkins, 1989) is an off-policy TD control algorithm that directly learns $Q^*$:
$$Q(s_t, a_t) \leftarrow Q(s_t, a_t) + \alpha \left[r_{t+1} + \gamma \max_{a'} Q(s_{t+1}, a') - Q(s_t, a_t)\right]$$
Key properties: - Off-policy: It learns about the optimal policy regardless of the policy being followed. - Model-free: It does not require knowledge of transition dynamics. - Convergence: Guaranteed to converge to $Q^*$ given sufficient exploration and appropriate learning rate decay.
36.4.3 SARSA
SARSA (State-Action-Reward-State-Action) is the on-policy counterpart:
$$Q(s_t, a_t) \leftarrow Q(s_t, a_t) + \alpha \left[r_{t+1} + \gamma Q(s_{t+1}, a_{t+1}) - Q(s_t, a_t)\right]$$
The difference is subtle but important: Q-learning uses $\max_{a'} Q(s_{t+1}, a')$ (the best possible next action), while SARSA uses $Q(s_{t+1}, a_{t+1})$ (the action actually taken). This makes SARSA more conservative -- it accounts for the exploration policy.
Worked Example: Cliff Walking. Consider a grid world where an agent must walk from a start position to a goal, with a cliff along the bottom edge. Falling off the cliff gives a large negative reward.
- Q-learning learns the optimal path (walking along the edge of the cliff) because it evaluates the greedy policy, which never falls. But during training, the epsilon-greedy behavior policy occasionally falls off, incurring large penalties.
- SARSA learns a safer path (walking further from the cliff) because it evaluates the policy it is actually following, which includes random exploration steps that might walk off the cliff.
This illustrates a general principle: Q-learning (off-policy) is optimistic -- it learns about the best possible behavior even while behaving suboptimally. SARSA (on-policy) is realistic -- it learns about the policy it actually follows, including its mistakes.
| Property | Q-Learning | SARSA |
|---|---|---|
| Policy learned | Optimal (greedy) | The current behavior policy |
| On/Off-policy | Off-policy | On-policy |
| Safety | Optimistic (ignores exploration risk) | Conservative (accounts for exploration) |
| Convergence | To $Q^*$ regardless of policy | To $Q^\pi$ for current policy |
| Use case | When execution will be greedy | When execution includes randomness |
36.4.4 Exploration Strategies
The agent needs to explore to discover good actions, but also exploit known good actions. Common strategies:
Epsilon-greedy: $$a_t = \begin{cases} \arg\max_a Q(s_t, a) & \text{with probability } 1 - \epsilon \\ \text{random action} & \text{with probability } \epsilon \end{cases}$$
Boltzmann (softmax) exploration: $$\pi(a|s) = \frac{\exp(Q(s,a) / \tau)}{\sum_{a'} \exp(Q(s,a') / \tau)}$$
where $\tau$ is the temperature parameter.
Upper Confidence Bound (UCB): $$a_t = \arg\max_a \left[Q(s_t, a) + c\sqrt{\frac{\ln t}{N(s_t, a)}}\right]$$
import numpy as np
torch.manual_seed(42)
np.random.seed(42)
class QLearningAgent:
"""Tabular Q-learning agent.
Args:
n_states: Number of discrete states.
n_actions: Number of discrete actions.
learning_rate: Step size for Q-value updates.
gamma: Discount factor for future rewards.
epsilon: Exploration rate for epsilon-greedy policy.
epsilon_decay: Multiplicative decay for epsilon per episode.
epsilon_min: Minimum value for epsilon.
"""
def __init__(
self,
n_states: int,
n_actions: int,
learning_rate: float = 0.1,
gamma: float = 0.99,
epsilon: float = 1.0,
epsilon_decay: float = 0.995,
epsilon_min: float = 0.01,
) -> None:
self.q_table = np.zeros((n_states, n_actions))
self.lr = learning_rate
self.gamma = gamma
self.epsilon = epsilon
self.epsilon_decay = epsilon_decay
self.epsilon_min = epsilon_min
self.n_actions = n_actions
def select_action(self, state: int) -> int:
"""Select action using epsilon-greedy policy.
Args:
state: Current discrete state.
Returns:
Selected action index.
"""
if np.random.random() < self.epsilon:
return np.random.randint(self.n_actions)
return int(np.argmax(self.q_table[state]))
def update(
self, state: int, action: int, reward: float, next_state: int, done: bool
) -> float:
"""Update Q-value using the Q-learning update rule.
Args:
state: Current state.
action: Action taken.
reward: Reward received.
next_state: Next state observed.
done: Whether the episode ended.
Returns:
The TD error.
"""
target = reward
if not done:
target += self.gamma * np.max(self.q_table[next_state])
td_error = target - self.q_table[state, action]
self.q_table[state, action] += self.lr * td_error
return td_error
def decay_epsilon(self) -> None:
"""Decay epsilon after each episode."""
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
36.5 Deep Q-Networks (DQN)
36.5.1 Why Go Deep?
Tabular Q-learning works well when the state and action spaces are small and discrete. But many interesting problems have large or continuous state spaces: an Atari game has roughly $10^{30,000}$ possible screen configurations. We need function approximation.
The idea is simple: replace the Q-table with a neural network $Q(s, a; \theta)$ parameterized by weights $\theta$. The network takes a state as input and outputs Q-values for all actions.
36.5.2 The DQN Algorithm
The seminal DQN paper (Mnih et al., 2015) introduced two key innovations that made deep Q-learning stable:
Experience Replay: Store transitions $(s, a, r, s', \text{done})$ in a replay buffer $\mathcal{D}$ of capacity $N$ and sample mini-batches uniformly at random for training. This addresses two critical issues:
-
Breaking temporal correlations. Consecutive transitions in an episode are highly correlated (state $s_{t+1}$ is similar to $s_t$). Training on correlated batches leads to poor convergence because the gradient estimates are biased toward the current region of state space. Random sampling from a large buffer decorrelates the training data.
-
Data reuse. Each transition can be used for multiple gradient updates, improving sample efficiency by a factor of 10-100x compared to on-policy methods that discard data after one update.
The replay buffer implements a FIFO queue: when full, the oldest transition is removed to make room for new data. The uniform sampling probability for each transition is $P(i) = 1/|\mathcal{D}|$. Prioritized experience replay improves on this by sampling proportional to the TD error: $P(i) \propto |\delta_i|^\alpha$, where $\alpha$ controls how much prioritization is used. This focuses learning on the most surprising transitions but requires importance sampling corrections to avoid bias.
Target Network: Maintain a separate target network $Q(s, a; \theta^-)$ with parameters $\theta^-$ that are periodically copied from the main network. Without a target network, the loss function is:
$$\mathcal{L} = (r + \gamma \max_{a'} Q(s', a'; \theta) - Q(s, a; \theta))^2$$
The problem is that the target $r + \gamma \max_{a'} Q(s', a'; \theta)$ changes with every gradient step, creating a moving target that can cause oscillation or divergence. The target network provides a stable target that is updated less frequently (every $C$ steps), decoupling the target from the optimization.
The loss function:
$$\mathcal{L}(\theta) = \mathbb{E}_{(s,a,r,s') \sim \mathcal{D}} \left[\left(r + \gamma \max_{a'} Q(s', a'; \theta^-) - Q(s, a; \theta)\right)^2\right]$$
36.5.3 DQN Improvements
Several improvements to vanilla DQN have been developed:
Double DQN (van Hasselt et al., 2016): Standard DQN uses the max operator both to select and evaluate the next action, which leads to systematic overestimation. The intuition is simple: if our Q-value estimates are noisy (and they always are), taking the max over noisy estimates produces a value that is biased upward. Double DQN decouples selection from evaluation:
$$y = r + \gamma Q(s', \arg\max_{a'} Q(s', a'; \theta); \theta^-)$$
The policy network $\theta$ selects the best action, but the target network $\theta^-$ evaluates it. Since the two networks have independent noise, the overestimation is significantly reduced.
Dueling DQN (Wang et al., 2016): Decomposes the Q-function into a state-value component and a state-dependent action advantage:
$$Q(s, a; \theta) = V(s; \theta_v) + A(s, a; \theta_a) - \frac{1}{|\mathcal{A}|}\sum_{a'} A(s, a'; \theta_a)$$
The advantage function $A(s, a)$ measures how much better action $a$ is compared to the average action in state $s$. Subtracting the mean advantage ensures identifiability (otherwise $V$ and $A$ could shift by a constant). This decomposition is beneficial when many states have similar values regardless of the action taken -- the network can learn the state value once rather than relearning it for each action.
Prioritized Experience Replay: Samples transitions proportional to their TD error magnitude $|\delta_i|$, focusing learning on surprising or important transitions. Transitions with large TD errors represent situations where the agent's predictions were most wrong, so they contain the most information for learning.
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random
from typing import Tuple, List
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
class ReplayBuffer:
"""Experience replay buffer for DQN.
Args:
capacity: Maximum number of transitions to store.
"""
def __init__(self, capacity: int = 100000) -> None:
self.buffer: deque = deque(maxlen=capacity)
def push(
self,
state: np.ndarray,
action: int,
reward: float,
next_state: np.ndarray,
done: bool,
) -> None:
"""Store a transition in the buffer.
Args:
state: Current state observation.
action: Action taken.
reward: Reward received.
next_state: Next state observation.
done: Whether the episode terminated.
"""
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size: int) -> Tuple:
"""Sample a random mini-batch of transitions.
Args:
batch_size: Number of transitions to sample.
Returns:
Tuple of batched states, actions, rewards, next_states, dones.
"""
transitions = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*transitions)
return (
np.array(states),
np.array(actions),
np.array(rewards, dtype=np.float32),
np.array(next_states),
np.array(dones, dtype=np.float32),
)
def __len__(self) -> int:
return len(self.buffer)
class DQNetwork(nn.Module):
"""Deep Q-Network with dueling architecture option.
Args:
state_dim: Dimension of the state space.
action_dim: Number of possible actions.
hidden_dim: Number of hidden units per layer.
dueling: Whether to use dueling architecture.
"""
def __init__(
self,
state_dim: int,
action_dim: int,
hidden_dim: int = 128,
dueling: bool = False,
) -> None:
super().__init__()
self.dueling = dueling
self.feature = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
)
if dueling:
self.value_stream = nn.Linear(hidden_dim, 1)
self.advantage_stream = nn.Linear(hidden_dim, action_dim)
else:
self.output_layer = nn.Linear(hidden_dim, action_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass computing Q-values for all actions.
Args:
x: State tensor of shape (batch_size, state_dim).
Returns:
Q-values of shape (batch_size, action_dim).
"""
features = self.feature(x)
if self.dueling:
value = self.value_stream(features)
advantage = self.advantage_stream(features)
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
else:
q_values = self.output_layer(features)
return q_values
36.5.4 Training the DQN
class DQNAgent:
"""Deep Q-Network agent with experience replay and target network.
Args:
state_dim: Dimension of state observations.
action_dim: Number of discrete actions.
hidden_dim: Hidden layer size.
lr: Learning rate.
gamma: Discount factor.
epsilon_start: Initial exploration rate.
epsilon_end: Final exploration rate.
epsilon_decay: Decay steps for epsilon.
buffer_size: Replay buffer capacity.
batch_size: Training batch size.
target_update_freq: Steps between target network updates.
dueling: Whether to use dueling architecture.
double_dqn: Whether to use double DQN.
"""
def __init__(
self,
state_dim: int,
action_dim: int,
hidden_dim: int = 128,
lr: float = 1e-3,
gamma: float = 0.99,
epsilon_start: float = 1.0,
epsilon_end: float = 0.01,
epsilon_decay: int = 10000,
buffer_size: int = 100000,
batch_size: int = 64,
target_update_freq: int = 1000,
dueling: bool = False,
double_dqn: bool = False,
) -> None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.action_dim = action_dim
self.gamma = gamma
self.batch_size = batch_size
self.target_update_freq = target_update_freq
self.double_dqn = double_dqn
# Epsilon schedule
self.epsilon_start = epsilon_start
self.epsilon_end = epsilon_end
self.epsilon_decay = epsilon_decay
self.steps_done = 0
# Networks
self.policy_net = DQNetwork(
state_dim, action_dim, hidden_dim, dueling
).to(self.device)
self.target_net = DQNetwork(
state_dim, action_dim, hidden_dim, dueling
).to(self.device)
self.target_net.load_state_dict(self.policy_net.state_dict())
self.target_net.eval()
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
self.buffer = ReplayBuffer(buffer_size)
@property
def epsilon(self) -> float:
"""Current epsilon value based on linear decay schedule."""
return self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
max(0, 1 - self.steps_done / self.epsilon_decay)
def select_action(self, state: np.ndarray) -> int:
"""Select action using epsilon-greedy policy.
Args:
state: Current state observation.
Returns:
Selected action index.
"""
self.steps_done += 1
if random.random() < self.epsilon:
return random.randrange(self.action_dim)
with torch.no_grad():
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
q_values = self.policy_net(state_tensor)
return q_values.argmax(dim=1).item()
def train_step(self) -> float:
"""Perform one training step on a mini-batch from the replay buffer.
Returns:
The mean loss value.
"""
if len(self.buffer) < self.batch_size:
return 0.0
states, actions, rewards, next_states, dones = self.buffer.sample(
self.batch_size
)
states_t = torch.FloatTensor(states).to(self.device)
actions_t = torch.LongTensor(actions).to(self.device)
rewards_t = torch.FloatTensor(rewards).to(self.device)
next_states_t = torch.FloatTensor(next_states).to(self.device)
dones_t = torch.FloatTensor(dones).to(self.device)
# Current Q-values
current_q = self.policy_net(states_t).gather(
1, actions_t.unsqueeze(1)
).squeeze(1)
# Target Q-values
with torch.no_grad():
if self.double_dqn:
# Use policy net to select actions, target net to evaluate
next_actions = self.policy_net(next_states_t).argmax(dim=1)
next_q = self.target_net(next_states_t).gather(
1, next_actions.unsqueeze(1)
).squeeze(1)
else:
next_q = self.target_net(next_states_t).max(dim=1)[0]
target_q = rewards_t + self.gamma * next_q * (1 - dones_t)
loss = nn.functional.mse_loss(current_q, target_q)
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
self.optimizer.step()
# Update target network
if self.steps_done % self.target_update_freq == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())
return loss.item()
36.6 Policy Gradient Methods
36.6.1 From Value-Based to Policy-Based Methods
Value-based methods like DQN learn a value function and derive a policy from it. Policy gradient methods take a fundamentally different approach: they parameterize the policy directly as $\pi_\theta(a|s)$ and optimize it by gradient ascent on the expected return.
Why use policy gradients?
- Continuous action spaces: Value-based methods require maximizing over actions, which is intractable for continuous spaces.
- Stochastic policies: Policy gradients naturally learn stochastic policies, which can be optimal in partially observable environments.
- Better convergence properties: Policy gradient methods have stronger theoretical convergence guarantees in many settings.
36.6.2 The Policy Gradient Theorem
The objective is to maximize the expected return:
$$J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}[R(\tau)] = \mathbb{E}_{\tau \sim \pi_\theta}\left[\sum_{t=0}^{T} \gamma^t r_t\right]$$
The policy gradient theorem (Sutton et al., 1999) gives us:
$$\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}\left[\sum_{t=0}^{T} \nabla_\theta \log \pi_\theta(a_t | s_t) \cdot G_t\right]$$
where $G_t = \sum_{k=t}^{T} \gamma^{k-t} r_k$ is the return from time step $t$.
The intuition is elegant: increase the probability of actions that led to high returns and decrease the probability of actions that led to low returns.
36.6.3 REINFORCE
The simplest policy gradient algorithm is REINFORCE (Williams, 1992):
- Collect a complete episode using $\pi_\theta$.
- For each time step, compute the return $G_t$.
- Update the policy: $\theta \leftarrow \theta + \alpha \sum_t \nabla_\theta \log \pi_\theta(a_t | s_t) G_t$.
A critical improvement is adding a baseline $b(s_t)$ to reduce variance without introducing bias:
$$\nabla_\theta J(\theta) = \mathbb{E}\left[\sum_{t=0}^{T} \nabla_\theta \log \pi_\theta(a_t | s_t) (G_t - b(s_t))\right]$$
A common choice for the baseline is the state-value function $V(s_t)$, giving the advantage $A_t = G_t - V(s_t)$.
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import numpy as np
from typing import List, Tuple
torch.manual_seed(42)
class PolicyNetwork(nn.Module):
"""Simple policy network for discrete action spaces.
Args:
state_dim: Dimension of state observations.
action_dim: Number of discrete actions.
hidden_dim: Hidden layer size.
"""
def __init__(
self, state_dim: int, action_dim: int, hidden_dim: int = 128
) -> None:
super().__init__()
self.network = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Compute action logits.
Args:
x: State tensor of shape (batch_size, state_dim).
Returns:
Action logits of shape (batch_size, action_dim).
"""
return self.network(x)
def get_action(self, state: np.ndarray) -> Tuple[int, torch.Tensor]:
"""Sample action from the policy and return log probability.
Args:
state: Current state observation.
Returns:
Tuple of (action, log_probability).
"""
state_tensor = torch.FloatTensor(state).unsqueeze(0)
logits = self.forward(state_tensor)
dist = Categorical(logits=logits)
action = dist.sample()
return action.item(), dist.log_prob(action)
class REINFORCE:
"""REINFORCE policy gradient algorithm with optional baseline.
Args:
state_dim: Dimension of state observations.
action_dim: Number of discrete actions.
lr: Learning rate.
gamma: Discount factor.
use_baseline: Whether to use a value function baseline.
"""
def __init__(
self,
state_dim: int,
action_dim: int,
lr: float = 1e-3,
gamma: float = 0.99,
use_baseline: bool = True,
) -> None:
self.gamma = gamma
self.use_baseline = use_baseline
self.policy = PolicyNetwork(state_dim, action_dim)
self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)
if use_baseline:
self.value_net = nn.Sequential(
nn.Linear(state_dim, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, 1),
)
self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=lr)
# Episode storage
self.log_probs: List[torch.Tensor] = []
self.rewards: List[float] = []
self.states: List[np.ndarray] = []
def select_action(self, state: np.ndarray) -> int:
"""Select an action and store the log probability.
Args:
state: Current state observation.
Returns:
Selected action.
"""
action, log_prob = self.policy.get_action(state)
self.log_probs.append(log_prob)
self.rewards.append(0.0) # Placeholder, updated after step
self.states.append(state)
return action
def compute_returns(self) -> torch.Tensor:
"""Compute discounted returns for each time step.
Returns:
Tensor of returns for each step in the episode.
"""
returns = []
G = 0.0
for r in reversed(self.rewards):
G = r + self.gamma * G
returns.insert(0, G)
returns_tensor = torch.FloatTensor(returns)
# Normalize returns for stability
if len(returns) > 1:
returns_tensor = (returns_tensor - returns_tensor.mean()) / (
returns_tensor.std() + 1e-8
)
return returns_tensor
def update(self) -> float:
"""Update the policy (and baseline) using collected episode data.
Returns:
The policy loss value.
"""
returns = self.compute_returns()
log_probs = torch.stack(self.log_probs)
if self.use_baseline:
states_tensor = torch.FloatTensor(np.array(self.states))
values = self.value_net(states_tensor).squeeze()
advantages = returns - values.detach()
# Update value network
value_loss = nn.functional.mse_loss(values, returns)
self.value_optimizer.zero_grad()
value_loss.backward()
self.value_optimizer.step()
policy_loss = -(log_probs * advantages).mean()
else:
policy_loss = -(log_probs * returns).mean()
self.optimizer.zero_grad()
policy_loss.backward()
self.optimizer.step()
# Clear episode data
self.log_probs = []
self.rewards = []
self.states = []
return policy_loss.item()
36.7 Actor-Critic Methods
36.7.1 Bridging Value and Policy Methods
REINFORCE uses complete episodes and Monte Carlo returns, resulting in high variance. Actor-critic methods address this by combining:
- Actor: The policy $\pi_\theta(a|s)$ that selects actions.
- Critic: A value function $V_\phi(s)$ (or $Q_\phi(s,a)$) that evaluates actions.
The critic provides a lower-variance estimate of the advantage, enabling online (step-by-step) learning rather than waiting for complete episodes.
36.7.2 Advantage Actor-Critic (A2C)
The advantage function $A(s, a) = Q(s, a) - V(s)$ tells us how much better action $a$ is compared to the average action in state $s$. In practice, we estimate the advantage using TD error:
$$\hat{A}_t = r_{t+1} + \gamma V_\phi(s_{t+1}) - V_\phi(s_t)$$
The update rules: - Critic: Minimize $\mathcal{L}_\text{critic} = \frac{1}{2}\hat{A}_t^2$ - Actor: Maximize $J_\text{actor} = \log \pi_\theta(a_t|s_t) \hat{A}_t$
36.7.3 Generalized Advantage Estimation (GAE)
GAE (Schulman et al., 2015b) provides a smooth interpolation between low-bias, high-variance estimates and high-bias, low-variance estimates using a parameter $\lambda$:
$$\hat{A}_t^{\text{GAE}(\gamma,\lambda)} = \sum_{l=0}^{\infty} (\gamma\lambda)^l \delta_{t+l}$$
where $\delta_t = r_{t+1} + \gamma V(s_{t+1}) - V(s_t)$ is the one-step TD error.
To understand GAE, consider what happens at different values of $\lambda$:
- $\lambda = 0$: $\hat{A}_t = \delta_t = r_{t+1} + \gamma V(s_{t+1}) - V(s_t)$. This is the one-step TD estimate. It has low variance (only one random step) but high bias (it relies entirely on the critic's estimate of $V(s_{t+1})$, which may be inaccurate).
- $\lambda = 1$: $\hat{A}_t = \sum_{l=0}^{\infty} \gamma^l \delta_{t+l} = G_t - V(s_t)$. This reduces to the Monte Carlo advantage estimate. It has zero bias (it uses actual returns) but high variance (it sums many random rewards).
- $\lambda \in (0, 1)$: The exponentially decaying weights create a principled blend. In practice, $\lambda = 0.95$ is a strong default that gives most of the variance reduction of short bootstrapping while retaining most of the low bias of longer returns.
Worked Example. Consider a 5-step trajectory with rewards $[1, 0, 0, 0, 10]$ and value estimates $V = [5, 4, 3, 2, 1]$, with $\gamma = 0.99$. The TD errors are:
$$\delta_0 = 1 + 0.99 \cdot 4 - 5 = -0.04$$ $$\delta_1 = 0 + 0.99 \cdot 3 - 4 = -1.03$$ $$\delta_2 = 0 + 0.99 \cdot 2 - 3 = -1.02$$ $$\delta_3 = 0 + 0.99 \cdot 1 - 2 = -1.01$$ $$\delta_4 = 10 + 0 - 1 = 9.0$$
With $\lambda = 0$, the advantage at time 0 is just $\delta_0 = -0.04$ -- the agent does not foresee the big reward at time 4. With $\lambda = 0.95$, GAE at time 0 incorporates the large reward 4 steps away (discounted by $(\gamma\lambda)^4 \approx 0.81$), producing a positive advantage that correctly signals this is a good state to be in. This ability to propagate future reward information backward through time, while controlling variance, is what makes GAE so effective in practice.
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Categorical
from typing import List, Tuple
torch.manual_seed(42)
class ActorCritic(nn.Module):
"""Actor-Critic network with shared feature extractor.
Args:
state_dim: Dimension of state observations.
action_dim: Number of discrete actions.
hidden_dim: Hidden layer size.
"""
def __init__(
self, state_dim: int, action_dim: int, hidden_dim: int = 256
) -> None:
super().__init__()
# Shared feature extractor
self.shared = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
)
# Actor head
self.actor = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
)
# Critic head
self.critic = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
)
def forward(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass computing both policy logits and state value.
Args:
x: State tensor of shape (batch_size, state_dim).
Returns:
Tuple of (action_logits, state_value).
"""
features = self.shared(x)
action_logits = self.actor(features)
state_value = self.critic(features)
return action_logits, state_value
def get_action_and_value(
self, state: np.ndarray
) -> Tuple[int, torch.Tensor, torch.Tensor]:
"""Sample action and compute log prob and value.
Args:
state: Current state observation.
Returns:
Tuple of (action, log_probability, state_value).
"""
state_tensor = torch.FloatTensor(state).unsqueeze(0)
logits, value = self.forward(state_tensor)
dist = Categorical(logits=logits)
action = dist.sample()
return action.item(), dist.log_prob(action), value.squeeze()
def compute_gae(
rewards: List[float],
values: List[float],
next_value: float,
gamma: float = 0.99,
gae_lambda: float = 0.95,
) -> Tuple[List[float], List[float]]:
"""Compute Generalized Advantage Estimation.
Args:
rewards: List of rewards for each time step.
values: List of value estimates for each time step.
next_value: Value estimate for the state after the last step.
gamma: Discount factor.
gae_lambda: GAE lambda parameter.
Returns:
Tuple of (advantages, returns).
"""
advantages = []
gae = 0.0
for t in reversed(range(len(rewards))):
if t == len(rewards) - 1:
next_val = next_value
else:
next_val = values[t + 1]
delta = rewards[t] + gamma * next_val - values[t]
gae = delta + gamma * gae_lambda * gae
advantages.insert(0, gae)
returns = [a + v for a, v in zip(advantages, values)]
return advantages, returns
36.8 Proximal Policy Optimization (PPO)
36.8.1 The Trust Region Idea
A fundamental challenge in policy gradient methods is choosing the right step size. Too small, and learning is slow. Too large, and the policy can change catastrophically, destroying performance.
Trust Region Policy Optimization (TRPO) (Schulman et al., 2015a) addressed this by constraining the KL divergence between old and new policies. However, TRPO is complex to implement due to its use of conjugate gradients and line search.
36.8.2 The PPO Objective
Proximal Policy Optimization (PPO) (Schulman et al., 2017) achieves similar results with a much simpler objective. It uses a clipped surrogate objective:
$$L^{\text{CLIP}}(\theta) = \mathbb{E}_t \left[\min\left(r_t(\theta) \hat{A}_t, \;\text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \hat{A}_t\right)\right]$$
where $r_t(\theta) = \frac{\pi_\theta(a_t | s_t)}{\pi_{\theta_\text{old}}(a_t | s_t)}$ is the probability ratio and $\epsilon$ is the clipping parameter (typically 0.2).
The intuition: if the advantage is positive (good action), we want to increase its probability, but we clip the ratio so it does not increase too much. If the advantage is negative (bad action), we decrease its probability, but again clip to prevent too large a change.
36.8.3 Full PPO Objective
The complete PPO objective combines three terms:
$$L(\theta) = L^{\text{CLIP}}(\theta) - c_1 L^{\text{VF}}(\theta) + c_2 S[\pi_\theta]$$
where: - $L^{\text{VF}}(\theta) = (V_\theta(s_t) - V_t^{\text{target}})^2$ is the value function loss - $S[\pi_\theta] = -\sum_a \pi_\theta(a|s) \log \pi_\theta(a|s)$ is an entropy bonus encouraging exploration - $c_1, c_2$ are coefficients (typically $c_1 = 0.5, c_2 = 0.01$)
36.8.4 PPO Implementation
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Categorical
from typing import Dict, List, Tuple
torch.manual_seed(42)
class PPOMemory:
"""Rollout buffer for PPO.
Stores trajectories from multiple episodes for batch training.
"""
def __init__(self) -> None:
self.states: List[np.ndarray] = []
self.actions: List[int] = []
self.log_probs: List[float] = []
self.rewards: List[float] = []
self.values: List[float] = []
self.dones: List[bool] = []
def store(
self,
state: np.ndarray,
action: int,
log_prob: float,
reward: float,
value: float,
done: bool,
) -> None:
"""Store a single transition.
Args:
state: State observation.
action: Action taken.
log_prob: Log probability of the action under the current policy.
reward: Reward received.
value: Value estimate of the state.
done: Whether the episode ended.
"""
self.states.append(state)
self.actions.append(action)
self.log_probs.append(log_prob)
self.rewards.append(reward)
self.values.append(value)
self.dones.append(done)
def clear(self) -> None:
"""Clear all stored data."""
self.states = []
self.actions = []
self.log_probs = []
self.rewards = []
self.values = []
self.dones = []
def get_batches(
self, batch_size: int
) -> List[np.ndarray]:
"""Generate random mini-batch indices.
Args:
batch_size: Size of each mini-batch.
Returns:
List of arrays containing random indices for each batch.
"""
n = len(self.states)
indices = np.arange(n)
np.random.shuffle(indices)
batches = [
indices[i:i + batch_size]
for i in range(0, n, batch_size)
]
return batches
class PPOAgent:
"""Proximal Policy Optimization agent.
Args:
state_dim: Dimension of state observations.
action_dim: Number of discrete actions.
lr: Learning rate.
gamma: Discount factor.
gae_lambda: GAE lambda parameter.
clip_epsilon: PPO clipping parameter.
entropy_coeff: Entropy bonus coefficient.
value_coeff: Value loss coefficient.
n_epochs: Number of optimization epochs per rollout.
batch_size: Mini-batch size for optimization.
"""
def __init__(
self,
state_dim: int,
action_dim: int,
lr: float = 3e-4,
gamma: float = 0.99,
gae_lambda: float = 0.95,
clip_epsilon: float = 0.2,
entropy_coeff: float = 0.01,
value_coeff: float = 0.5,
n_epochs: int = 4,
batch_size: int = 64,
) -> None:
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
self.gamma = gamma
self.gae_lambda = gae_lambda
self.clip_epsilon = clip_epsilon
self.entropy_coeff = entropy_coeff
self.value_coeff = value_coeff
self.n_epochs = n_epochs
self.batch_size = batch_size
self.actor_critic = ActorCritic(
state_dim, action_dim
).to(self.device)
self.optimizer = optim.Adam(
self.actor_critic.parameters(), lr=lr
)
self.memory = PPOMemory()
def select_action(
self, state: np.ndarray
) -> Tuple[int, float, float]:
"""Select action from the current policy.
Args:
state: Current state observation.
Returns:
Tuple of (action, log_probability, state_value).
"""
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
with torch.no_grad():
logits, value = self.actor_critic(state_tensor)
dist = Categorical(logits=logits)
action = dist.sample()
log_prob = dist.log_prob(action)
return action.item(), log_prob.item(), value.item()
def compute_gae(self, last_value: float) -> Tuple[np.ndarray, np.ndarray]:
"""Compute GAE advantages and returns.
Args:
last_value: Value estimate for the final next state.
Returns:
Tuple of (advantages, returns) arrays.
"""
rewards = self.memory.rewards
values = self.memory.values
dones = self.memory.dones
advantages = np.zeros(len(rewards), dtype=np.float32)
gae = 0.0
for t in reversed(range(len(rewards))):
if t == len(rewards) - 1:
next_value = last_value
else:
next_value = values[t + 1]
next_non_terminal = 1.0 - float(dones[t])
delta = (
rewards[t] + self.gamma * next_value * next_non_terminal
- values[t]
)
gae = delta + self.gamma * self.gae_lambda * next_non_terminal * gae
advantages[t] = gae
returns = advantages + np.array(values, dtype=np.float32)
return advantages, returns
def update(self, last_value: float) -> Dict[str, float]:
"""Update the policy and value function using PPO.
Args:
last_value: Value estimate for the final next state.
Returns:
Dictionary of training metrics.
"""
advantages, returns = self.compute_gae(last_value)
# Convert to tensors
states = torch.FloatTensor(
np.array(self.memory.states)
).to(self.device)
actions = torch.LongTensor(self.memory.actions).to(self.device)
old_log_probs = torch.FloatTensor(
self.memory.log_probs
).to(self.device)
advantages_t = torch.FloatTensor(advantages).to(self.device)
returns_t = torch.FloatTensor(returns).to(self.device)
# Normalize advantages
advantages_t = (advantages_t - advantages_t.mean()) / (
advantages_t.std() + 1e-8
)
total_policy_loss = 0.0
total_value_loss = 0.0
total_entropy = 0.0
n_updates = 0
for _ in range(self.n_epochs):
batches = self.memory.get_batches(self.batch_size)
for batch_indices in batches:
batch_states = states[batch_indices]
batch_actions = actions[batch_indices]
batch_old_log_probs = old_log_probs[batch_indices]
batch_advantages = advantages_t[batch_indices]
batch_returns = returns_t[batch_indices]
# Forward pass
logits, values = self.actor_critic(batch_states)
dist = Categorical(logits=logits)
new_log_probs = dist.log_prob(batch_actions)
entropy = dist.entropy().mean()
# Policy loss with clipping
ratio = torch.exp(new_log_probs - batch_old_log_probs)
surr1 = ratio * batch_advantages
surr2 = (
torch.clamp(
ratio,
1 - self.clip_epsilon,
1 + self.clip_epsilon,
)
* batch_advantages
)
policy_loss = -torch.min(surr1, surr2).mean()
# Value loss
value_loss = nn.functional.mse_loss(
values.squeeze(), batch_returns
)
# Combined loss
loss = (
policy_loss
+ self.value_coeff * value_loss
- self.entropy_coeff * entropy
)
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(
self.actor_critic.parameters(), max_norm=0.5
)
self.optimizer.step()
total_policy_loss += policy_loss.item()
total_value_loss += value_loss.item()
total_entropy += entropy.item()
n_updates += 1
self.memory.clear()
return {
"policy_loss": total_policy_loss / n_updates,
"value_loss": total_value_loss / n_updates,
"entropy": total_entropy / n_updates,
}
36.8.5 Why PPO Dominates
PPO has become the most widely used RL algorithm for several reasons: - Simple to implement (compared to TRPO, which requires second-order optimization with conjugate gradients) - Good sample efficiency relative to other on-policy methods - Remarkably robust to hyperparameter choices -- the same defaults often work across very different tasks - Works well across diverse domains: continuous control, discrete games, language model alignment - Easily parallelizable by running multiple environment instances in parallel to collect rollouts
It is the algorithm behind RLHF for language models, robotics policies at companies like Google DeepMind and Tesla, and game-playing agents like OpenAI Five. When in doubt about which RL algorithm to use, PPO is almost always a good starting point.
36.9 RL for Language Models: RLHF and GRPO
36.9.1 Connecting RL to Language Models
Chapter 25 introduced RLHF as a technique for aligning language models. Let us now understand it through the RL lens we have developed in this chapter.
In the RLHF framework, text generation is framed as a sequential decision-making problem: - State: The prompt $x$ concatenated with all tokens generated so far $(y_1, \ldots, y_{t-1})$. This grows with each step. - Action: The next token $y_t$ to generate, selected from the vocabulary $\mathcal{V}$ (typically 30,000-100,000 tokens). - Policy: The language model $\pi_\theta(y_t | x, y_1, \ldots, y_{t-1})$. - Reward: Provided by a learned reward model $R_\phi$ trained on human preferences, applied at the end of generation. - Episode: A complete response, from the first generated token to the end-of-sequence token.
This is a challenging RL setting: the action space is enormous (the entire vocabulary), the state space is effectively infinite (all possible token sequences), and the reward is extremely sparse (only received at the end of generation). This is why the KL penalty term is so important -- without it, the policy could exploit quirks in the reward model to produce high-reward but nonsensical text (a phenomenon known as reward hacking).
The objective is:
$$\max_\theta \mathbb{E}_{x \sim \mathcal{D}, y \sim \pi_\theta(\cdot|x)} \left[R_\phi(x, y) - \beta \text{KL}(\pi_\theta(\cdot|x) \| \pi_\text{ref}(\cdot|x))\right]$$
The KL divergence term prevents the policy from deviating too far from the reference model (typically the supervised fine-tuned model), avoiding reward hacking.
36.9.2 PPO for RLHF
The standard RLHF pipeline uses PPO:
- Generate: Sample responses $y \sim \pi_\theta(\cdot|x)$ for prompts $x$.
- Score: Compute rewards $R_\phi(x, y)$ using the reward model.
- Update: Apply PPO to update $\pi_\theta$ to maximize rewards while staying close to $\pi_\text{ref}$.
import torch
import torch.nn as nn
from typing import Tuple
torch.manual_seed(42)
def compute_rlhf_reward(
reward_model_score: torch.Tensor,
log_probs: torch.Tensor,
ref_log_probs: torch.Tensor,
kl_coeff: float = 0.1,
) -> torch.Tensor:
"""Compute the RLHF reward with KL penalty.
Args:
reward_model_score: Scalar reward from the reward model.
log_probs: Log probabilities under the current policy.
ref_log_probs: Log probabilities under the reference policy.
kl_coeff: Coefficient for the KL penalty term.
Returns:
Adjusted reward with KL penalty applied per token.
"""
# Per-token KL divergence
kl_divergence = log_probs - ref_log_probs
# The reward model score is applied to the last token
# KL penalty is applied per token
rewards = -kl_coeff * kl_divergence
# Add the reward model score to the last token
rewards[-1] += reward_model_score
return rewards
36.9.3 Direct Preference Optimization (DPO)
DPO (Rafailov et al., 2023) bypasses the reward model entirely, optimizing the policy directly from preference data:
$$\mathcal{L}_{\text{DPO}}(\theta) = -\mathbb{E}_{(x, y_w, y_l)} \left[\log \sigma\left(\beta \log \frac{\pi_\theta(y_w|x)}{\pi_\text{ref}(y_w|x)} - \beta \log \frac{\pi_\theta(y_l|x)}{\pi_\text{ref}(y_l|x)}\right)\right]$$
where $y_w$ is the preferred response and $y_l$ is the rejected response. DPO is simpler and more stable than PPO-based RLHF but may have limitations in settings where the reward model captures nuances that pairwise comparisons miss.
36.9.4 Group Relative Policy Optimization (GRPO)
GRPO (Shao et al., 2024) is a newer approach that eliminates the need for a separate value function (critic). Instead, it estimates advantages by comparing multiple responses to the same prompt within a group:
- For each prompt $x$, generate $G$ responses $\{y_1, \ldots, y_G\}$.
- Score each response with the reward model: $r_i = R_\phi(x, y_i)$.
- Normalize rewards within the group: $\hat{r}_i = \frac{r_i - \text{mean}(\{r_j\})}{\text{std}(\{r_j\})}$.
- Use the normalized rewards as advantages in a PPO-like objective.
The GRPO objective:
$$\mathcal{L}_{\text{GRPO}}(\theta) = \mathbb{E}_{x, \{y_i\}} \left[\frac{1}{G} \sum_{i=1}^{G} \min\left(r_i(\theta) \hat{r}_i, \;\text{clip}(r_i(\theta), 1-\epsilon, 1+\epsilon) \hat{r}_i\right) - \beta \text{KL}(\pi_\theta \| \pi_\text{ref})\right]$$
Advantages of GRPO: - No critic network needed, reducing memory and compute. - The group normalization provides a natural baseline. - Particularly effective for reasoning tasks where verifiable rewards are available.
import torch
import torch.nn.functional as F
from typing import List
torch.manual_seed(42)
def grpo_loss(
policy_log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
ref_log_probs: torch.Tensor,
rewards: torch.Tensor,
clip_epsilon: float = 0.2,
kl_coeff: float = 0.01,
) -> torch.Tensor:
"""Compute GRPO loss for a group of responses to a single prompt.
Args:
policy_log_probs: Log probs under current policy, shape (G, T).
old_log_probs: Log probs under old policy, shape (G, T).
ref_log_probs: Log probs under reference policy, shape (G, T).
rewards: Scalar rewards for each response, shape (G,).
clip_epsilon: Clipping parameter for the ratio.
kl_coeff: KL divergence penalty coefficient.
Returns:
Scalar loss value.
"""
# Normalize rewards within the group
normalized_rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
# Compute per-sequence log probability ratios
seq_log_ratios = (policy_log_probs - old_log_probs).sum(dim=-1) # (G,)
ratios = torch.exp(seq_log_ratios)
# Clipped surrogate objective
surr1 = ratios * normalized_rewards
surr2 = torch.clamp(ratios, 1 - clip_epsilon, 1 + clip_epsilon) * normalized_rewards
policy_loss = -torch.min(surr1, surr2).mean()
# KL divergence penalty
kl_div = (policy_log_probs - ref_log_probs).sum(dim=-1).mean()
total_loss = policy_loss + kl_coeff * kl_div
return total_loss
36.10 Working with Gymnasium Environments
36.10.1 The Gymnasium API
Gymnasium (formerly OpenAI Gym) provides a standardized interface for RL environments. The core API is simple:
import gymnasium as gym
# Create an environment
env = gym.make("CartPole-v1", render_mode="rgb_array")
# Reset to get initial observation
obs, info = env.reset(seed=42)
# Step through the environment
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
# Check spaces
print(f"Observation space: {env.observation_space}")
print(f"Action space: {env.action_space}")
36.10.2 Common Environments
| Environment | Observation | Actions | Challenge |
|---|---|---|---|
| CartPole-v1 | 4D continuous | 2 discrete | Balance a pole |
| LunarLander-v3 | 8D continuous | 4 discrete | Land softly |
| MountainCar-v0 | 2D continuous | 3 discrete | Build momentum |
| Pendulum-v1 | 3D continuous | 1D continuous | Swing up and balance |
| Acrobot-v1 | 6D continuous | 3 discrete | Swing up double pendulum |
36.10.3 Creating Custom Environments
import gymnasium as gym
from gymnasium import spaces
import numpy as np
from typing import Any, Dict, Optional, Tuple
torch.manual_seed(42)
np.random.seed(42)
class GridWorldEnv(gym.Env):
"""A simple grid world environment for RL experimentation.
The agent starts at a random position and must reach the goal
while avoiding obstacles.
Args:
grid_size: Size of the square grid.
n_obstacles: Number of obstacle cells.
"""
metadata = {"render_modes": ["human", "rgb_array"]}
def __init__(
self,
grid_size: int = 5,
n_obstacles: int = 3,
render_mode: Optional[str] = None,
) -> None:
super().__init__()
self.grid_size = grid_size
self.n_obstacles = n_obstacles
self.render_mode = render_mode
# 4 actions: up, right, down, left
self.action_space = spaces.Discrete(4)
# Observation: agent position (row, col) as flattened index
self.observation_space = spaces.Discrete(grid_size * grid_size)
self._action_to_direction = {
0: np.array([-1, 0]), # up
1: np.array([0, 1]), # right
2: np.array([1, 0]), # down
3: np.array([0, -1]), # left
}
self.agent_pos: np.ndarray = np.array([0, 0])
self.goal_pos: np.ndarray = np.array([grid_size - 1, grid_size - 1])
self.obstacles: list = []
def _pos_to_obs(self, pos: np.ndarray) -> int:
"""Convert 2D position to flat observation index."""
return int(pos[0] * self.grid_size + pos[1])
def reset(
self,
seed: Optional[int] = None,
options: Optional[Dict[str, Any]] = None,
) -> Tuple[int, Dict[str, Any]]:
"""Reset environment to initial state.
Args:
seed: Random seed for reproducibility.
options: Additional options (unused).
Returns:
Tuple of (initial_observation, info_dict).
"""
super().reset(seed=seed)
# Place agent at start
self.agent_pos = np.array([0, 0])
# Place random obstacles
self.obstacles = []
while len(self.obstacles) < self.n_obstacles:
pos = self.np_random.integers(0, self.grid_size, size=2)
if (not np.array_equal(pos, self.agent_pos) and
not np.array_equal(pos, self.goal_pos)):
self.obstacles.append(pos.copy())
return self._pos_to_obs(self.agent_pos), {}
def step(
self, action: int
) -> Tuple[int, float, bool, bool, Dict[str, Any]]:
"""Take one step in the environment.
Args:
action: Direction to move (0=up, 1=right, 2=down, 3=left).
Returns:
Tuple of (observation, reward, terminated, truncated, info).
"""
direction = self._action_to_direction[action]
new_pos = self.agent_pos + direction
# Check boundaries
if (0 <= new_pos[0] < self.grid_size and
0 <= new_pos[1] < self.grid_size):
# Check obstacles
hit_obstacle = any(
np.array_equal(new_pos, obs) for obs in self.obstacles
)
if not hit_obstacle:
self.agent_pos = new_pos
# Check goal
terminated = np.array_equal(self.agent_pos, self.goal_pos)
reward = 1.0 if terminated else -0.01
return self._pos_to_obs(self.agent_pos), reward, terminated, False, {}
36.11 Challenges in Reinforcement Learning
36.11.1 Sparse Rewards
Many real-world tasks have sparse rewards: the agent receives zero reward for most of its trajectory and only gets a signal at success or failure. This makes credit assignment extremely difficult.
Solutions: - Reward shaping: Add intermediate rewards that guide the agent. Must be done carefully to avoid unintended behaviors. - Curiosity-driven exploration: Reward the agent for visiting novel states, using prediction error as an intrinsic reward. - Hindsight Experience Replay (HER): Retroactively relabel failed trajectories by pretending the achieved state was the goal.
36.11.2 Sample Efficiency
RL algorithms, especially on-policy methods like PPO, require millions of environment interactions. This is acceptable in simulation but prohibitive in the real world.
Solutions: - Off-policy learning: Reuse past data (DQN, SAC). - Model-based RL: Learn a model of the environment and plan using it. - Transfer learning: Pre-train on related tasks or simulations. - Offline RL: Learn from fixed datasets without environment interaction.
36.11.3 Exploration
The exploration-exploitation trade-off is fundamental. Under-exploration leads to suboptimal policies; over-exploration wastes time on bad actions.
Advanced exploration strategies: - Count-based exploration: Bonus rewards for rarely-visited states. - Posterior sampling (Thompson Sampling): Maintain uncertainty over values and sample optimistically. - Go-Explore: Archive interesting states and return to them for further exploration.
36.11.4 Stability and Reproducibility
Deep RL is notoriously unstable: - Small hyperparameter changes can cause large performance differences. - Random seeds dramatically affect results. - Performance can collapse after apparently converging.
Best practices: - Run multiple seeds and report aggregate statistics. - Use established hyperparameter settings as starting points. - Monitor for performance collapse and use checkpointing. - Log extensively: rewards, losses, gradients, entropy.
36.11.5 Sim-to-Real Transfer
Policies trained in simulation often fail when deployed in the real world due to the sim-to-real gap.
Solutions: - Domain randomization: Vary simulation parameters during training. - System identification: Calibrate the simulator to match reality. - Fine-tuning: Adapt the policy with a small amount of real-world data.
36.12 Multi-Agent RL (Brief Introduction)
When multiple agents interact in a shared environment, RL becomes significantly more complex:
- Cooperative: Agents share a common reward and must coordinate (e.g., team sports, multi-robot systems).
- Competitive: Agents have directly opposing goals (e.g., two-player games, adversarial settings).
- Mixed: Both cooperative and competitive elements coexist (e.g., economic markets).
Key challenges: - Non-stationarity: Each agent's environment changes as other agents learn. - Credit assignment: Which agent deserves credit for the team's success? - Communication: Should agents communicate, and how?
Approaches include independent learning (each agent uses standard RL, ignoring others), centralized training with decentralized execution (CTDE), and emergent communication protocols.
36.13 Practical Tips for RL Engineering
36.13.1 Debugging RL
Debugging RL is notoriously difficult because failures are silent -- the agent simply learns a bad policy rather than crashing with an error. Here are systematic strategies, expanding on the debugging techniques we discussed in Chapter 5 for general ML code:
-
Start simple: Test on easy environments first (CartPole). If your algorithm cannot solve CartPole in under 50,000 steps, there is a bug. Do not move to harder environments until simple ones work.
-
Verify the environment: Check that observations, rewards, and terminations are correct. A surprisingly common bug is incorrect reward shaping that inadvertently incentivizes the wrong behavior. Print out several episodes of random actions and manually verify that rewards make sense.
-
Monitor everything: Log reward curves, loss values, entropy, gradient norms, episode lengths, and the distribution of Q-values or advantages. Specific warning signs: - Entropy collapses to zero early: The policy has become deterministic prematurely. Increase the entropy coefficient. - Q-values diverge to large magnitudes: The target network may not be updating, or the learning rate is too high. - Policy loss oscillates wildly: The clipping parameter in PPO may be too large, or advantages are not normalized. - Episode length is constant: The agent may be stuck in a loop or immediately terminating. Visualize a few episodes.
-
Check for common bugs: - Reward scale issues (normalize rewards or use reward clipping) - Observation preprocessing errors (forgetting to normalize observations to zero mean, unit variance) - Off-by-one errors in advantage computation (the bootstrap value for terminal states must be zero) - Forgetting to zero out the value of terminal states in TD target computation - Gradient explosion/vanishing (always use gradient clipping in RL)
-
The sanity check ladder: Verify your implementation in this order: - Can your network overfit a single transition? (If not, the network or loss is wrong.) - Does the agent learn a trivial environment? (If not, the training loop has a bug.) - Does the agent match known benchmark results on standard environments? (If not, hyperparameters or subtle implementation details are wrong.)
36.13.2 Hyperparameter Tuning
Key hyperparameters and their typical ranges:
| Parameter | Typical Range | Notes |
|---|---|---|
| Learning rate | 1e-4 to 3e-4 | PPO sweet spot |
| Gamma | 0.95 to 0.999 | Higher for longer horizons |
| GAE lambda | 0.9 to 0.99 | Trade-off bias/variance |
| Clip epsilon | 0.1 to 0.3 | 0.2 is standard |
| Entropy coeff | 0.001 to 0.05 | Decrease over time if needed |
| Batch size | 32 to 2048 | Larger is often better |
| Number of epochs | 3 to 10 | More epochs per rollout |
36.13.3 When to Use RL
RL is powerful but not always the right tool. Use RL when: - The problem naturally involves sequential decisions. - A reward signal (even sparse) can be defined. - Simulation is available or data collection is feasible. - Supervised alternatives are insufficient.
Do not use RL when: - Supervised learning can solve the problem. - The reward function is unclear or hard to specify. - Sample efficiency is critical and you have limited interaction budget. - The environment is too dangerous for exploration.
36.14 Training Loop: Putting It All Together
import gymnasium as gym
import torch
import numpy as np
from typing import Dict, List
torch.manual_seed(42)
np.random.seed(42)
def train_ppo(
env_name: str = "CartPole-v1",
total_timesteps: int = 100000,
rollout_length: int = 2048,
n_epochs: int = 4,
batch_size: int = 64,
lr: float = 3e-4,
gamma: float = 0.99,
gae_lambda: float = 0.95,
clip_epsilon: float = 0.2,
print_freq: int = 10,
) -> Dict[str, List[float]]:
"""Train a PPO agent on a Gymnasium environment.
Args:
env_name: Name of the Gymnasium environment.
total_timesteps: Total training time steps.
rollout_length: Number of steps per rollout.
n_epochs: PPO update epochs per rollout.
batch_size: Mini-batch size.
lr: Learning rate.
gamma: Discount factor.
gae_lambda: GAE lambda.
clip_epsilon: PPO clipping parameter.
print_freq: How often to print progress (in rollouts).
Returns:
Dictionary containing training history.
"""
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = PPOAgent(
state_dim=state_dim,
action_dim=action_dim,
lr=lr,
gamma=gamma,
gae_lambda=gae_lambda,
clip_epsilon=clip_epsilon,
n_epochs=n_epochs,
batch_size=batch_size,
)
history = {
"episode_rewards": [],
"policy_losses": [],
"value_losses": [],
}
state, _ = env.reset(seed=42)
episode_reward = 0.0
timesteps_done = 0
rollout_count = 0
while timesteps_done < total_timesteps:
# Collect rollout
for _ in range(rollout_length):
action, log_prob, value = agent.select_action(state)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
agent.memory.store(state, action, log_prob, reward, value, done)
episode_reward += reward
timesteps_done += 1
if done:
history["episode_rewards"].append(episode_reward)
episode_reward = 0.0
state, _ = env.reset()
else:
state = next_state
# Compute last value for GAE
with torch.no_grad():
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(
agent.device
)
_, last_value = agent.actor_critic(state_tensor)
last_value = last_value.item()
# Update policy
metrics = agent.update(last_value)
history["policy_losses"].append(metrics["policy_loss"])
history["value_losses"].append(metrics["value_loss"])
rollout_count += 1
if rollout_count % print_freq == 0:
recent_rewards = history["episode_rewards"][-10:]
avg_reward = np.mean(recent_rewards) if recent_rewards else 0.0
print(
f"Timestep {timesteps_done}/{total_timesteps} | "
f"Avg Reward (last 10): {avg_reward:.1f} | "
f"Policy Loss: {metrics['policy_loss']:.4f} | "
f"Entropy: {metrics['entropy']:.4f}"
)
env.close()
return history
# To run: history = train_ppo()
Summary
Reinforcement learning provides a fundamentally different paradigm for training AI systems---one based on interaction and reward rather than static datasets. In this chapter, we covered:
- MDP formulation: States, actions, transitions, rewards, and policies.
- Value functions: $V(s)$ and $Q(s,a)$, connected by the Bellman equations.
- Tabular methods: Q-learning and SARSA for small state spaces.
- Deep Q-Networks: Function approximation with experience replay and target networks.
- Policy gradients: REINFORCE and the policy gradient theorem.
- Actor-Critic methods: Combining value and policy learning with GAE.
- PPO: The clipped surrogate objective that dominates modern RL.
- RL for LLMs: RLHF, DPO, and GRPO for aligning language models.
- Practical considerations: Environments, debugging, and when to use RL.
The connection between RL and language model alignment represents one of the most impactful applications of RL in modern AI. Understanding these foundations will prepare you to work with alignment techniques, train game-playing agents, and tackle sequential decision problems across domains.
Key Equations to Remember
| Concept | Equation |
|---|---|
| Return | $G_t = \sum_{k=0}^{\infty} \gamma^k r_{t+k+1}$ |
| Bellman equation | $Q^*(s,a) = \sum_{s'} P(s'|s,a) [R + \gamma \max_{a'} Q^*(s',a')]$ |
| Q-learning update | $Q(s,a) \leftarrow Q(s,a) + \alpha [r + \gamma \max_{a'} Q(s',a') - Q(s,a)]$ |
| Policy gradient | $\nabla_\theta J = \mathbb{E}[\nabla_\theta \log \pi_\theta(a|s) \hat{A}]$ |
| PPO objective | $L = \mathbb{E}[\min(r_t \hat{A}_t, \text{clip}(r_t, 1\pm\epsilon) \hat{A}_t)]$ |
| GAE | $\hat{A}_t = \sum_{l=0}^{\infty} (\gamma\lambda)^l \delta_{t+l}$ |
The Broader Perspective
Reinforcement learning occupies a unique position in AI: it is the closest we have come to a general framework for learning intelligent behavior from experience. While supervised learning requires human-provided labels and unsupervised learning finds patterns in static data, RL learns from the consequences of actions -- much as humans and animals learn in the real world. As we discussed in Chapter 1 on the landscape of AI engineering, this ability to learn from interaction rather than instruction is what makes RL essential for problems that cannot be solved by pattern matching alone.
In the next chapter, we turn to graph neural networks, exploring how to learn from structured, relational data that does not fit neatly into the grid-like formats that standard neural networks expect.