Writing a Complete PyTorch Training Loop

Writing a training loop requires managing model modes, propagating losses, updating parameters, and evaluating validation metrics. This loop defines the iterative process that trains neural networks.


Training Loop Anatomy

The training loop contains nested epoch and batch iterations, managing parameter states and gradient steps.

Mode Setting and Phase Loops

The training script contains two loops: an outer epoch loop that tracks passes over the entire dataset, and an inner batch loop that iterates over data batches. Before starting training, we set the model to training mode using model.train().

This setting is critical because layers like Dropout and Batch Normalization behave differently during training compared to evaluation. Training mode enables dropout regularization and tracks batch statistics, ensuring proper parameter updates.

The Optimization Sequence

For each batch, we execute a mandatory sequence of three steps: clearing gradients using optimizer.zero_grad(), computing gradients using loss.backward(), and updating parameters using optimizer.step().

Clearing gradients is necessary because PyTorch accumulates gradients by default. Without resetting, gradients from the current batch would add to past gradients, leading to incorrect updates.

Evaluation and Performance Tracking

Evaluating validation metrics during training helps track model generalization and detect overfitting.

The Validation Phase

At the end of each training epoch, we run a validation loop to evaluate the model on unseen data. We switch the model to evaluation mode using model.eval() and wrap the computations in a torch.no_grad() block.

This evaluation phase checks model generalization, tracking validation loss and accuracy without updating weights or storing gradients, saving computation memory.

Logging and Early Stopping

Logging metrics across epochs provides diagnostic information. We track both training and validation losses, comparing their trends to monitor for overfitting.

If the validation loss starts to rise while the training loss continues to fall, the model is overfitting. We can implement early stopping to save the best checkpoint and stop training before generalization decays.

PyTorch Implementation

We can write a complete, runnable PyTorch training loop on synthetic binary classification data.

Coding a Full Training Loop

This script runs a complete training and evaluation pipeline for a classification model:

<pre><code class="language-python">import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import TensorDataset, DataLoader # 1. Setup Model, Data, and Optimizer model = nn.Sequential(nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 1)) optimizer = optim.Adam(model.parameters(), lr=0.01) criterion = nn.BCEWithLogitsLoss() # Create dummy train dataset (100 samples) x_train, y_train = torch.randn(100, 10), torch.randint(0, 2, (100, 1)).float() train_loader = DataLoader(TensorDataset(x_train, y_train), batch_size=20, shuffle=True) # 2. Training Loop for epoch in range(5): model.train() # Set model to training mode epoch_loss = 0.0 for batch_x, batch_y in train_loader: # Zero gradients optimizer.zero_grad() # Forward pass predictions = model(batch_x) loss = criterion(predictions, batch_y) # Backward pass loss.backward() # Update weights optimizer.step() epoch_loss += loss.item() * len(batch_x) print(f"Epoch {epoch} | Loss: {epoch_loss / len(x_train):.4f}")</pre>

In this code, we run a 5-epoch training loop. For each batch, we zero gradients, forward inputs, compute BCE loss, propagate gradients, and update parameters, showing the standard training pipeline.

Metrics Tracking and Printing

To evaluate performance, we can compute accuracy during the loop. Since our loss function uses logits (BCEWithLogitsLoss), we pass the output through a sigmoid function and threshold at 0.5 to obtain binary class predictions.

Comparing these predictions with target labels provides the batch accuracy, which we accumulate and print to track model performance.