.. _sec_semantic_segmentation: Semantic Segmentation and the Dataset ===================================== When discussing object detection tasks in :numref:`sec_bbox`--:numref:`sec_rcnn`, rectangular bounding boxes are used to label and predict objects in images. This section will discuss the problem of *semantic segmentation*, which focuses on how to divide an image into regions belonging to different semantic classes. Different from object detection, semantic segmentation recognizes and understands what are in images in pixel level: its labeling and prediction of semantic regions are in pixel level. :numref:`fig_segmentation` shows the labels of the dog, cat, and background of the image in semantic segmentation. Compared with in object detection, the pixel-level borders labeled in semantic segmentation are obviously more fine-grained. .. _fig_segmentation: .. figure:: ../img/segmentation.svg Labels of the dog, cat, and background of the image in semantic segmentation. Image Segmentation and Instance Segmentation -------------------------------------------- There are also two important tasks in the field of computer vision that are similar to semantic segmentation, namely image segmentation and instance segmentation. We will briefly distinguish them from semantic segmentation as follows. - *Image segmentation* divides an image into several constituent regions. The methods for this type of problem usually make use of the correlation between pixels in the image. It does not need label information about image pixels during training, and it cannot guarantee that the segmented regions will have the semantics that we hope to obtain during prediction. Taking the image in :numref:`fig_segmentation` as input, image segmentation may divide the dog into two regions: one covers the mouth and eyes which are mainly black, and the other covers the rest of the body which is mainly yellow. - *Instance segmentation* is also called *simultaneous detection and segmentation*. It studies how to recognize the pixel-level regions of each object instance in an image. Different from semantic segmentation, instance segmentation needs to distinguish not only semantics, but also different object instances. For example, if there are two dogs in the image, instance segmentation needs to distinguish which of the two dogs a pixel belongs to. The Pascal VOC2012 Semantic Segmentation Dataset ------------------------------------------------ On of the most important semantic segmentation dataset is `Pascal VOC2012 `__. In the following, we will take a look at this dataset. .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python %matplotlib inline import os from mxnet import gluon, image, np, npx from d2l import mxnet as d2l npx.set_np() .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python %matplotlib inline import os import torch import torchvision from d2l import torch as d2l .. raw:: html
.. raw:: html
The tar file of the dataset is about 2 GB, so it may take a while to download the file. The extracted dataset is located at ``../data/VOCdevkit/VOC2012``. .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python #@save d2l.DATA_HUB['voc2012'] = (d2l.DATA_URL + 'VOCtrainval_11-May-2012.tar', '4e443f8a2eca6b1dac8a6c57641b67dd40621a49') voc_dir = d2l.download_extract('voc2012', 'VOCdevkit/VOC2012') .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Downloading ../data/VOCtrainval_11-May-2012.tar from http://d2l-data.s3-accelerate.amazonaws.com/VOCtrainval_11-May-2012.tar... .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python #@save d2l.DATA_HUB['voc2012'] = (d2l.DATA_URL + 'VOCtrainval_11-May-2012.tar', '4e443f8a2eca6b1dac8a6c57641b67dd40621a49') voc_dir = d2l.download_extract('voc2012', 'VOCdevkit/VOC2012') .. raw:: html
.. raw:: html
After entering the path ``../data/VOCdevkit/VOC2012``, we can see the different components of the dataset. The ``ImageSets/Segmentation`` path contains text files that specify training and test samples, while the ``JPEGImages`` and ``SegmentationClass`` paths store the input image and label for each example, respectively. The label here is also in the image format, with the same size as its labeled input image. Besides, pixels with the same color in any label image belong to the same semantic class. The following defines the ``read_voc_images`` function to read all the input images and labels into the memory. .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python #@save def read_voc_images(voc_dir, is_train=True): """Read all VOC feature and label images.""" txt_fname = os.path.join(voc_dir, 'ImageSets', 'Segmentation', 'train.txt' if is_train else 'val.txt') with open(txt_fname, 'r') as f: images = f.read().split() features, labels = [], [] for i, fname in enumerate(images): features.append(image.imread(os.path.join( voc_dir, 'JPEGImages', f'{fname}.jpg'))) labels.append(image.imread(os.path.join( voc_dir, 'SegmentationClass', f'{fname}.png'))) return features, labels train_features, train_labels = read_voc_images(voc_dir, True) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python #@save def read_voc_images(voc_dir, is_train=True): """Read all VOC feature and label images.""" txt_fname = os.path.join(voc_dir, 'ImageSets', 'Segmentation', 'train.txt' if is_train else 'val.txt') mode = torchvision.io.image.ImageReadMode.RGB with open(txt_fname, 'r') as f: images = f.read().split() features, labels = [], [] for i, fname in enumerate(images): features.append(torchvision.io.read_image(os.path.join( voc_dir, 'JPEGImages', f'{fname}.jpg'))) labels.append(torchvision.io.read_image(os.path.join( voc_dir, 'SegmentationClass' ,f'{fname}.png'), mode)) return features, labels train_features, train_labels = read_voc_images(voc_dir, True) .. raw:: html
.. raw:: html
We draw the first five input images and their labels. In the label images, white and black represent borders and background, respectively, while the other colors correspond to different classes. .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python n = 5 imgs = train_features[0:n] + train_labels[0:n] d2l.show_images(imgs, 2, n); .. figure:: output_semantic-segmentation-and-dataset_23ff18_30_0.png .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python n = 5 imgs = train_features[0:n] + train_labels[0:n] imgs = [img.permute(1,2,0) for img in imgs] d2l.show_images(imgs, 2, n); .. figure:: output_semantic-segmentation-and-dataset_23ff18_33_0.png .. raw:: html
.. raw:: html
Next, we enumerate the RGB color values and class names for all the labels in this dataset. .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python #@save VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]] #@save VOC_CLASSES = ['background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'potted plant', 'sheep', 'sofa', 'train', 'tv/monitor'] .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python #@save VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]] #@save VOC_CLASSES = ['background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'potted plant', 'sheep', 'sofa', 'train', 'tv/monitor'] .. raw:: html
.. raw:: html
With the two constants defined above, we can conveniently find the class index for each pixel in a label. We define the ``voc_colormap2label`` function to build the mapping from the above RGB color values to class indices, and the ``voc_label_indices`` function to map any RGB values to their class indices in this Pascal VOC2012 dataset. .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python #@save def voc_colormap2label(): """Build the mapping from RGB to class indices for VOC labels.""" colormap2label = np.zeros(256 ** 3) for i, colormap in enumerate(VOC_COLORMAP): colormap2label[ (colormap[0] * 256 + colormap[1]) * 256 + colormap[2]] = i return colormap2label #@save def voc_label_indices(colormap, colormap2label): """Map any RGB values in VOC labels to their class indices.""" colormap = colormap.astype(np.int32) idx = ((colormap[:, :, 0] * 256 + colormap[:, :, 1]) * 256 + colormap[:, :, 2]) return colormap2label[idx] .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python #@save def voc_colormap2label(): """Build the mapping from RGB to class indices for VOC labels.""" colormap2label = torch.zeros(256 ** 3, dtype=torch.long) for i, colormap in enumerate(VOC_COLORMAP): colormap2label[ (colormap[0] * 256 + colormap[1]) * 256 + colormap[2]] = i return colormap2label #@save def voc_label_indices(colormap, colormap2label): """Map any RGB values in VOC labels to their class indices.""" colormap = colormap.permute(1, 2, 0).numpy().astype('int32') idx = ((colormap[:, :, 0] * 256 + colormap[:, :, 1]) * 256 + colormap[:, :, 2]) return colormap2label[idx] .. raw:: html
.. raw:: html
For example, in the first example image, the class index for the front part of the airplane is 1, while the background index is 0. .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python y = voc_label_indices(train_labels[0], voc_colormap2label()) y[105:115, 130:140], VOC_CLASSES[1] .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output (array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.], [0., 0., 0., 0., 0., 0., 0., 1., 1., 1.], [0., 0., 0., 0., 0., 0., 1., 1., 1., 1.], [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.], [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.], [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.], [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.], [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.], [0., 0., 0., 0., 0., 0., 1., 1., 1., 1.], [0., 0., 0., 0., 0., 0., 0., 0., 1., 1.]]), 'aeroplane') .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python y = voc_label_indices(train_labels[0], voc_colormap2label()) y[105:115, 130:140], VOC_CLASSES[1] .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output (tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0, 1, 1, 1], [0, 0, 0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 0, 0, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 1, 1, 1, 1, 1], [0, 0, 0, 0, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 1, 1]]), 'aeroplane') .. raw:: html
.. raw:: html
Data Preprocessing ~~~~~~~~~~~~~~~~~~ In previous experiments such as in :numref:`sec_alexnet`--:numref:`sec_googlenet`, images are rescaled to fit the model's required input shape. However, in semantic segmentation, doing so requires rescaling the predicted pixel classes back to the original shape of the input image. Such rescaling may be inaccurate, especially for segmented regions with different classes. To avoid this issue, we crop the image to a *fixed* shape instead of rescaling. Specifically, using random cropping from image augmentation, we crop the same area of the input image and the label. .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python #@save def voc_rand_crop(feature, label, height, width): """Randomly crop both feature and label images.""" feature, rect = image.random_crop(feature, (width, height)) label = image.fixed_crop(label, *rect) return feature, label imgs = [] for _ in range(n): imgs += voc_rand_crop(train_features[0], train_labels[0], 200, 300) d2l.show_images(imgs[::2] + imgs[1::2], 2, n); .. figure:: output_semantic-segmentation-and-dataset_23ff18_66_0.png .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python #@save def voc_rand_crop(feature, label, height, width): """Randomly crop both feature and label images.""" rect = torchvision.transforms.RandomCrop.get_params( feature, (height, width)) feature = torchvision.transforms.functional.crop(feature, *rect) label = torchvision.transforms.functional.crop(label, *rect) return feature, label imgs = [] for _ in range(n): imgs += voc_rand_crop(train_features[0], train_labels[0], 200, 300) imgs = [img.permute(1, 2, 0) for img in imgs] d2l.show_images(imgs[::2] + imgs[1::2], 2, n); .. figure:: output_semantic-segmentation-and-dataset_23ff18_69_0.png .. raw:: html
.. raw:: html
Custom Semantic Segmentation Dataset Class ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ We define a custom semantic segmentation dataset class ``VOCSegDataset`` by inheriting the ``Dataset`` class provided by high-level APIs. By implementing the ``__getitem__`` function, we can arbitrarily access the input image indexed as ``idx`` in the dataset and the class index of each pixel in this image. Since some images in the dataset have a smaller size than the output size of random cropping, these examples are filtered out by a custom ``filter`` function. In addition, we also define the ``normalize_image`` function to standardize the values of the three RGB channels of input images. .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python #@save class VOCSegDataset(gluon.data.Dataset): """A customized dataset to load the VOC dataset.""" def __init__(self, is_train, crop_size, voc_dir): self.rgb_mean = np.array([0.485, 0.456, 0.406]) self.rgb_std = np.array([0.229, 0.224, 0.225]) self.crop_size = crop_size features, labels = read_voc_images(voc_dir, is_train=is_train) self.features = [self.normalize_image(feature) for feature in self.filter(features)] self.labels = self.filter(labels) self.colormap2label = voc_colormap2label() print('read ' + str(len(self.features)) + ' examples') def normalize_image(self, img): return (img.astype('float32') / 255 - self.rgb_mean) / self.rgb_std def filter(self, imgs): return [img for img in imgs if ( img.shape[0] >= self.crop_size[0] and img.shape[1] >= self.crop_size[1])] def __getitem__(self, idx): feature, label = voc_rand_crop(self.features[idx], self.labels[idx], *self.crop_size) return (feature.transpose(2, 0, 1), voc_label_indices(label, self.colormap2label)) def __len__(self): return len(self.features) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python #@save class VOCSegDataset(torch.utils.data.Dataset): """A customized dataset to load the VOC dataset.""" def __init__(self, is_train, crop_size, voc_dir): self.transform = torchvision.transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) self.crop_size = crop_size features, labels = read_voc_images(voc_dir, is_train=is_train) self.features = [self.normalize_image(feature) for feature in self.filter(features)] self.labels = self.filter(labels) self.colormap2label = voc_colormap2label() print('read ' + str(len(self.features)) + ' examples') def normalize_image(self, img): return self.transform(img.float() / 255) def filter(self, imgs): return [img for img in imgs if ( img.shape[1] >= self.crop_size[0] and img.shape[2] >= self.crop_size[1])] def __getitem__(self, idx): feature, label = voc_rand_crop(self.features[idx], self.labels[idx], *self.crop_size) return (feature, voc_label_indices(label, self.colormap2label)) def __len__(self): return len(self.features) .. raw:: html
.. raw:: html
Reading the Dataset ~~~~~~~~~~~~~~~~~~~ We use the custom ``VOCSegDatase``\ t class to create instances of the training set and test set, respectively. Suppose that we specify that the output shape of randomly cropped images is :math:`320\times 480`. Below we can view the number of examples that are retained in the training set and test set. .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python crop_size = (320, 480) voc_train = VOCSegDataset(True, crop_size, voc_dir) voc_test = VOCSegDataset(False, crop_size, voc_dir) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output read 1114 examples read 1078 examples .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python crop_size = (320, 480) voc_train = VOCSegDataset(True, crop_size, voc_dir) voc_test = VOCSegDataset(False, crop_size, voc_dir) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output read 1114 examples read 1078 examples .. raw:: html
.. raw:: html
Setting the batch size to 64, we define the data iterator for the training set. Let us print the shape of the first minibatch. Different from in image classification or object detection, labels here are three-dimensional tensors. .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python batch_size = 64 train_iter = gluon.data.DataLoader(voc_train, batch_size, shuffle=True, last_batch='discard', num_workers=d2l.get_dataloader_workers()) for X, Y in train_iter: print(X.shape) print(Y.shape) break .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output (64, 3, 320, 480) (64, 320, 480) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python batch_size = 64 train_iter = torch.utils.data.DataLoader(voc_train, batch_size, shuffle=True, drop_last=True, num_workers=d2l.get_dataloader_workers()) for X, Y in train_iter: print(X.shape) print(Y.shape) break .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output torch.Size([64, 3, 320, 480]) torch.Size([64, 320, 480]) .. raw:: html
.. raw:: html
Putting All Things Together ~~~~~~~~~~~~~~~~~~~~~~~~~~~ Finally, we define the following ``load_data_voc`` function to download and read the Pascal VOC2012 semantic segmentation dataset. It returns data iterators for both the training and test datasets. .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python #@save def load_data_voc(batch_size, crop_size): """Load the VOC semantic segmentation dataset.""" voc_dir = d2l.download_extract('voc2012', os.path.join( 'VOCdevkit', 'VOC2012')) num_workers = d2l.get_dataloader_workers() train_iter = gluon.data.DataLoader( VOCSegDataset(True, crop_size, voc_dir), batch_size, shuffle=True, last_batch='discard', num_workers=num_workers) test_iter = gluon.data.DataLoader( VOCSegDataset(False, crop_size, voc_dir), batch_size, last_batch='discard', num_workers=num_workers) return train_iter, test_iter .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python #@save def load_data_voc(batch_size, crop_size): """Load the VOC semantic segmentation dataset.""" voc_dir = d2l.download_extract('voc2012', os.path.join( 'VOCdevkit', 'VOC2012')) num_workers = d2l.get_dataloader_workers() train_iter = torch.utils.data.DataLoader( VOCSegDataset(True, crop_size, voc_dir), batch_size, shuffle=True, drop_last=True, num_workers=num_workers) test_iter = torch.utils.data.DataLoader( VOCSegDataset(False, crop_size, voc_dir), batch_size, drop_last=True, num_workers=num_workers) return train_iter, test_iter .. raw:: html
.. raw:: html
Summary ------- - Semantic segmentation recognizes and understands what are in an image in pixel level by dividing the image into regions belonging to different semantic classes. - On of the most important semantic segmentation dataset is Pascal VOC2012. - In semantic segmentation, since the input image and label correspond one-to-one on the pixel, the input image is randomly cropped to a fixed shape rather than rescaled. Exercises --------- 1. How can semantic segmentation be applied in autonomous vehicles and medical image diagnostics? Can you think of other applications? 2. Recall the descriptions of data augmentation in :numref:`sec_image_augmentation`. Which of the image augmentation methods used in image classification would be infeasible to be applied in semantic segmentation? .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html