Introduction to image segmentation

Deep Learning for Images with PyTorch

Michal Oleszak

Machine Learning Engineer

Image segmentation

  • Image segmentation partitions the image into multiple segments on the pixel level
  • Each pixel in an image is assigned to a particular segment
  • Three types of segmentation:
    • Semantic segmentation
    • Instance segmentation
    • Panoptic segmentation
Deep Learning for Images with PyTorch

Semantic segmentation

semantic segmentation

  • Each pixel classified into a class
  • All pixels belonging to the same class are treated equally
Deep Learning for Images with PyTorch

Instance segmentation

instance segmentation

  • Distinguishes between different instances of the same class
  • Background often not segmented
Deep Learning for Images with PyTorch

Panoptic segmentation

panoptic segmentation

  • Combines semantic and instance segmentations
  • Assigns a unique label to each instance of an object
  • Classifies background at pixel level
Deep Learning for Images with PyTorch

Data annotations

image = Image.open("images/British_Shorthair_36.jpg")
mask = Image.open("annots/British_Shorthair_36.png")


transform = transforms.Compose([ transforms.ToTensor() ]) image_tensor = transform(image) mask_tensor = transform(mask)
print(f"""Image shape: {image_tensor.shape} Mask shape: {mask_tensor.shape}""")
    Image shape: torch.Size([3, 333, 500])
    Mask shape: torch.Size([1, 333, 500])

British shorthair cat photograph

Deep Learning for Images with PyTorch

Understanding the mask

  • Dataset documentation:

    Pixel Annotations: 1: Foreground 2: Background 3: Not classified

  • Unique mask values:

    mask_tensor.unique()
    
    tensor([0.0039, 0.0078, 0.0118])
    
  • Pixel values are divided by 255:

    • 1 / 255 = 0.0039 - object
    • 2 / 255 = 0.0.0078 - background
    • 3 / 255 = 0.0118 - unclassified
Deep Learning for Images with PyTorch

Creating a binary mask

binary_mask = torch.where(
    mask_tensor == 1/255, 
    torch.tensor(1.0),
    torch.tensor(0.0),
)


to_pil_image = transforms.ToPILImage() mask = to_pil_image(binary_mask)
plt.imshow(mask)

segmentation mask

  • torch.where():
    • Condition
    • Value to use if condition met
    • Value to use otherwise
  • Transform mask to PIL image
  • Display mask image
Deep Learning for Images with PyTorch

Segmenting the object

object_tensor = image_tensor * binary_mask


to_pil_image = transforms.ToPILImage() object_image = to_pil_image(object_tensor)
plt.imshow(object_image)

segmented image

  • Multiply image with the binary mask
  • Transform object to PIL image
  • Display object image
Deep Learning for Images with PyTorch

Let's practice!

Deep Learning for Images with PyTorch

Preparing Video For Download...