Deep Learning for Images with PyTorch
Michal Oleszak
Machine Learning Engineer


Faster R-CNN: an advanced version of R-CNN

Faster R-CNN: an advanced version of R-CNN

Faster R-CNN: an advanced version of R-CNN





from torchvision.models.detection.rpn import AnchorGeneratoranchor_generator = AnchorGenerator( sizes=((32, 64, 128),), aspect_ratios=((0.5, 1.0, 2.0),), )
from torchvision.ops import MultiScaleRoIAlignroi_pooler = MultiScaleRoIAlign( featmap_names=["0"], output_size=7, sampling_ratio=2, )
rpn_cls_criterion = nn.BCEWithLogitsLoss()
rpn_reg_criterion = nn.MSELoss()rcnn_cls_criterion = nn.CrossEntropyLoss()
rcnn_reg_criterion = nn.MSELoss()from torchvision.models.detection import FasterRCNNbackbone = torchvision.models.mobilenet_v2(weights="DEFAULT").featuresbackbone.out_channels = 1280model = FasterRCNN( backbone=backbone, num_classes=num_classes, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler, )
Load pre-trained Faster R-CNN
from torchvision.models.detection.faster_rcnn import FastRCNNPredictormodel = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")
Define number of classes and classifier input sise
num_classes = 2in_features = model.roi_heads.box_predictor.cls_score.in_features
Replace model's classifier with a one with the desired number of classes
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
Deep Learning for Images with PyTorch