Optimizers: SGD with Momentum
SGD with Momentum is an optimization variant inspired by physical mechanics. By accumulating a velocity vector of past gradients, momentum accelerates convergence in consistent directions and dampens oscillations.
Physics Analogy and Formulation
Momentum models parameter updates as a heavy ball rolling down a potential well, using inertia to stabilize updates.
The Momentum Concept
Standard SGD updates parameters based only on the current gradient step, which leads to slow progress in flat regions. Momentum introduces a velocity vector that accumulates past gradients, allowing the optimizer to maintain momentum in consistent directions.
This is analogous to a physics simulation where friction slows the ball down, but gravity pulls it downhill. The accumulated velocity helps the optimizer push through flat plateaus and escape shallow saddle points.
The Momentum Equations
The parameter update with momentum is governed by two coupled equations:
$$\mathbf{v}_{t+1} = \beta \mathbf{v}_t + \eta \nabla_\theta \mathcal{L}(\theta_t)$$
$$\theta_{t+1} = \theta_t - \mathbf{v}_{t+1}$$
where $\mathbf{v}_t$ is the velocity vector and $\beta \in [0, 1)$ is the momentum decay coefficient (typically 0.9). The parameter $\beta$ controls the influence of historical gradients on the current update.
Oscillation Damping and Advanced Variants
Momentum dampens transverse oscillations in high-curvature valleys, and can be enhanced using look-ahead updates.
Dampening Cross-Valley Oscillations
In ravines with steep walls, standard SGD oscillates between the walls because the gradients point in opposite directions across iterations. Since these transverse gradients alternate in sign, they cancel out when summed over time.
Conversely, the gradients along the valley floor point consistently toward the minimum, causing their corresponding velocity terms to reinforce each other. This dampens the transverse oscillations and accelerates progress along the valley floor.
Nesterov Accelerated Gradient
Nesterov Accelerated Gradient (NAG) is a look-ahead variant that computes the gradient at the predicted future position of the parameters rather than the current position:
$$\mathbf{v}_{t+1} = \beta \mathbf{v}_t + \eta \nabla_\theta \mathcal{L}(\theta_t - \beta \mathbf{v}_t)$$
$$\theta_{t+1} = \theta_t - \mathbf{v}_{t+1}$$
This look-ahead update acts as a braking mechanism. If the momentum is carrying the optimizer toward an ascending slope, the gradient calculated at the future position will warn the optimizer to slow down, reducing overshoot.
PyTorch Implementation
We can configure momentum directly in PyTorch's SGD optimizer or implement velocity tracking manually.
Configuring Momentum in PyTorch
We can enable standard momentum or Nesterov acceleration in PyTorch by setting the corresponding optimizer parameters:
<pre><code class="language-python">import torch import torch.nn as nn import torch.optim as optim model = nn.Linear(2, 1) # Standard Momentum (beta = 0.9) optimizer_momentum = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) # Nesterov Accelerated Gradient (NAG) optimizer_nesterov = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, nesterov=True) print("Optimizers configured successfully.")</pre>In this code, we initialize both momentum configurations. Setting nesterov=True requires a non-zero momentum value, instructing PyTorch to compute look-ahead updates during the step pass.
Coding a Custom Momentum Optimizer
We can implement a custom optimizer with momentum tracking in PyTorch by managing the velocity state in the optimizer's state dictionary:
<pre><code class="language-python">class CustomMomentumSGD(torch.optim.Optimizer): def __init__(self, params, lr=0.01, beta=0.9): defaults = dict(lr=lr, beta=beta) super().__init__(params, defaults) def step(self, closure=None): loss = None for group in self.param_groups: lr = group['lr'] beta = group['beta'] for p in group['params']: if p.grad is None: continue # Retrieve or initialize velocity state state = self.state[p] if 'velocity' not in state: state['velocity'] = torch.zeros_like(p.data) v = state['velocity'] # Update velocity: v = beta * v + grad v.mul_(beta).add_(p.grad.data) # Update parameter: p = p - lr * v p.data.add_(v, alpha=-lr) return loss custom_opt = CustomMomentumSGD(model.parameters(), lr=0.01, beta=0.9)</pre>This implementation stores the velocity vector for each parameter in the self.state dictionary, ensuring that velocity is tracked across steps. The updates use in-place operations to maintain memory efficiency.