Instance segmentation with Mask R-CNN

Deep Learning for Images with PyTorch

Michal Oleszak

Machine Learning Engineer

Faster R-CNN

Faster R-CNN architecture diagram

Deep Learning for Images with PyTorch

Mask R-CNN

Mask R-CNN architecture diagram

Deep Learning for Images with PyTorch

Pre-trained Masked R-CNN in PyTorch

from torchvision.models.detection import \
maskrcnn_resnet50_fpn


model = maskrcnn_resnet50_fpn(pretrained=True) model.eval()
image = Image.open("cat_and_laptop.jpg") transform = transforms.Compose([ transforms.ToTensor() ]) image_tensor = transform(image).unsqueeze(0)
with torch.no_grad(): prediction = model(image_tensor)
  • Import the Mask R-CNN model
  • Load pre-trained model
  • Load test image and transform to tensor

photograph of a cat sitting next to a laptop

  • Pass image tensor to the model
Deep Learning for Images with PyTorch

Model outputs

  • Labels

    prediction[0]["labels"]
    
    tensor([
        17, 73, 76, 73, 67, 42, 63, 84,73, 65, 
        17, 73, 73, 73, 84, 72, 76, 76,17, 15
    ])
    
  • Class names

    print(class_names[17], class_names[73])
    
    cat laptop
    
  • Class probabilities

    prediction[0]["scores"]
    
    tensor([
        0.9981, 0.9672, 0.9061, 0.6893, 0.3729, 
        ..., 
        0.0745, 0.0705, 0.0623, 0.0610, 0.0508
    ])
    
  • Masks

    prediction[0]["masks"]
    
    tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
              ...]]]])
    
Deep Learning for Images with PyTorch

Soft masks

  • Unique mask values

    prediction[0]["masks"].unique()
    
    tensor([0.0000e+00, 5.9713e-08, ..., 
            9.9989e-01, 9.9990e-01])
    
  • Mask R-CNN masks:

    • Values between 0 and 1
    • Represent the model's confidence that each pixel belongs to the object
    • Provide more nuanced information than binary masks
    • Can be binarized by thresholding if needed
Deep Learning for Images with PyTorch

Displaying soft masks

masks = prediction[0]["masks"]
labels = prediction[0]["labels"]


for i in range(2): plt.imshow(image)
plt.imshow( masks[i, 0], cmap="jet", alpha=0.5, )
plt.title( f"Object: {class_names[labels[i]]}" ) plt.show()
  • Extract masks and labels from prediction
  • Iterate over top two objects, plotting the original image
  • For each object, plot the semi-transparent mask
  • Add title and display
Deep Learning for Images with PyTorch

Displaying soft masks

Instance segmentation masks overlaid on the images

Deep Learning for Images with PyTorch

Let's practice!

Deep Learning for Images with PyTorch

Preparing Video For Download...