Building a Custom PyTorch nn.Module

The nn.Module is the base class for all neural network components in PyTorch, providing the template for registering learnable parameters and defining forward propagation.


Anatomy of a PyTorch Module

A custom model subclassing nn.Module registers its layers in the constructor and defines the data flow in the forward method.

Initializing Layers and Forward Pass

When subclassing nn.Module, the constructor __init__ must call super().__init__() to initialize the underlying PyTorch base class. In the constructor, you define the learnable layers and parameters of the network. Any object assigned as an attribute that inherits from nn.Module (such as nn.Linear or nn.Conv2d) is automatically registered as a submodule, and its parameters are added to the model's parameters list.

The data flow is defined in the forward method. This method accepts input tensors and specifies the sequence of layer applications, mathematical calculations, and activation functions. Crucially, developers do not call forward(x) directly; instead, they call the model instance as a function (e.g., model(x)). This invokes PyTorch's internal execution hooks, ensuring that gradients and submodules are tracked correctly during the forward pass.

Parameters, Buffers, and the State Dict

Under the hood, nn.Module manages model parameters through the nn.Parameter class. Parameters are tensors that require gradients and are updated by optimizers during training. If you need to register a tensor that does not require gradients but is still part of the model's state (such as the running mean in Batch Normalization), you register it as a buffer using register_buffer.

The collection of all registered parameters and buffers is called the state_dict. This is a standard Python dictionary mapping parameter names to their corresponding tensors. Saving and loading models in PyTorch is done by serializing this state_dict using torch.save and deserializing it using load_state_dict, which allows for clean model sharing and deployment.

Building a Multi-Layer Perceptron Class

Let's write a custom Multi-Layer Perceptron (MLP) with dropout, activation functions, and fully commented tensor shapes.

MLP Architecture Implementation

Here is a complete, modular subclass of nn.Module representing an MLP with a hidden layer, batch normalization, and dropout.

<pre><code class="language-python">import torch import torch.nn as nn class CustomMLP(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super().__init__() # Define layers self.fc1 = nn.Linear(input_dim, hidden_dim) self.bn = nn.BatchNorm1d(hidden_dim) self.relu = nn.ReLU() self.dropout = nn.Dropout(p=0.25) self.fc2 = nn.Linear(hidden_dim, output_dim) def forward(self, x): # Input shape: [batch_size, input_dim] x = self.fc1(x) # [batch_size, hidden_dim] x = self.bn(x) # [batch_size, hidden_dim] x = self.relu(x) # [batch_size, hidden_dim] x = self.dropout(x) # [batch_size, hidden_dim] logits = self.fc2(x) # [batch_size, output_dim] return logits # Instantiate model model = CustomMLP(input_dim=10, hidden_dim=32, output_dim=2) print(model)</pre>

In this architecture, the linear layers transform the feature dimensions, the batch normalization layer stabilizes activations, and the dropout layer randomly drops activations to prevent overfitting. Commenting on the tensor dimensions at each step is a best practice that helps prevent dimension mismatch errors.

Inspecting Parameters and Submodules

We can inspect the parameters of our custom module using the `parameters()` or `named_parameters()` methods. This is useful for verifying that all weights are registered and checking their shapes. It also allows developers to freeze specific parameters during transfer learning by setting `requires_grad = False` on selected parameters.

Submodules can be accessed using `children()` or `modules()`. The `modules()` method recursively traverses all nested submodules, whereas `children()` only returns the immediate submodules of the top-level class. This distinction is important when implementing custom parameter initializations or weight decay exclusions across specific layers.

Execution Modes and Custom Parameters

Managing model states and registering custom tensor variables are essential for implementing advanced neural architectures.

Train vs. Eval Modes

Certain layers like nn.BatchNorm1d and nn.Dropout behave differently during training and evaluation. Dropout must be active during training to regularize weights, but must be deactivated during evaluation to make deterministic predictions. BatchNorm must update its running statistics during training, but use these static running statistics during validation.

Calling model.train() sets the model and all its submodules to training mode, activating dropout and batch statistics. Calling model.eval() transitions the model to evaluation mode, freezing statistics and deactivating dropout. Forgetting to set the model to eval() during testing can lead to inconsistent predictions.

Defining Custom Parameters Manually

While pre-built layers like nn.Linear handle their own weights, you can also define custom learnable parameters manually using nn.Parameter. This is useful for implementing custom layers or attention mechanisms where weights are not tied to standard linear transformations.

<pre><code class="language-python">import torch import torch.nn as nn class CustomScalingLayer(nn.Module): def __init__(self, size): super().__init__() # Register a custom learnable weight vector self.scale = nn.Parameter(torch.ones(size)) def forward(self, x): # element-wise multiplication return x * self.scale</pre>

By wrapping the tensor in `nn.Parameter`, PyTorch registers it as a learnable parameter, ensuring that it appears in `model.parameters()` and is updated by the optimizer. Without this wrapper, the tensor would be treated as a static constant and would not receive gradient updates.