Backpropagation Through Time (BPTT)
Training recurrent networks requires calculating gradients across both network layers and time. Backpropagation Through Time (BPTT) unrolls the network and propagates errors backward, though we often truncate this pass to keep compute and memory costs manageable.
BPTT Gradient Flow
BPTT unrolls the RNN across the sequence length, treating it as a deep feedforward network, and calculates parameter gradients by summing contributions from each step.
Summing Shared Gradients
Since the weight matrices are shared across all steps, the gradient of the total loss L with respect to a weight matrix (e.g., W_{hh}) is the sum of the gradients calculated at each individual time step: \\frac{\\partial L}{\\partial W_{hh}} = \\sum_{t=1}^T \\frac{\\partial L_t}{\\partial W_{hh}}.
The Differentiable Chain
The gradient at step t depends on the gradient from step t+1, which requires multiplying terms backward through the chain of hidden states: \\frac{\\partial h_t}{\\partial h_k} = \\prod_{j=k+1}^t \\frac{\\partial h_j}{\\partial h_{j-1}}.
Truncated BPTT
For long sequences (e.g., thousands of steps), standard BPTT is computationally expensive and requires storing all intermediate activations in memory, which can lead to Out-Of-Memory errors. Truncated BPTT resolves this.
Mitigating Memory Costs
Truncated BPTT processes the sequence in chunks (e.g., 50 steps). We run the forward pass, propagate errors backward only within the current chunk, update weights, and pass the final hidden state to the next chunk without propagating its gradients backward.
PyTorch Implementation Concept
To implement truncated BPTT in PyTorch, detach the hidden state tensor between sequence steps to break the gradient history.
<pre><code class="language-python"># Iterate over chunks of a long sequence state = None for x_chunk, y_chunk in dataset_loader: # Forward pass output, state = rnn(x_chunk, state) # Detach state to prevent backpropagating to the previous chunk if isinstance(state, tuple): # For LSTMs, state is (h, c) state = (state[0].detach(), state[1].detach()) else: # For RNNs/GRUs, state is h state = state.detach() loss = criterion(output, y_chunk) loss.backward() optimizer.step() optimizer.zero_grad()</pre>