Deep Learning for Images with PyTorch
Michal Oleszak
Machine Learning Engineer
Faster R-CNN: an advanced version of R-CNN
Faster R-CNN: an advanced version of R-CNN
Faster R-CNN: an advanced version of R-CNN
from torchvision.models.detection.rpn import AnchorGenerator
anchor_generator = AnchorGenerator( sizes=((32, 64, 128),), aspect_ratios=((0.5, 1.0, 2.0),), )
from torchvision.ops import MultiScaleRoIAlign
roi_pooler = MultiScaleRoIAlign( featmap_names=["0"], output_size=7, sampling_ratio=2, )
rpn_cls_criterion = nn.BCEWithLogitsLoss()
rpn_reg_criterion = nn.MSELoss()
rcnn_cls_criterion = nn.CrossEntropyLoss()
rcnn_reg_criterion = nn.MSELoss()
from torchvision.models.detection import FasterRCNN
backbone = torchvision.models.mobilenet_v2(weights="DEFAULT").features
backbone.out_channels = 1280
model = FasterRCNN( backbone=backbone, num_classes=num_classes, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler, )
Load pre-trained Faster R-CNN
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")
Define number of classes and classifier input sise
num_classes = 2
in_features = model.roi_heads.box_predictor.cls_score.in_features
Replace model's classifier with a one with the desired number of classes
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
Deep Learning for Images with PyTorch