.. _sec_fashion_mnist:
The Image Classification Dataset
================================
One of the widely used dataset for image classification is the MNIST
dataset :cite:`LeCun.Bottou.Bengio.ea.1998`. While it had a good run
as a benchmark dataset, even simple models by today's standards achieve
classification accuracy over 95%, making it unsuitable for
distinguishing between stronger models and weaker ones. Today, MNIST
serves as more of sanity checks than as a benchmark. To up the ante just
a bit, we will focus our discussion in the coming sections on the
qualitatively similar, but comparatively complex Fashion-MNIST dataset
:cite:`Xiao.Rasul.Vollgraf.2017`, which was released in 2017.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
%matplotlib inline
import sys
from mxnet import gluon
from d2l import mxnet as d2l
d2l.use_svg_display()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
d2l.use_svg_display()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
%matplotlib inline
import tensorflow as tf
from d2l import tensorflow as d2l
d2l.use_svg_display()
.. raw:: html
.. raw:: html
Reading the Dataset
-------------------
We can download and read the Fashion-MNIST dataset into memory via the
build-in functions in the framework.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
mnist_train = gluon.data.vision.FashionMNIST(train=True)
mnist_test = gluon.data.vision.FashionMNIST(train=False)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
# `ToTensor` converts the image data from PIL type to 32-bit floating point
# tensors. It divides all numbers by 255 so that all pixel values are between
# 0 and 1
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=True)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
mnist_train, mnist_test = tf.keras.datasets.fashion_mnist.load_data()
.. raw:: html
.. raw:: html
Fashion-MNIST consists of images from 10 categories, each represented by
6000 images in the training dataset and by 1000 in the test dataset. A
*test dataset* (or *test set*) is used for evaluating model performance
and not for training. Consequently the training set and the test set
contain 60000 and 10000 images, respectively.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
len(mnist_train), len(mnist_test)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(60000, 10000)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
len(mnist_train), len(mnist_test)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(60000, 10000)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
len(mnist_train[0]), len(mnist_test[0])
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(60000, 10000)
.. raw:: html
.. raw:: html
The height and width of each input image are both 28 pixels. Note that
the dataset consists of grayscale images, whose number of channels is 1.
For brevity, throughout this book we store the shape of any image with
height :math:`h` width :math:`w` pixels as :math:`h \times w` or
(:math:`h`, :math:`w`).
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
mnist_train[0][0].shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(28, 28, 1)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
mnist_train[0][0].shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
torch.Size([1, 28, 28])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
mnist_train[0][0].shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(28, 28)
.. raw:: html
.. raw:: html
The images in Fashion-MNIST are associated with the following
categories: t-shirt, trousers, pullover, dress, coat, sandal, shirt,
sneaker, bag, and ankle boot. The following function converts between
numeric label indices and their names in text.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def get_fashion_mnist_labels(labels): #@save
"""Return text labels for the Fashion-MNIST dataset."""
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
We can now create a function to visualize these examples.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save
"""Plot a list of images."""
figsize = (num_cols * scale, num_rows * scale)
_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
axes = axes.flatten()
for i, (ax, img) in enumerate(zip(axes, imgs)):
ax.imshow(img.asnumpy())
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if titles:
ax.set_title(titles[i])
return axes
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save
"""Plot a list of images."""
figsize = (num_cols * scale, num_rows * scale)
_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
axes = axes.flatten()
for i, (ax, img) in enumerate(zip(axes, imgs)):
if torch.is_tensor(img):
# Tensor Image
ax.imshow(img.numpy())
else:
# PIL Image
ax.imshow(img)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if titles:
ax.set_title(titles[i])
return axes
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save
"""Plot a list of images."""
figsize = (num_cols * scale, num_rows * scale)
_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
axes = axes.flatten()
for i, (ax, img) in enumerate(zip(axes, imgs)):
ax.imshow(img.numpy())
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if titles:
ax.set_title(titles[i])
return axes
.. raw:: html
.. raw:: html
Here are the images and their corresponding labels (in text) for the
first few examples in the training dataset.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X, y = mnist_train[:18]
print(X.shape)
show_images(X.squeeze(axis=-1), 2, 9, titles=get_fashion_mnist_labels(y));
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(18, 28, 28, 1)
.. figure:: output_image-classification-dataset_e45669_65_1.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));
.. figure:: output_image-classification-dataset_e45669_68_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = tf.constant(mnist_train[0][:18])
y = tf.constant(mnist_train[1][:18])
show_images(X, 2, 9, titles=get_fashion_mnist_labels(y));
.. figure:: output_image-classification-dataset_e45669_71_0.svg
.. raw:: html
.. raw:: html
Reading a Minibatch
-------------------
To make our life easier when reading from the training and test sets, we
use the built-in data iterator rather than creating one from scratch.
Recall that at each iteration, a data iterator reads a minibatch of data
with size ``batch_size`` each time. We also randomly shuffle the
examples for the training data iterator.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
batch_size = 256
def get_dataloader_workers(): #@save
"""Use 4 processes to read the data except for Windows."""
return 0 if sys.platform.startswith('win') else 4
# `ToTensor` converts the image data from uint8 to 32-bit floating point. It
# divides all numbers by 255 so that all pixel values are between 0 and 1
transformer = gluon.data.vision.transforms.ToTensor()
train_iter = gluon.data.DataLoader(mnist_train.transform_first(transformer),
batch_size, shuffle=True,
num_workers=get_dataloader_workers())
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
batch_size = 256
def get_dataloader_workers(): #@save
"""Use 4 processes to read the data."""
return 4
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers())
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
batch_size = 256
train_iter = tf.data.Dataset.from_tensor_slices(
mnist_train).batch(batch_size).shuffle(len(mnist_train[0]))
.. raw:: html
.. raw:: html
Let us look at the time it takes to read the training data.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
timer = d2l.Timer()
for X, y in train_iter:
continue
f'{timer.stop():.2f} sec'
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'2.47 sec'
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
timer = d2l.Timer()
for X, y in train_iter:
continue
f'{timer.stop():.2f} sec'
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'2.28 sec'
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
timer = d2l.Timer()
for X, y in train_iter:
continue
f'{timer.stop():.2f} sec'
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'0.29 sec'
.. raw:: html
.. raw:: html
Putting All Things Together
---------------------------
Now we define the ``load_data_fashion_mnist`` function that obtains and
reads the Fashion-MNIST dataset. It returns the data iterators for both
the training set and validation set. In addition, it accepts an optional
argument to resize images to another shape.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def load_data_fashion_mnist(batch_size, resize=None): #@save
"""Download the Fashion-MNIST dataset and then load it into memory."""
dataset = gluon.data.vision
trans = [dataset.transforms.ToTensor()]
if resize:
trans.insert(0, dataset.transforms.Resize(resize))
trans = dataset.transforms.Compose(trans)
mnist_train = dataset.FashionMNIST(train=True).transform_first(trans)
mnist_test = dataset.FashionMNIST(train=False).transform_first(trans)
return (gluon.data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers()),
gluon.data.DataLoader(mnist_test, batch_size, shuffle=False,
num_workers=get_dataloader_workers()))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def load_data_fashion_mnist(batch_size, resize=None): #@save
"""Download the Fashion-MNIST dataset and then load it into memory."""
trans = [transforms.ToTensor()]
if resize:
trans.insert(0, transforms.Resize(resize))
trans = transforms.Compose(trans)
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=True)
return (data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers()),
data.DataLoader(mnist_test, batch_size, shuffle=False,
num_workers=get_dataloader_workers()))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def load_data_fashion_mnist(batch_size, resize=None): #@save
"""Download the Fashion-MNIST dataset and then load it into memory."""
mnist_train, mnist_test = tf.keras.datasets.fashion_mnist.load_data()
# Divide all numbers by 255 so that all pixel values are between
# 0 and 1, add a batch dimension at the last. And cast label to int32
process = lambda X, y: (tf.expand_dims(X, axis=3) / 255,
tf.cast(y, dtype='int32'))
resize_fn = lambda X, y: (
tf.image.resize_with_pad(X, resize, resize) if resize else X, y)
return (
tf.data.Dataset.from_tensor_slices(process(*mnist_train)).batch(
batch_size).shuffle(len(mnist_train[0])).map(resize_fn),
tf.data.Dataset.from_tensor_slices(process(*mnist_test)).batch(
batch_size).map(resize_fn))
.. raw:: html
.. raw:: html
Below we test the image resizing feature of the
``load_data_fashion_mnist`` function by specifying the ``resize``
argument.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
print(X.shape, X.dtype, y.shape, y.dtype)
break
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(32, 1, 64, 64) (32,)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
print(X.shape, X.dtype, y.shape, y.dtype)
break
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
torch.Size([32, 1, 64, 64]) torch.float32 torch.Size([32]) torch.int64
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
print(X.shape, X.dtype, y.shape, y.dtype)
break
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(32, 64, 64, 1) (32,)
.. raw:: html
.. raw:: html
We are now ready to work with the Fashion-MNIST dataset in the sections
that follow.
Summary
-------
- Fashion-MNIST is an apparel classification dataset consisting of
images representing 10 categories. We will use this dataset in
subsequent sections and chapters to evaluate various classification
algorithms.
- We store the shape of any image with height :math:`h` width :math:`w`
pixels as :math:`h \times w` or (:math:`h`, :math:`w`).
- Data iterators are a key component for efficient performance. Rely on
well-implemented data iterators that exploit high-performance
computing to avoid slowing down your training loop.
Exercises
---------
1. Does reducing the ``batch_size`` (for instance, to 1) affect the
reading performance?
2. The data iterator performance is important. Do you think the current
implementation is fast enough? Explore various options to improve it.
3. Check out the framework's online API documentation. Which other
datasets are available?
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html