> "Getting a neural network to learn is not magic---it is engineering. The difference between a model that converges beautifully and one that diverges catastrophically often comes down to a handful of deliberate choices about losses, optimizers, and...
In This Chapter
- 12.1 Loss Functions: Translating Objectives into Gradients
- 12.2 Optimizers: Navigating the Loss Landscape
- 12.3 Learning Rate Schedules
- 12.4 Normalization Layers
- 12.5 Weight Initialization
- 12.6 Gradient Clipping
- 12.7 Mixed Precision Training
- 12.8 The Complete Training Loop
- 12.9 Debugging Training: When Things Go Wrong
- 12.10 Regularization Techniques for Training
- 12.11 Putting It All Together: A Training Recipe
- 12.12 Advanced Training Techniques
- 12.13 Distributed Training: Scaling to Multiple GPUs
- Summary
Chapter 12: Training Deep Networks
"Getting a neural network to learn is not magic---it is engineering. The difference between a model that converges beautifully and one that diverges catastrophically often comes down to a handful of deliberate choices about losses, optimizers, and numerical hygiene."
In Chapter 11, we built neural networks from the ground up---defining layers, activations, and the forward pass. But defining a network and training it successfully are two very different challenges. A well-architected model can still fail spectacularly if you choose the wrong loss function, initialize weights poorly, or set a learning rate that is too large (or too small). This chapter is devoted to the practical art and science of making deep networks learn.
Training a deep network involves orchestrating dozens of interacting components: loss functions that translate your objective into a differentiable signal, optimizers that update parameters based on gradients, learning rate schedules that modulate the pace of learning, normalization layers that stabilize intermediate representations, initialization schemes that set the stage for healthy gradient flow, and numerical techniques like gradient clipping and mixed precision that keep computations well-behaved at scale.
We will cover each of these components in depth, building on the optimization foundations from Chapter 3 and the neural network architectures from Chapter 11. By the end of this chapter, you will have a complete, production-grade training pipeline in PyTorch and---more importantly---the conceptual tools to diagnose and fix problems when training goes wrong.
12.1 Loss Functions: Translating Objectives into Gradients
The loss function (also called the objective function or criterion) is the bridge between what you want your model to do and how gradient descent updates its parameters. Choosing the right loss is not merely a formality; it shapes the entire optimization landscape.
12.1.1 Mean Squared Error (MSE)
For regression tasks, the workhorse loss is the mean squared error:
$$\mathcal{L}_{\text{MSE}} = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2$$
where $y_i$ is the true target and $\hat{y}_i$ is the model prediction. The gradient with respect to each prediction is proportional to the residual:
$$\frac{\partial \mathcal{L}_{\text{MSE}}}{\partial \hat{y}_i} = \frac{2}{N}(\hat{y}_i - y_i)$$
This means large errors produce large gradients, pushing the model to correct its worst predictions first. While this is often desirable, it also makes MSE sensitive to outliers. A single data point with $|y_i - \hat{y}_i| = 100$ contributes 10,000 to the loss, dominating the gradient signal.
When to use MSE: Regression tasks where the target distribution is roughly Gaussian and outliers are rare or have been handled during preprocessing.
Alternatives for robustness: The Huber loss (smooth L1) combines MSE for small errors with L1 for large errors, providing outlier robustness while maintaining smoothness near zero:
$$\mathcal{L}_{\text{Huber}}(y, \hat{y}) = \begin{cases} \frac{1}{2}(y - \hat{y})^2 & \text{if } |y - \hat{y}| \leq \delta \\ \delta |y - \hat{y}| - \frac{1}{2}\delta^2 & \text{otherwise} \end{cases}$$
In PyTorch, these are available as torch.nn.MSELoss() and torch.nn.SmoothL1Loss(beta=delta).
12.1.2 Cross-Entropy Loss
For classification, cross-entropy is the standard choice. Recall from Chapter 5 that for a classification problem with $C$ classes, the cross-entropy loss for a single sample is:
$$\mathcal{L}_{\text{CE}} = -\sum_{c=1}^{C} y_c \log(\hat{p}_c)$$
where $y_c$ is 1 if the true class is $c$ and 0 otherwise, and $\hat{p}_c$ is the model's predicted probability for class $c$. For the common case of hard labels (one-hot), this simplifies to:
$$\mathcal{L}_{\text{CE}} = -\log(\hat{p}_{y^*})$$
where $y^*$ is the true class index. The model is penalized by how confident it is in the wrong answer, which creates strong gradients when predictions are confidently incorrect.
Numerical stability: Computing $\log(\text{softmax}(z))$ naively is numerically dangerous. If any logit $z_c$ is very large, $\exp(z_c)$ overflows; if all logits are very negative, $\sum \exp(z_c)$ underflows. PyTorch's torch.nn.CrossEntropyLoss accepts raw logits (not softmax probabilities) and uses the log-sum-exp trick internally:
$$\log \text{softmax}(z_c) = z_c - \log\left(\sum_{j} \exp(z_j)\right) = z_c - \left(z_{\max} + \log\sum_{j} \exp(z_j - z_{\max})\right)$$
This is why you should never apply softmax before CrossEntropyLoss in PyTorch---doing so loses numerical precision and is a common source of bugs.
Binary cross-entropy for two-class problems (or multi-label classification where each label is independent) uses sigmoid activation:
$$\mathcal{L}_{\text{BCE}} = -\frac{1}{N}\sum_{i=1}^{N}\left[y_i \log(\sigma(\hat{z}_i)) + (1 - y_i)\log(1 - \sigma(\hat{z}_i))\right]$$
PyTorch provides torch.nn.BCEWithLogitsLoss which combines sigmoid and BCE for numerical stability, analogous to how CrossEntropyLoss combines softmax and negative log-likelihood.
12.1.3 Focal Loss
Standard cross-entropy treats all correctly classified examples equally, regardless of the model's confidence. In datasets with severe class imbalance---such as object detection where background examples vastly outnumber objects---easy negatives can dominate the loss and gradient.
Focal loss, introduced by Lin et al. (2017), adds a modulating factor:
$$\mathcal{L}_{\text{focal}} = -\alpha_t (1 - \hat{p}_t)^\gamma \log(\hat{p}_t)$$
where $\hat{p}_t$ is the predicted probability for the true class, $\gamma \geq 0$ is the focusing parameter, and $\alpha_t$ is a class-balancing weight. When $\gamma = 0$, focal loss reduces to standard cross-entropy. When $\gamma > 0$, the factor $(1 - \hat{p}_t)^\gamma$ down-weights easy examples (where $\hat{p}_t$ is high) and focuses training on hard examples. Typical values are $\gamma = 2$ and $\alpha = 0.25$.
The gradient of focal loss includes additional terms that reduce the gradient magnitude for well-classified examples:
$$\frac{\partial \mathcal{L}_{\text{focal}}}{\partial z} = -\alpha_t (1 - \hat{p}_t)^\gamma \left[\gamma \hat{p}_t \log(\hat{p}_t) + (\hat{p}_t - 1)\right]$$
Focal loss is not limited to object detection. It is useful whenever you have class imbalance and want the model to focus on the hardest examples.
12.1.4 Other Important Losses
Several other loss functions appear frequently in deep learning:
- KL Divergence (
torch.nn.KLDivLoss): Measures the divergence between two probability distributions. Central to variational autoencoders and knowledge distillation. - Triplet Loss: Learns embeddings by enforcing that an anchor is closer to a positive example than a negative one by a margin. Used in metric learning and face recognition.
- Contrastive Loss / InfoNCE: The backbone of modern self-supervised learning (SimCLR, CLIP). Pulls positive pairs together and pushes negative pairs apart in embedding space.
- Label Smoothing: Instead of hard one-hot targets, uses $(1 - \epsilon)$ for the true class and $\epsilon / (C-1)$ for others. Reduces overconfidence and can improve generalization. Available via the
label_smoothingparameter intorch.nn.CrossEntropyLoss.
12.2 Optimizers: Navigating the Loss Landscape
In Chapter 3, we studied gradient descent and its variants in general terms. Now we put those ideas into practice with PyTorch's optimizer library, focusing on the three optimizers that dominate modern deep learning.
12.2.1 Stochastic Gradient Descent (SGD)
The simplest optimizer updates parameters by stepping in the direction opposite to the gradient:
$$\theta_{t+1} = \theta_t - \eta \nabla_\theta \mathcal{L}(\theta_t)$$
Plain SGD is rarely used alone because mini-batch gradients are noisy, causing the optimization trajectory to oscillate. Momentum smooths the updates by accumulating an exponentially decaying moving average of past gradients:
$$v_{t+1} = \mu v_t + \nabla_\theta \mathcal{L}(\theta_t)$$ $$\theta_{t+1} = \theta_t - \eta v_{t+1}$$
where $\mu$ is the momentum coefficient (typically 0.9). Momentum helps the optimizer build up velocity along consistent gradient directions while dampening oscillations. Nesterov momentum computes the gradient at the "lookahead" position $\theta_t - \eta \mu v_t$, which often provides better convergence:
optimizer = torch.optim.SGD(
model.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4
)
SGD with momentum remains the optimizer of choice for many computer vision tasks (ResNets, Vision Transformers with long training schedules) because it generalizes well, albeit requiring more careful learning rate tuning.
12.2.2 Adam
Adam (Adaptive Moment Estimation) combines momentum with per-parameter adaptive learning rates. It maintains two exponentially decaying averages---the first moment (mean) $m_t$ and the second moment (uncentered variance) $v_t$ of the gradient:
$$m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t$$ $$v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2$$
Because $m_t$ and $v_t$ are initialized to zero, they are biased toward zero in early steps. Bias-corrected estimates are:
$$\hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t}$$
The update rule divides the first moment by the square root of the second moment, giving each parameter its own effective learning rate:
$$\theta_{t+1} = \theta_t - \eta \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$$
Default hyperparameters ($\beta_1 = 0.9$, $\beta_2 = 0.999$, $\epsilon = 10^{-8}$) work well across a wide range of tasks. Adam is the default choice for NLP, generative models, and any setting where you want fast convergence with minimal tuning.
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, betas=(0.9, 0.999))
12.2.3 AdamW: Decoupled Weight Decay
A subtle but important issue with Adam is how it interacts with L2 regularization. When you add an L2 penalty $\frac{\lambda}{2}\|\theta\|^2$ to the loss, Adam's adaptive scaling effectively reduces the regularization strength for parameters with large gradients. Loshchilov and Hutter (2019) proposed decoupled weight decay, which applies regularization directly to the parameters rather than through the gradient:
$$\theta_{t+1} = \theta_t - \eta\left(\frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda \theta_t\right)$$
This is mathematically different from L2 regularization with Adam and empirically produces better generalization. AdamW is the recommended optimizer for most modern deep learning tasks, particularly in NLP and for training transformers.
optimizer = torch.optim.AdamW(
model.parameters(), lr=3e-4, betas=(0.9, 0.999), weight_decay=0.01
)
12.2.4 Choosing an Optimizer
| Scenario | Recommended Optimizer | Typical LR |
|---|---|---|
| Computer vision (CNNs, long training) | SGD + Nesterov momentum | 0.1 (with schedule) |
| NLP / Transformers | AdamW | 1e-4 to 5e-4 |
| GANs and generative models | Adam | 1e-4 to 2e-4 |
| Fine-tuning pretrained models | AdamW | 1e-5 to 5e-5 |
| Quick prototyping / uncertain | Adam or AdamW | 3e-4 |
12.2.5 Parameter Groups
PyTorch optimizers support parameter groups, allowing different learning rates or weight decay for different parts of the model. This is essential for fine-tuning, where pretrained layers should update slowly:
optimizer = torch.optim.AdamW([
{"params": model.backbone.parameters(), "lr": 1e-5},
{"params": model.head.parameters(), "lr": 1e-3},
], weight_decay=0.01)
12.3 Learning Rate Schedules
The learning rate is arguably the single most important hyperparameter. A rate too high causes divergence; too low causes slow convergence or getting trapped in poor local minima. Learning rate schedules adjust the rate during training to get the best of both worlds.
12.3.1 Step Decay
The simplest schedule reduces the learning rate by a fixed factor at predetermined epochs:
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
This multiplies the learning rate by $\gamma = 0.1$ every 30 epochs. Multi-step decay (MultiStepLR) allows specifying exact milestones:
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=[60, 120, 160], gamma=0.2
)
Step decay was the standard approach for training ResNets and remains effective, but it requires knowing in advance when to drop the rate.
12.3.2 Cosine Annealing
Cosine annealing smoothly reduces the learning rate following a cosine curve:
$$\eta_t = \eta_{\min} + \frac{1}{2}(\eta_{\max} - \eta_{\min})\left(1 + \cos\left(\frac{t}{T}\pi\right)\right)$$
where: - $\eta_{\min}$ is the minimum learning rate (typically 1e-6 or 0), - $\eta_{\max}$ is the initial (maximum) learning rate, - $t$ is the current step or epoch, and - $T$ is the total number of steps or epochs.
This schedule starts by reducing the rate slowly, then accelerates the decay in the middle, and slows again near the end. The slow start allows the model to continue learning at a productive rate early on, the fast middle phase pushes the model toward a minimum, and the slow finish enables fine-grained convergence.
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=num_epochs, eta_min=1e-6
)
Cosine annealing consistently performs well and has become the default schedule for many tasks. It requires only two hyperparameters: the initial learning rate and the minimum rate.
Cosine annealing with warm restarts (CosineAnnealingWarmRestarts) periodically resets the learning rate to its initial value, which can help escape local minima and improve exploration. The restart period can increase geometrically (e.g., doubling after each restart), giving the optimizer progressively longer phases to settle into better minima:
# First restart after 10 epochs, then 20, then 40, etc.
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer, T_0=10, T_mult=2, eta_min=1e-6
)
12.3.2.1 Cyclical Learning Rates
Leslie Smith (2017) introduced cyclical learning rates, which oscillate the learning rate between a minimum and maximum value throughout training. The intuition is that periodically increasing the learning rate helps the optimizer escape sharp local minima and find flatter, more generalizable minima.
The triangular policy linearly increases and then decreases the learning rate over each cycle:
scheduler = torch.optim.lr_scheduler.CyclicLR(
optimizer,
base_lr=1e-4,
max_lr=1e-2,
step_size_up=2000, # Half-cycle length in steps
mode="triangular2", # Halve the max_lr after each cycle
cycle_momentum=True, # Cycle momentum inversely
)
The triangular2 mode halves the amplitude with each cycle, combining exploration with convergence. Smith also proposed the LR range test: train the model for one epoch with the learning rate increasing exponentially from a very small value to a very large value. Plot the loss against the learning rate. The optimal range is the interval where the loss decreases most steeply. This technique is a practical way to set the base_lr and max_lr for cyclical schedules, or to find a good initial learning rate for any schedule.
Connection to loss landscape geometry. Why do cyclical learning rates work? The loss landscape of deep networks is not a smooth bowl---it is a rugged terrain with many local minima, saddle points, and narrow valleys. A high learning rate acts like a high-temperature exploration: the optimizer takes large steps that can jump over barriers between basins. A low learning rate acts like low-temperature exploitation: the optimizer settles into the nearest minimum. By oscillating between the two, cyclical learning rates get the benefits of both.
12.3.3 Linear Warmup
For large models or large batch sizes, starting training at the full learning rate can cause instability. Warmup linearly increases the learning rate from a small value to the target rate over the first few epochs or steps:
$$\eta_t = \eta_{\text{target}} \cdot \frac{t}{T_{\text{warmup}}}$$
where: - $\eta_{\text{target}}$ is the target peak learning rate, - $t$ is the current step, and - $T_{\text{warmup}}$ is the total number of warmup steps.
Why Warmup Is Necessary. At the start of training, the model's weights are random and the loss landscape is poorly conditioned. The gradients computed on the first few batches are noisy and potentially unrepresentative. If the learning rate is too high at this point, the optimizer takes large, poorly-directed steps that can push the model into a bad region of parameter space from which it never recovers. Warmup addresses this by starting with small, conservative steps that allow the model to "calibrate" its gradients. Once the running statistics in Adam/AdamW have accumulated a few hundred steps of gradient information, the adaptive learning rates become reliable and the optimizer can handle the full learning rate.
This phenomenon is especially pronounced with: - Large batch sizes: Larger batches produce sharper gradient estimates that can cause overshooting - Transformers and attention models: The softmax in attention creates highly nonlinear dynamics that are unstable without warmup - Adam/AdamW: The bias correction in early steps can create artificially large updates if the learning rate is already high
Warmup strategies: - Linear warmup (most common): The learning rate increases linearly from 0 to the target - Exponential warmup: The learning rate increases exponentially, spending more time at low rates - Gradual warmup (Goyal et al., 2017): For very large batch training, warmup over 5--10 epochs
A typical transformer training recipe uses 1,000--4,000 warmup steps followed by cosine decay. For fine-tuning pretrained models, shorter warmup (100--500 steps) is usually sufficient.
Warmup is critical for training transformers and is typically combined with a decay schedule afterward. PyTorch does not have a built-in warmup scheduler, but you can compose one using LambdaLR or SequentialLR:
def warmup_cosine_schedule(optimizer, warmup_steps, total_steps):
"""Create a schedule with linear warmup and cosine decay.
Args:
optimizer: The optimizer to schedule.
warmup_steps: Number of warmup steps.
total_steps: Total number of training steps.
Returns:
A learning rate scheduler.
"""
def lr_lambda(current_step):
if current_step < warmup_steps:
return float(current_step) / float(max(1, warmup_steps))
progress = float(current_step - warmup_steps) / float(
max(1, total_steps - warmup_steps)
)
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
12.3.4 One-Cycle Policy
The one-cycle policy (Smith and Topin, 2019) increases the learning rate from a minimum to a maximum during the first phase, then decreases it to well below the minimum during the second phase. Simultaneously, momentum decreases during phase one and increases during phase two:
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=0.01, total_steps=num_epochs * len(train_loader),
pct_start=0.3, anneal_strategy="cos"
)
The one-cycle policy often achieves the best results in fewer epochs and is particularly effective with SGD.
12.3.5 The Training Loop Integration
Schedulers must be called at the right time. Per-epoch schedulers (StepLR, CosineAnnealingLR) call scheduler.step() after each epoch. Per-step schedulers (OneCycleLR, warmup schedules) call scheduler.step() after each batch:
# Per-epoch scheduler
for epoch in range(num_epochs):
train_one_epoch(model, train_loader, optimizer)
scheduler.step() # After each epoch
# Per-step scheduler
for epoch in range(num_epochs):
for batch in train_loader:
loss = train_step(model, batch, optimizer)
scheduler.step() # After each batch
Mixing these up is a common bug---calling a per-epoch scheduler every step results in the learning rate decaying far too quickly.
12.4 Normalization Layers
As data flows through many layers, the distribution of activations can shift dramatically---a phenomenon known as internal covariate shift. Normalization layers stabilize training by constraining activation statistics, enabling higher learning rates and faster convergence.
12.4.1 Batch Normalization
Batch normalization (Ioffe and Szegedy, 2015) normalizes activations across the batch dimension. For a mini-batch $\mathcal{B} = \{x_1, \ldots, x_m\}$ and a given feature channel:
$$\hat{x}_i = \frac{x_i - \mu_\mathcal{B}}{\sqrt{\sigma_\mathcal{B}^2 + \epsilon}}$$
where $\mu_\mathcal{B}$ and $\sigma_\mathcal{B}^2$ are the batch mean and variance. Learnable parameters $\gamma$ (scale) and $\beta$ (shift) allow the network to undo normalization if needed:
$$y_i = \gamma \hat{x}_i + \beta$$
During training, batch statistics ($\mu_\mathcal{B}$, $\sigma_\mathcal{B}^2$) are computed from the current mini-batch, and running averages are maintained via exponential moving average. During inference, the running averages are used for deterministic predictions---this is why calling model.eval() is critical before evaluation.
import torch.nn as nn
# For 2D inputs (images): normalizes over (N, H, W) for each channel
bn = nn.BatchNorm2d(num_features=64)
# For 1D inputs (sequences/vectors): normalizes over (N,) for each feature
bn1d = nn.BatchNorm1d(num_features=256)
Benefits of batch normalization: - Allows higher learning rates without divergence - Acts as a mild regularizer (due to batch noise) - Reduces sensitivity to weight initialization - Smooths the loss landscape
Limitations: - Behavior differs between training and evaluation modes (source of subtle bugs) - Performance degrades with very small batch sizes (batch statistics become noisy) - Not ideal for sequence models where batch statistics mix different sequence lengths
12.4.2 Layer Normalization
Layer normalization (Ba et al., 2016) normalizes across the feature dimension rather than the batch dimension. For an input $x \in \mathbb{R}^d$:
$$\hat{x}_j = \frac{x_j - \mu}{\sqrt{\sigma^2 + \epsilon}}, \quad \mu = \frac{1}{d}\sum_{j=1}^d x_j, \quad \sigma^2 = \frac{1}{d}\sum_{j=1}^d (x_j - \mu)^2$$
Each sample is normalized independently, so there is no dependence on other samples in the batch.
# Normalizes over the last dimension (or specified normalized_shape)
ln = nn.LayerNorm(normalized_shape=256)
# For transformer inputs with shape (batch, seq_len, d_model)
ln = nn.LayerNorm(normalized_shape=d_model)
Layer normalization is the standard choice for transformers and RNNs because: - It does not depend on batch size - Behavior is identical during training and inference - It works naturally with variable-length sequences
12.4.3 Other Normalization Variants
- Group Normalization (
nn.GroupNorm): Divides channels into groups and normalizes within each group. A compromise between batch norm and layer norm. Works well with small batch sizes. - Instance Normalization (
nn.InstanceNorm2d): Normalizes each channel of each sample independently. Used in style transfer. - RMS Normalization: Simplifies layer norm by removing the mean centering, using only root-mean-square scaling. Used in LLaMA and other modern LLMs for its computational efficiency.
12.4.4 Where to Place Normalization
The original batch normalization paper placed normalization before the activation function (post-linear, pre-activation). However, modern practice varies:
- Pre-norm (normalize before the sub-layer): Standard in transformers. Stabilizes training of very deep models.
- Post-norm (normalize after the sub-layer): Original ResNet and transformer design. Can achieve slightly better final performance but is harder to train at scale.
12.5 Weight Initialization
Before training begins, every weight in your network must be assigned an initial value. This choice is surprisingly consequential: poor initialization can cause gradients to vanish or explode within the first few forward passes, making learning impossible.
12.5.1 The Problem with Bad Initialization
Consider a network with $L$ layers, each computing $h_l = W_l h_{l-1}$ (ignoring biases and activations for simplicity). The output is:
$$h_L = W_L W_{L-1} \cdots W_1 x$$
If each weight matrix has entries with variance too large, the activations grow exponentially with depth. If the variance is too small, they shrink exponentially. The same argument applies to gradients during backpropagation.
Initializing all weights to zero is catastrophic: every neuron computes the same function, receives the same gradient, and stays identical forever. This is the symmetry-breaking problem---initialization must introduce asymmetry.
12.5.2 Xavier (Glorot) Initialization
Glorot and Bengio (2010) derived that to maintain variance of activations across layers with linear or tanh activations, weights should be drawn from:
$$W \sim \mathcal{U}\left(-\sqrt{\frac{6}{n_{\text{in}} + n_{\text{out}}}}, \sqrt{\frac{6}{n_{\text{in}} + n_{\text{out}}}}\right)$$
or equivalently from a normal distribution with variance $\frac{2}{n_{\text{in}} + n_{\text{out}}}$. This ensures that both forward activations and backward gradients maintain roughly unit variance.
nn.init.xavier_uniform_(layer.weight)
nn.init.xavier_normal_(layer.weight)
Use Xavier initialization with: tanh, sigmoid, and softmax activations.
12.5.3 He (Kaiming) Initialization
ReLU activations zero out roughly half of their inputs, effectively halving the variance at each layer. He et al. (2015) corrected for this by deriving:
$$W \sim \mathcal{N}\left(0, \frac{2}{n_{\text{in}}}\right)$$
This accounts for the ReLU's variance reduction and is the standard initialization for networks using ReLU-family activations.
nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')
nn.init.kaiming_uniform_(layer.weight, mode='fan_in', nonlinearity='relu')
The mode parameter controls whether to use the input dimension (fan_in, preserves forward pass variance) or output dimension (fan_out, preserves backward pass variance). For most cases, fan_in is preferred.
12.5.4 Initialization in Practice
PyTorch's nn.Linear uses Kaiming uniform by default, which is a reasonable choice for ReLU networks. However, for specific architectures you may want to customize initialization:
def init_weights(module: nn.Module) -> None:
"""Initialize weights using best practices for each layer type.
Args:
module: A PyTorch module to initialize.
"""
if isinstance(module, nn.Linear):
nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, (nn.BatchNorm2d, nn.LayerNorm)):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
model.apply(init_weights)
Biases are almost always initialized to zero. The exception is LSTM forget gate biases, which are often initialized to 1 to encourage remembering in early training (as we will discuss in Chapter 15).
12.5.5 Diagnosing Initialization Problems
You can check whether your initialization is appropriate by examining activation statistics after one forward pass on a real data batch:
def check_activation_stats(model: nn.Module, sample_input: torch.Tensor) -> None:
"""Print activation statistics for each layer to diagnose initialization.
Args:
model: The neural network.
sample_input: A representative batch of input data.
"""
hooks = []
stats = {}
def hook_fn(name):
def hook(module, input, output):
if isinstance(output, torch.Tensor):
stats[name] = {
"mean": output.mean().item(),
"std": output.std().item(),
"dead_frac": (output == 0).float().mean().item(),
}
return hook
for name, module in model.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ReLU)):
hooks.append(module.register_forward_hook(hook_fn(name)))
with torch.no_grad():
model(sample_input)
for hook in hooks:
hook.remove()
for name, s in stats.items():
print(f"{name}: mean={s['mean']:.4f}, std={s['std']:.4f}, "
f"dead={s['dead_frac']:.2%}")
What to look for: - Activation means should be near zero (for zero-centered activations like tanh) or near 0.5 times the standard deviation (for ReLU) - Activation standard deviations should be roughly constant across layers (not growing or shrinking) - Dead fraction (fraction of neurons outputting exactly zero) should be below 50% for ReLU layers; above that indicates too-aggressive initialization or that the learning rate pushed many neurons into the dead zone
12.6 Gradient Clipping
Even with good initialization and normalization, gradients can occasionally spike---due to outlier batches, sharp loss landscape features, or the inherent instability of recurrent computations. Gradient clipping caps gradients to prevent these spikes from destabilizing training.
12.6.1 Clipping by Norm
The most common approach clips the global gradient norm. If the total norm of all parameter gradients exceeds a threshold, all gradients are scaled down proportionally:
$$g \leftarrow \begin{cases} g & \text{if } \|g\| \leq \tau \\ \frac{\tau}{\|g\|} g & \text{if } \|g\| > \tau \end{cases}$$
This preserves the direction of the gradient while limiting its magnitude:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
Call this after loss.backward() and before optimizer.step(). The function returns the original (unclipped) gradient norm, which is useful for monitoring.
12.6.2 Clipping by Value
An alternative clips each gradient element independently to the range $[-\tau, \tau]$:
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)
Clipping by value changes the gradient direction and is less commonly used than clipping by norm. It can be useful when you want hard bounds on individual gradient components.
12.6.3 When to Clip
Gradient clipping is essential for:
- Recurrent networks (LSTMs, GRUs): Exploding gradients are an inherent risk in backpropagation through time (as we will see in Chapter 15)
- Transformers: The standard training recipe includes gradient clipping with max_norm=1.0
- Any training run showing gradient spikes: Monitor the gradient norm over time and introduce clipping if you see occasional spikes
Choosing the clip threshold. A common approach is to train for a few hundred steps without clipping, record the gradient norm at each step, and set the threshold to a percentile (e.g., the 95th or 99th percentile) of the observed distribution. This ensures that only abnormally large gradients are clipped, while normal training dynamics are preserved. Typical values:
- max_norm=1.0: Standard for transformers and many NLP models
- max_norm=5.0: Common for RNNs and LSTMs
- max_norm=0.5: More aggressive clipping for unstable training
Monitoring gradient norms. Always log the gradient norm as a training metric. A healthy training run shows a stable or slowly declining gradient norm. Warning signs include: - Sudden spikes: A single batch caused a very large gradient---clipping handles this - Steadily increasing gradient norm: The model is entering an unstable region---consider reducing the learning rate - Consistently clipped gradients: The threshold is too low, or the learning rate is too high---the model is making little progress in the clipped direction
Gradient clipping is a safety net, not a fix for fundamentally broken training. If gradients are consistently being clipped, investigate the root cause (learning rate too high, poor initialization, bad data).
12.7 Mixed Precision Training
Modern GPUs have specialized hardware (Tensor Cores on NVIDIA GPUs) that perform float16 (FP16) matrix multiplications much faster than float32 (FP32). Mixed precision training exploits this by performing most computations in FP16 while keeping a master copy of weights in FP32.
12.7.1 Why Mixed Precision Works
FP16 has a smaller dynamic range ($6 \times 10^{-8}$ to $6.5 \times 10^4$) compared to FP32 ($1.2 \times 10^{-38}$ to $3.4 \times 10^{38}$). Naively converting everything to FP16 causes two problems: 1. Underflow: Small gradients (especially in early layers) round to zero 2. Overflow: Large intermediate values exceed the FP16 range
Mixed precision resolves these through two techniques: - Loss scaling: Multiply the loss by a large factor before backpropagation to shift gradients into the FP16 representable range, then divide gradients by the same factor before the optimizer step - Master weights in FP32: Weight updates involve adding small gradient values to large weight values. FP16 lacks the precision for this, so a FP32 copy of weights receives the updates
12.7.2 PyTorch Automatic Mixed Precision (AMP)
PyTorch provides torch.cuda.amp for seamless mixed precision training:
scaler = torch.amp.GradScaler()
for batch in train_loader:
optimizer.zero_grad()
# Forward pass in mixed precision
with torch.amp.autocast(device_type="cuda"):
outputs = model(inputs)
loss = criterion(outputs, targets)
# Backward pass with loss scaling
scaler.scale(loss).backward()
# Unscale gradients for clipping
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Optimizer step with automatic scale adjustment
scaler.step(optimizer)
scaler.update()
Key details:
- autocast automatically selects FP16 for operations that benefit (matmuls, convolutions) and keeps FP32 for operations that need precision (reductions, softmax, loss computation)
- GradScaler dynamically adjusts the loss scale factor. If gradients contain inf/nan (overflow), it skips the optimizer step and reduces the scale
- Gradient clipping must happen after scaler.unscale_() and before scaler.step()
12.7.3 Performance Benefits
Mixed precision training typically provides: - 2-3x speedup on modern NVIDIA GPUs (V100, A100, H100) - ~50% memory reduction for activations, enabling larger batch sizes - No loss in model quality when implemented correctly
Understanding the Memory Savings. A model with 100 million parameters in FP32 requires 400 MB for weights alone. In FP16, this drops to 200 MB. But the memory savings for activations are even more impactful: during training, PyTorch stores intermediate activations for the backward pass. For large models and batch sizes, activations can consume 10--100x more memory than parameters. Halving the precision of activations can double the batch size you can fit on a GPU.
BFloat16 (BF16). BFloat16 is an alternative 16-bit format with the same dynamic range as FP32 (8 exponent bits) but reduced precision (7 mantissa bits vs. FP32's 23). Because it has the same dynamic range as FP32, BF16 avoids the need for loss scaling entirely---there is no risk of gradient underflow. This makes BF16 simpler to use than FP16 and is the preferred format on A100+ GPUs and modern TPUs:
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
outputs = model(inputs)
loss = criterion(outputs, targets)
When BF16 is available, you can simplify the training loop by dropping the GradScaler entirely. This is why most modern large language model training uses BF16 rather than FP16.
When to be cautious with mixed precision. Certain operations are numerically sensitive and should always remain in FP32: - Softmax and log-softmax (small differences in large logits matter) - Loss computation (accumulated sums need precision) - Batch normalization running statistics (small updates to running mean/variance) - Optimizer state (Adam's moment estimates)
PyTorch's autocast handles most of these automatically, but if you implement custom operations, be mindful of which precision they require.
12.8 The Complete Training Loop
With all components in place, let us build a production-quality training loop. A well-structured training loop is the backbone of any deep learning project.
12.8.1 Anatomy of a Training Step
Each training step follows the same pattern:
- Zero gradients: Clear gradients from the previous step
- Forward pass: Compute model predictions
- Loss computation: Evaluate how far predictions are from targets
- Backward pass: Compute gradients via backpropagation
- Gradient processing: Clip gradients, check for anomalies
- Parameter update: Apply the optimizer step
- Scheduler step: Update the learning rate (if per-step schedule)
def train_one_step(
model: nn.Module,
batch: tuple[torch.Tensor, torch.Tensor],
criterion: nn.Module,
optimizer: torch.optim.Optimizer,
scaler: torch.amp.GradScaler | None = None,
max_grad_norm: float = 1.0,
device: torch.device = torch.device("cuda"),
) -> dict[str, float]:
"""Execute a single training step.
Args:
model: The neural network model.
batch: Tuple of (inputs, targets).
criterion: Loss function.
optimizer: Optimizer instance.
scaler: GradScaler for mixed precision (None for FP32).
max_grad_norm: Maximum gradient norm for clipping.
device: Device to run computations on.
Returns:
Dictionary with loss and gradient norm.
"""
inputs, targets = batch
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad(set_to_none=True) # More efficient than zero_grad()
if scaler is not None:
with torch.amp.autocast(device_type=device.type):
outputs = model(inputs)
loss = criterion(outputs, targets)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=max_grad_norm
)
scaler.step(optimizer)
scaler.update()
else:
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=max_grad_norm
)
optimizer.step()
return {"loss": loss.item(), "grad_norm": grad_norm.item()}
Note the use of optimizer.zero_grad(set_to_none=True). By default, zero_grad() fills gradients with zeros, but set_to_none=True sets them to None instead, which is slightly faster and reduces memory.
12.8.2 The Full Epoch Loop
A complete training epoch iterates over the entire dataset:
def train_one_epoch(
model: nn.Module,
train_loader: DataLoader,
criterion: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
scaler: torch.amp.GradScaler | None = None,
device: torch.device = torch.device("cuda"),
epoch: int = 0,
) -> dict[str, float]:
"""Train the model for one epoch.
Args:
model: The neural network.
train_loader: Training data loader.
criterion: Loss function.
optimizer: Optimizer.
scheduler: Learning rate scheduler (per-step).
scaler: GradScaler for mixed precision.
device: Computation device.
epoch: Current epoch number.
Returns:
Dictionary with average metrics for the epoch.
"""
model.train()
running_loss = 0.0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(train_loader):
metrics = train_one_step(
model, (inputs, targets), criterion, optimizer, scaler, device=device
)
running_loss += metrics["loss"]
# Track accuracy for classification
with torch.no_grad():
inputs, targets = inputs.to(device), targets.to(device)
with torch.amp.autocast(device_type=device.type, enabled=scaler is not None):
outputs = model(inputs)
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
if scheduler is not None:
scheduler.step()
return {
"loss": running_loss / len(train_loader),
"accuracy": 100.0 * correct / total,
"lr": optimizer.param_groups[0]["lr"],
}
12.8.3 Evaluation Loop
The evaluation loop is simpler but must handle two critical details: setting the model to eval mode and disabling gradient computation:
@torch.no_grad()
def evaluate(
model: nn.Module,
val_loader: DataLoader,
criterion: nn.Module,
device: torch.device = torch.device("cuda"),
) -> dict[str, float]:
"""Evaluate the model on a validation set.
Args:
model: The neural network.
val_loader: Validation data loader.
criterion: Loss function.
device: Computation device.
Returns:
Dictionary with evaluation metrics.
"""
model.eval() # Critical: switches batch norm, dropout to eval mode
running_loss = 0.0
correct = 0
total = 0
for inputs, targets in val_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
running_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
return {
"loss": running_loss / len(val_loader),
"accuracy": 100.0 * correct / total,
}
Forgetting model.eval() is one of the most common bugs in deep learning code. Batch normalization uses batch statistics instead of running statistics, and dropout remains active, leading to inconsistent and degraded evaluation results.
12.8.4 Checkpointing
Save model state periodically so you can resume training after interruptions and keep the best model based on validation performance:
def save_checkpoint(
model: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler.LRScheduler,
epoch: int,
best_val_acc: float,
filepath: str,
) -> None:
"""Save a training checkpoint.
Args:
model: The neural network.
optimizer: Optimizer state.
scheduler: Scheduler state.
epoch: Current epoch.
best_val_acc: Best validation accuracy so far.
filepath: Path to save the checkpoint.
"""
torch.save({
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"best_val_acc": best_val_acc,
}, filepath)
Always save the optimizer and scheduler state alongside the model---otherwise, resuming training restarts optimization from scratch, losing momentum and learning rate schedule progress.
12.8.5 Early Stopping
Early stopping monitors validation performance and stops training when it stops improving, preventing overfitting:
class EarlyStopping:
"""Stop training when validation metric stops improving.
Args:
patience: Number of epochs to wait for improvement.
min_delta: Minimum change to qualify as improvement.
"""
def __init__(self, patience: int = 10, min_delta: float = 0.0):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_score: float | None = None
def __call__(self, val_score: float) -> bool:
"""Check if training should stop.
Args:
val_score: Current validation metric (higher is better).
Returns:
True if training should stop.
"""
if self.best_score is None or val_score > self.best_score + self.min_delta:
self.best_score = val_score
self.counter = 0
else:
self.counter += 1
return self.counter >= self.patience
12.9 Debugging Training: When Things Go Wrong
Training a deep network rarely works perfectly on the first attempt. Developing a systematic approach to diagnosing and fixing training problems is one of the most valuable skills in deep learning.
12.9.1 Loss Curves: Your First Diagnostic
The training and validation loss curves contain a wealth of information:
| Pattern | Diagnosis | Action |
|---|---|---|
| Loss decreasing steadily | Healthy training | Continue |
| Loss not decreasing at all | LR too low, bug in loss, or data issue | Increase LR, verify loss computation, check data pipeline |
| Loss oscillating wildly | LR too high | Reduce LR or add gradient clipping |
| Loss explodes (NaN/Inf) | Numerical instability | Check for log(0), reduce LR, add gradient clipping, check data for NaN |
| Training loss low, val loss high | Overfitting | Add regularization, data augmentation, or early stopping |
| Both losses plateau at high value | Underfitting | Increase model capacity, train longer, reduce regularization |
| Validation loss improves then worsens | Classic overfitting | Use best checkpoint via early stopping |
Always log both training and validation losses at every epoch. Logging only training loss hides overfitting.
12.9.2 Gradient Monitoring
Monitoring gradient statistics reveals problems that loss curves cannot:
def log_gradient_stats(model: nn.Module, step: int) -> dict[str, float]:
"""Compute and log gradient statistics for all parameters.
Args:
model: The neural network.
step: Current training step.
Returns:
Dictionary of gradient statistics.
"""
total_norm = 0.0
stats = {}
for name, param in model.named_parameters():
if param.grad is not None:
grad_norm = param.grad.data.norm(2).item()
total_norm += grad_norm ** 2
stats[f"grad_norm/{name}"] = grad_norm
# Check for dead neurons (zero gradients)
zero_frac = (param.grad.data == 0).float().mean().item()
if zero_frac > 0.5:
stats[f"dead_neurons/{name}"] = zero_frac
stats["grad_norm/total"] = total_norm ** 0.5
return stats
Vanishing gradients: If gradient norms in early layers are orders of magnitude smaller than in later layers, gradients are vanishing. Solutions: use residual connections, better initialization, normalization layers, or gradient-friendly activations (ReLU instead of sigmoid).
Exploding gradients: If gradient norms are very large or increasing over time, gradients are exploding. Solutions: gradient clipping, lower learning rate, better initialization.
Dead neurons: If many ReLU neurons have zero gradient consistently, they are "dead"---stuck in the negative regime with zero output and zero gradient. Solutions: use Leaky ReLU or ELU, reduce learning rate, use better initialization.
12.9.3 The Overfit-One-Batch Test
Before training on the full dataset, verify that your model and training pipeline are correct by overfitting a single batch:
def overfit_one_batch(model, train_loader, criterion, optimizer, device, steps=200):
"""Attempt to overfit a single batch as a sanity check.
Args:
model: The neural network.
train_loader: Training data loader.
criterion: Loss function.
optimizer: Optimizer.
device: Computation device.
steps: Number of optimization steps.
"""
model.train()
batch = next(iter(train_loader))
inputs, targets = batch[0].to(device), batch[1].to(device)
for step in range(steps):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
if step % 20 == 0:
_, predicted = outputs.max(1)
acc = predicted.eq(targets).float().mean().item()
print(f"Step {step}: loss={loss.item():.4f}, acc={acc:.4f}")
If your model cannot achieve near-zero loss on a single batch, there is a bug in your code---fix that before training on the full dataset.
12.9.4 Common Bugs Checklist
Here is a systematic checklist for debugging training runs:
- Data pipeline: Are labels correctly aligned with inputs? Are inputs normalized correctly? Are there NaN values in the data?
- Model architecture: Is the output dimension correct for your task? Are you applying softmax before CrossEntropyLoss (you should not)?
- Loss function: Is the loss appropriate for your task? Are you passing the right arguments (logits vs. probabilities)?
- Optimizer: Is the learning rate in a reasonable range? Did you pass all model parameters to the optimizer?
- Train/eval mode: Are you calling
model.train()before training andmodel.eval()before evaluation? - Device consistency: Are model, inputs, and targets all on the same device?
- Gradient flow: Are you calling
loss.backward()andoptimizer.step()in the right order? Are you callingoptimizer.zero_grad()at the start of each step? - Data leakage: Is your validation set truly separate from training data? Are you accidentally fitting to the validation set?
- Reproducibility: Are you setting random seeds (
torch.manual_seed,numpy.random.seed,random.seed)? Note that full reproducibility on GPU requires additional settings.
12.9.5 Logging and Visualization
Use a logging framework to track all metrics over time. TensorBoard integrates well with PyTorch:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("runs/experiment_01")
# In training loop:
writer.add_scalar("train/loss", train_loss, epoch)
writer.add_scalar("train/accuracy", train_acc, epoch)
writer.add_scalar("val/loss", val_loss, epoch)
writer.add_scalar("val/accuracy", val_acc, epoch)
writer.add_scalar("lr", optimizer.param_groups[0]["lr"], epoch)
# Log histograms of weights and gradients
for name, param in model.named_parameters():
writer.add_histogram(f"weights/{name}", param, epoch)
if param.grad is not None:
writer.add_histogram(f"gradients/{name}", param.grad, epoch)
writer.close()
Weights & Biases (wandb) is a popular alternative with cloud-hosted dashboards, experiment comparison, and hyperparameter sweep support.
12.9.6 Training Instabilities: A Deeper Look
Beyond the common bugs listed above, several deeper instabilities can plague training runs. Understanding their root causes helps you fix them faster.
Loss spikes. Occasional sudden increases in the loss, followed by recovery. Common causes:
- A batch containing outlier data points that produce extreme gradients
- Learning rate schedule transitions (e.g., a warmup that ramps up too aggressively)
- Mixed precision overflow (the GradScaler will skip these steps automatically)
If loss spikes are rare and the model recovers, they are usually harmless. If they are frequent, reduce the learning rate or increase gradient clipping.
Training collapse. The loss suddenly jumps to a very high value and never recovers. This is distinct from loss spikes because the model does not recover. Common causes: - Learning rate too high for the current phase of training - Numerical instability (e.g., division by zero in a custom loss, or log of a negative number) - A sudden change in data distribution (e.g., a corrupted data shard in distributed training)
Prevention: use gradient clipping, careful learning rate scheduling, and checkpointing so you can resume from a good state.
Oscillating validation loss. The validation loss oscillates without converging. This often indicates that the learning rate is too high for fine-tuning---the optimizer overshoots the minimum on each step. Reduce the learning rate or switch to a schedule with more aggressive decay.
NaN/Inf in loss or weights. This is a hard failure that requires immediate diagnosis. Common causes:
- Computing log(0) or log(negative) in the loss function
- Division by zero (e.g., in a normalization layer with all-zero inputs)
- FP16 overflow in mixed precision training (check if GradScaler is working correctly)
- Extremely large learning rate
Debug by adding torch.autograd.set_detect_anomaly(True) temporarily. This slows training significantly but pinpoints the exact operation that produced the NaN/Inf.
# Enable anomaly detection for debugging (disable in production)
with torch.autograd.detect_anomaly():
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
12.9.7 Practical Training Recipes
Here are proven recipes for common scenarios, incorporating all the techniques from this chapter.
Recipe 1: Training a CNN for Image Classification (e.g., on ImageNet) - Optimizer: SGD with momentum 0.9, Nesterov, weight decay 1e-4 - Learning rate: 0.1, reduced by 10x at epochs 30, 60, 90 (or cosine annealing) - Batch size: 256 across 4 GPUs - Data augmentation: RandomResizedCrop, RandomHorizontalFlip, ColorJitter - Normalization: Batch normalization after each convolutional layer - Training: 90--100 epochs - Mixup/CutMix for additional regularization in modern recipes
Recipe 2: Training a Transformer for NLP - Optimizer: AdamW with betas (0.9, 0.98), weight decay 0.01 - Learning rate: Peak 5e-4, linear warmup for 4,000 steps, then cosine decay - Batch size: Effective 256--2048 (with gradient accumulation) - Gradient clipping: max_norm=1.0 - Dropout: 0.1 on attention and feed-forward layers - Mixed precision: BFloat16 (or FP16 with GradScaler) - Label smoothing: 0.1
Recipe 3: Fine-Tuning a Pretrained Model - Optimizer: AdamW with weight decay 0.01 - Learning rate: 1e-5 to 5e-5 for pretrained layers, 10x higher for new head - Warmup: 100--500 steps - Epochs: 3--10 (much less than training from scratch) - Gradient clipping: max_norm=1.0 - Use parameter groups to set different learning rates for backbone vs. head
These recipes are not arbitrary---they represent the accumulated wisdom of thousands of experiments. When starting a new project, begin with the recipe closest to your setting and adjust based on your validation metrics.
12.10 Regularization Techniques for Training
Beyond the architectural choices already discussed (normalization, weight decay), several additional regularization techniques are essential for training deep networks.
12.10.1 Dropout
Dropout (Srivastava et al., 2014) randomly zeroes elements of a layer's output during training with probability $p$, and scales the remaining elements by $\frac{1}{1-p}$ to maintain the expected value:
self.dropout = nn.Dropout(p=0.5)
During evaluation (when model.eval() is called), dropout is disabled and outputs are passed through unchanged. Dropout acts as an ensemble method---at each training step, a different sub-network is trained, and at test time, the full network approximates an average of all sub-networks.
Typical dropout rates: - 0.5 for fully connected layers (the original paper's recommendation) - 0.1-0.3 for transformers - 0.0-0.1 for convolutional layers (spatial dropout is preferred)
12.10.2 Data Augmentation
Data augmentation artificially expands the training set by applying random transformations to inputs. For image data:
import torchvision.transforms as T
train_transform = T.Compose([
T.RandomCrop(32, padding=4),
T.RandomHorizontalFlip(),
T.ColorJitter(brightness=0.2, contrast=0.2),
T.ToTensor(),
T.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]),
])
Advanced augmentation strategies like Mixup, CutMix, and RandAugment can further improve generalization and are now standard in competitive image classification pipelines.
12.10.3 Weight Decay Revisited
As discussed in the optimizer section, weight decay penalizes large weights. In AdamW, this is decoupled from the gradient:
$$\theta_{t+1} = (1 - \eta \lambda)\theta_t - \eta \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$$
Typical weight decay values: - 0.01: Standard for AdamW in NLP/transformers - 5e-4: Standard for SGD in computer vision - 0.0: Often used for biases and normalization layer parameters
It is common practice to exclude biases and normalization parameters from weight decay:
decay_params = []
no_decay_params = []
for name, param in model.named_parameters():
if "bias" in name or "norm" in name or "bn" in name:
no_decay_params.append(param)
else:
decay_params.append(param)
optimizer = torch.optim.AdamW([
{"params": decay_params, "weight_decay": 0.01},
{"params": no_decay_params, "weight_decay": 0.0},
], lr=3e-4)
12.11 Putting It All Together: A Training Recipe
Here is a complete recipe for training a classification model on a GPU, incorporating all the techniques from this chapter:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
torch.manual_seed(42)
# 1. Model
model = build_model() # Your architecture from Chapter 11
model = model.to(device)
model.apply(init_weights) # Custom initialization
# 2. Loss
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
# 3. Optimizer with parameter groups
optimizer = torch.optim.AdamW([
{"params": decay_params, "weight_decay": 0.01},
{"params": no_decay_params, "weight_decay": 0.0},
], lr=3e-4)
# 4. Learning rate schedule: warmup + cosine
total_steps = num_epochs * len(train_loader)
scheduler = warmup_cosine_schedule(optimizer, warmup_steps=500, total_steps=total_steps)
# 5. Mixed precision
scaler = torch.amp.GradScaler()
# 6. Training loop
best_val_acc = 0.0
early_stop = EarlyStopping(patience=15)
for epoch in range(num_epochs):
train_metrics = train_one_epoch(
model, train_loader, criterion, optimizer, scheduler, scaler, device, epoch
)
val_metrics = evaluate(model, val_loader, criterion, device)
print(f"Epoch {epoch}: "
f"train_loss={train_metrics['loss']:.4f}, "
f"train_acc={train_metrics['accuracy']:.2f}%, "
f"val_loss={val_metrics['loss']:.4f}, "
f"val_acc={val_metrics['accuracy']:.2f}%")
if val_metrics["accuracy"] > best_val_acc:
best_val_acc = val_metrics["accuracy"]
save_checkpoint(model, optimizer, scheduler, epoch, best_val_acc, "best.pt")
if early_stop(val_metrics["accuracy"]):
print(f"Early stopping at epoch {epoch}")
break
This recipe is not a rigid template---adapt it to your specific problem. But it captures the key design decisions and best practices that lead to successful training.
12.12 Advanced Training Techniques
12.12.1 Gradient Accumulation
When GPU memory limits your batch size, gradient accumulation simulates larger batches by accumulating gradients over multiple forward passes before stepping:
accumulation_steps = 4 # Effective batch size = batch_size * 4
for i, (inputs, targets) in enumerate(train_loader):
with torch.amp.autocast(device_type="cuda"):
outputs = model(inputs.to(device))
loss = criterion(outputs, targets.to(device))
loss = loss / accumulation_steps # Normalize loss
scaler.scale(loss).backward()
if (i + 1) % accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
scheduler.step()
Why divide the loss by accumulation_steps? By default, PyTorch averages the loss over the mini-batch. When you accumulate gradients over 4 mini-batches, loss.backward() adds gradients to .grad without clearing them. If you do not divide the loss, the accumulated gradients are 4x too large (they represent a sum rather than a mean over the effective batch). Dividing by the accumulation steps corrects this, making gradient accumulation mathematically equivalent to using a single large batch.
When to use gradient accumulation. Gradient accumulation is essential for: - Large language models that require batch sizes of 256+ but each sample barely fits in GPU memory - High-resolution image models where a single batch of 4 images fills the GPU - Reproducibility when comparing results across different GPU configurations (same effective batch size, different hardware)
The only caveat is that batch normalization layers compute statistics over the physical mini-batch (not the accumulated batch). This means the batch statistics may be noisier with gradient accumulation. Group normalization or layer normalization (as we discussed in Section 12.4) sidestep this issue.
12.12.1.1 Batch Size Effects on Training
The batch size has profound effects on both training dynamics and generalization:
Gradient noise. A small batch produces a noisy gradient estimate with high variance; a large batch produces a smoother, lower-variance estimate. Surprisingly, moderate gradient noise can be beneficial---it acts as implicit regularization, helping the optimizer escape sharp minima and find flatter minima that generalize better (Keskar et al., 2017).
The large-batch generalization gap. Training with very large batches (without adjusting the learning rate or other hyperparameters) often leads to worse generalization. The optimizer converges to sharp minima that overfit. Mitigations include: - Linear learning rate scaling: As discussed in Section 12.13.2, scale the learning rate proportionally to batch size - Longer warmup: Large batches need more warmup steps to stabilize - LARS/LAMB optimizers: Specialized optimizers that scale the learning rate per layer based on the ratio of weight norm to gradient norm
Training speed vs. generalization trade-off. In practice: - Batch size 32--128: Good generalization, moderate training speed - Batch size 256--1024: Faster training, may need learning rate adjustments - Batch size 4096+: Requires careful tuning (warmup, LR scaling, LAMB optimizer) to avoid generalization loss
A practical rule: start with the largest batch size that fits in GPU memory, use gradient accumulation if needed, and verify that validation performance matches a smaller-batch baseline.
12.12.1.2 Loss Landscape Visualization
Understanding the geometry of the loss landscape helps explain why certain training choices work. Li et al. (2018) introduced a technique for visualizing loss landscapes by plotting the loss along random directions in parameter space:
- Choose two random direction vectors $\boldsymbol{\delta}_1$ and $\boldsymbol{\delta}_2$ in the parameter space
- Normalize them using "filter normalization" (scale each filter's direction to have the same norm as the corresponding filter in the model)
- Compute the loss at positions $\boldsymbol{\theta}^* + \alpha \boldsymbol{\delta}_1 + \beta \boldsymbol{\delta}_2$ for a grid of $(\alpha, \beta)$ values
- Plot the resulting 2D surface
These visualizations reveal several important phenomena: - Skip connections flatten the landscape: ResNets (Chapter 13) have dramatically smoother loss surfaces than equivalent networks without skip connections - Batch normalization smooths the landscape: Consistent with its empirical benefits for training stability - Sharp vs. flat minima: Networks that generalize well tend to converge to broader, flatter minima
While you rarely need to generate these visualizations in practice, understanding that the loss landscape is not a smooth bowl---but a rugged terrain with sharp valleys, saddle points, and plateaus---helps you reason about why learning rate schedules, warmup, and stochastic gradient noise matter.
12.12.2 Exponential Moving Average (EMA)
Maintaining an exponential moving average of model weights often produces a smoother, better-generalizing model:
$$\theta_{\text{EMA}} = \alpha \theta_{\text{EMA}} + (1 - \alpha) \theta$$
with $\alpha$ typically 0.999 or 0.9999. The EMA model is used for evaluation:
class EMAModel:
"""Exponential Moving Average of model parameters.
Args:
model: The source model.
decay: EMA decay rate.
"""
def __init__(self, model: nn.Module, decay: float = 0.999):
self.decay = decay
self.shadow = {
name: param.clone().detach()
for name, param in model.named_parameters()
}
@torch.no_grad()
def update(self, model: nn.Module) -> None:
"""Update EMA parameters.
Args:
model: The source model with updated parameters.
"""
for name, param in model.named_parameters():
self.shadow[name].mul_(self.decay).add_(param, alpha=1 - self.decay)
def apply(self, model: nn.Module) -> None:
"""Copy EMA parameters to the model.
Args:
model: The target model.
"""
for name, param in model.named_parameters():
param.data.copy_(self.shadow[name])
12.12.3 Stochastic Weight Averaging (SWA)
SWA averages model weights over the last several epochs of training, finding flatter minima that generalize better. PyTorch provides built-in support:
from torch.optim.swa_utils import AveragedModel, SWALR
swa_model = AveragedModel(model)
swa_scheduler = SWALR(optimizer, swa_lr=0.001)
# During SWA phase (last N epochs):
swa_model.update_parameters(model)
swa_scheduler.step()
# After training, update batch norm statistics:
torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device)
12.13 Distributed Training: Scaling to Multiple GPUs
When a single GPU is not enough, PyTorch's DistributedDataParallel (DDP) scales training across multiple GPUs with near-linear speedup.
12.13.1 DataParallel vs. DistributedDataParallel
torch.nn.DataParallel (DP) is simpler but slower---it replicates the model on each GPU per forward pass and gathers gradients on a single GPU, creating a bottleneck. DistributedDataParallel (DDP) uses one process per GPU with efficient all-reduce communication:
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
# Initialize process group
dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
# Wrap model
model = model.to(local_rank)
model = DDP(model, device_ids=[local_rank])
# Use DistributedSampler for the data loader
from torch.utils.data import DistributedSampler
sampler = DistributedSampler(train_dataset)
train_loader = DataLoader(train_dataset, sampler=sampler, batch_size=batch_size)
# Remember to set epoch for proper shuffling
for epoch in range(num_epochs):
sampler.set_epoch(epoch)
train_one_epoch(...)
DDP synchronizes gradients automatically during backward pass using the NCCL backend for GPU-to-GPU communication. The training code remains almost unchanged---just wrap the model and use DistributedSampler.
12.13.2 Learning Rate Scaling
When increasing the effective batch size via DDP (or gradient accumulation), you typically need to scale the learning rate. The linear scaling rule suggests scaling learning rate proportionally to batch size:
$$\eta_{\text{new}} = \eta_{\text{base}} \times \frac{B_{\text{new}}}{B_{\text{base}}}$$
This rule works well up to moderate batch sizes but breaks down for very large batches. Warmup becomes even more important when scaling to large batch sizes to stabilize early training.
Summary
Training deep networks is a multi-faceted engineering challenge that requires coordinating many interacting components. In this chapter, we covered:
-
Loss functions translate your objective into gradients. Use CrossEntropyLoss for classification (never apply softmax first), MSE for regression, and focal loss for imbalanced data.
-
Optimizers navigate the loss landscape. AdamW is the default for most tasks; SGD with momentum excels for vision tasks with long training schedules.
-
Learning rate schedules modulate the pace of learning. Cosine annealing with warmup is a robust default. The learning rate is the most important hyperparameter.
-
Normalization layers stabilize training. Batch normalization for CNNs, layer normalization for transformers.
-
Weight initialization sets the stage for gradient flow. Use He initialization for ReLU networks and Xavier for tanh/sigmoid.
-
Gradient clipping prevents catastrophic gradient explosions. Essential for transformers and RNNs.
-
Mixed precision training provides 2-3x speedup with no quality loss. Use PyTorch AMP with GradScaler.
-
The training loop must carefully order operations: zero grad, forward, loss, backward, clip, step, schedule.
-
Debugging starts with loss curves and the overfit-one-batch test. Monitor gradient norms and check the common bugs checklist.
Each of these components builds on the mathematical foundations from Chapter 3 and the neural network architecture concepts from Chapter 11. In Chapter 13, we will apply these training techniques to convolutional neural networks for image recognition, and in subsequent chapters to transformers and other advanced architectures.
The key takeaway is this: training is engineering. Success comes not from a single brilliant insight but from the disciplined application of well-understood principles---choosing the right loss, optimizer, and schedule; initializing and normalizing properly; monitoring relentlessly; and debugging systematically.