Faster R-CNN and Region Proposal Networks

While Fast R-CNN optimized classification, it still relied on slow external algorithms like Selective Search for region proposals. Faster R-CNN eliminated this bottleneck by introducing the Region Proposal Network (RPN), creating an end-to-end trainable detector.


The Region Proposal Network (RPN)

The RPN is a fully convolutional network that slides over the backbone's feature map to evaluate potential object locations (proposals) and output bounding box coordinates.

Anchor Boxes

At each sliding window location, the RPN evaluates multiple reference boxes called anchor boxes. These anchors span predefined scales and aspect ratios (e.g., 3 scales x 3 ratios = 9 anchors), allowing the network to capture objects of different sizes.

RPN Loss Function

The RPN calculates a binary classification loss (objectness: whether the anchor contains an object or background) and a bounding box regression loss (refining coordinates relative to the anchor box): L_{RPN} = L_{cls} + \\lambda L_{reg}.

Unified Detection Pipeline

In Faster R-CNN, the proposals generated by the RPN are fed into an RoI pooling layer (or RoIAlign) to extract features for final classification and coordinate refinement.

Fine-Tuning Faster R-CNN in PyTorch

Torchvision contains pre-configured Faster R-CNN models built on ResNet backbones, ready for transfer learning.

<pre><code class="language-python">import torchvision from torchvision.models.detection import FasterRCNN from torchvision.models.detection.backbone_utils import resnet_fpn_backbone # Create a ResNet50 backbone with Feature Pyramid Network (FPN) backbone = resnet_fpn_backbone(backbone_name='resnet50', weights='DEFAULT') # Construct Faster R-CNN model = FasterRCNN(backbone, num_classes=2) # background + target class # Set to training mode model.train() images = [torch.rand(3, 300, 300)] targets = [{'boxes': torch.tensor([[10, 20, 100, 200]], dtype=torch.float32), 'labels': torch.tensor([1])}] loss_dict = model(images, targets) print(loss_dict.keys()) # loss_classifier, loss_box_reg, loss_objectness, loss_rpn_box_reg</pre>