P e x c e r a

The Vanishing and Exploding Gradient Problems

The vanishing and exploding gradient problems are major challenges in training deep neural networks. They occur when gradients decay or grow exponentially as they flow backward through layers, preventing early layers from learning or causing training to diverge.


Dynamic Range of Gradients

Gradients can scale exponentially with depth, causing training stall or numerical instability.

The Vanishing Gradient Problem

The vanishing gradient problem occurs when the magnitude of gradients decays exponentially as they propagate backward. By the time the gradient reaches the early layers of the network, its value is close to zero.

As a result, the weights of the early layers update extremely slowly, leaving them near their random initial states. This prevents the network from learning basic low-level features, limiting the benefit of deep architectures.

The Exploding Gradient Problem

The exploding gradient problem is the converse issue, where gradients accumulate exponentially as they propagate backward. This results in massive parameter updates during training.

These large updates cause weights to oscillate wildly, preventing convergence. In extreme cases, the values overflow the computer's floating-point precision, resulting in NaN (Not a Number) values that crash training, requiring regularization.

Mathematical Causes

Gradient instability is driven by repeated matrix multiplications and activation function saturation.

Repeated Matrix Multiplications

During the backward pass, the gradient of the loss with respect to the first layer's weights contains a product of the weight matrices of all subsequent layers:

$$\frac{\partial \mathcal{L}}{\partial \mathbf{W}^{[1]}} \propto \prod_{l=2}^L \mathbf{W}^{[l]}$$

If the weights are initialized such that the eigenvalues of $\mathbf{W}^{[l]}$ are less than 1.0, the product will decay exponentially toward zero as $L$ increases. If the eigenvalues are greater than 1.0, the product will grow exponentially, causing gradients to explode.

Activation Function Saturation

Vanishing gradients are also caused by activation functions that saturate for large inputs, such as sigmoid and tanh. The derivatives of these functions approach zero when inputs are highly positive or negative.

When backpropagating through multiple saturating layers, the local derivatives are multiplied repeatedly. Because the maximum derivative of sigmoid is $0.25$, multiplying these values across layers leads to rapid gradient decay, highlighting the advantage of ReLU activations.

PyTorch Diagnostics

We can diagnose gradient issues in PyTorch by logging gradient norms during the backward pass.

Simulating Gradient Problems

This PyTorch script illustrates how poor initialization scale leads to gradient collapse in a deep network:

<pre><code class="language-python">import torch import torch.nn as nn # Network with 10 linear layers and Sigmoid activations net = nn.Sequential(*[nn.Sequential(nn.Linear(50, 50), nn.Sigmoid()) for _ in range(10)]) # Initialize weights with standard normal distribution (std = 1.0, too large) for layer in net.modules(): if isinstance(layer, nn.Linear): nn.init.normal_(layer.weight, mean=0.0, std=1.0) x = torch.randn(5, 50) y = net(x) loss = y.sum() loss.backward() # Log gradients of the first and last linear layers linear_layers = [m for m in net.modules() if isinstance(m, nn.Linear)] print("First layer weight grad norm:", torch.norm(linear_layers[0].weight.grad).item()) print("Last layer weight grad norm:", torch.norm(linear_layers[-1].weight.grad).item())</pre>

In this code, we evaluate gradients. Because we initialized the weights with a standard deviation of 1.0, activations saturate quickly. The gradient of the first layer is orders of magnitude smaller than the last layer, illustrating vanishing gradients.

Registering Backward Hooks

We can track gradient norms dynamically during training by registering backward hooks on network parameters. These hooks execute a callback function whenever gradients are computed, allowing us to log and inspect gradient flow.

This logging helps diagnose gradient instability issues early, allowing us to adjust initializations, learning rates, or optimizer configurations before training diverges.