Fully Connected Layers at the End of CNNs
Fully connected layers map the spatially structured feature maps extracted by CNNs to flat, 1D vectors representing final class logits.
Flattening and Logits Mapping
CNNs extract 3D feature tensors that must be flattened into 1D vectors before entering dense layers.
Connecting Conv Output to Dense Input
The output of the final convolutional or pooling layer is a 3D feature tensor of shape \\((C, H, W)\\). To feed this tensor into a fully connected (dense) layer, it must be flattened into a 1D vector of length \\(C \\times H \\times W\\).
Once flattened, this vector is passed through one or more fully connected layers, which compute global combinations of the extracted features. The final dense layer has a width matching the number of target classes, producing raw logits for classification.
Spatial Resolution Dependence
Because the input size of a fully connected layer is fixed, any change in the input image resolution will change the spatial dimensions \\((H, W)\\) of the final feature map, altering the length of the flattened vector and causing a shape mismatch error.
As a result, networks with fully connected layers at the end require inputs to have a fixed resolution (e.g., 224x224). To handle variable resolutions, modern networks replace standard flattening with Global Average Pooling before the dense layer.
Global Average Pooling vs. Fully Connected Layers
Replacing traditional flattening with Global Average Pooling reduces model parameter counts and prevents overfitting.
Parameter Reduction
Standard fully connected layers at the end of a CNN can account for up to 80% of the model's total parameters, making them a common source of overfitting and high memory consumption.
Global Average Pooling (GAP) addresses this issue by averaging the spatial dimensions of each channel, reducing the \\((C, H, W)\\) tensor to \\((C, 1, 1)\\). This reduces the inputs to the final dense layer to exactly \\(C\\), drastically reducing parameter counts.
Generalization and Overfitting
By replacing dense layers with GAP, the model is forced to learn spatial representations that are robust across the entire image, improving generalization and reducing the risk of overfitting.
Additionally, because GAP collapses any spatial dimensions to 1x1, the network can accept inputs of variable resolutions during inference without modifying the architecture.
PyTorch Implementation
Let's compare the code structure of traditional flattening and modern global average pooling classifiers.
Traditional Flattening and Dense Layer
The code below shows the traditional approach of flattening feature maps and passing them to fully connected layers in PyTorch.
<pre><code class="language-python">import torch import torch.nn as nn class TraditionalClassifier(nn.Module): def __init__(self, num_classes): super().__init__() self.features = nn.Sequential( nn.Conv2d(3, 16, kernel_size=3, padding=1), nn.MaxPool2d(2, 2) # Divides spatial dimensions by 2 ) # Input: 32x32. After MaxPool: 16x16. # Flattened size: 16 channels * 16 * 16 = 4096 features self.classifier = nn.Linear(16 * 16 * 16, num_classes) def forward(self, x): x = self.features(x) x = torch.flatten(x, 1) # Flatten starting from channel dimension return self.classifier(x) x = torch.randn(2, 3, 32, 32) model = TraditionalClassifier(num_classes=10) print("Output shape:", model(x).shape) # [2, 10]</pre>In this setup, if we pass a 64x64 input image instead of 32x32, the code will raise a shape mismatch runtime error when executing the fully connected layer, demonstrating the spatial resolution dependence of traditional classifiers.
Modern GAP Classifier
Below is the modern approach using Global Average Pooling, which makes the model robust to variable input sizes and reduces parameter count.
<pre><code class="language-python">import torch import torch.nn as nn class GAPClassifier(nn.Module): def __init__(self, num_classes): super().__init__() self.features = nn.Sequential( nn.Conv2d(3, 16, kernel_size=3, padding=1), nn.AdaptiveAvgPool2d(1) # Collapses spatial dimensions to 1x1 ) self.classifier = nn.Linear(16, num_classes) def forward(self, x): x = self.features(x) x = torch.flatten(x, 1) # Flatten to [batch_size, 16] return self.classifier(x) model = GAPClassifier(num_classes=10) # Test with two different input resolutions x1 = torch.randn(2, 3, 32, 32) x2 = torch.randn(2, 3, 64, 64) print("x1 Output shape:", model(x1).shape) # [2, 10] print("x2 Output shape:", model(x2).shape) # [2, 10]</pre>Both outputs return correctly with a shape of [2, 10], verifying that the Global Average Pooling layer successfully handles variable input sizes by collapsing the spatial maps to 1x1 before the dense projection layer.