.. _sec_naive_bayes:
Naive Bayes
===========
Throughout the previous sections, we learned about the theory of
probability and random variables. To put this theory to work, let us
introduce the *naive Bayes* classifier. This uses nothing but
probabilistic fundamentals to allow us to perform classification of
digits.
Learning is all about making assumptions. If we want to classify a new
data example that we have never seen before we have to make some
assumptions about which data examples are similar to each other. The
naive Bayes classifier, a popular and remarkably clear algorithm,
assumes all features are independent from each other to simplify the
computation. In this section, we will apply this model to recognize
characters in images.
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    %matplotlib inline
    import math
    from mxnet import gluon, np, npx
    from d2l import mxnet as d2l
    
    npx.set_np()
    d2l.use_svg_display()
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    %matplotlib inline
    import math
    import torch
    import torchvision
    from d2l import torch as d2l
    
    d2l.use_svg_display()
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    %matplotlib inline
    import math
    import tensorflow as tf
    from d2l import tensorflow as d2l
    
    d2l.use_svg_display()
.. raw:: html
     
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    def transform(data, label):
        return np.floor(data.astype('float32') / 128).squeeze(axis=-1), label
    
    mnist_train = gluon.data.vision.MNIST(train=True, transform=transform)
    mnist_test = gluon.data.vision.MNIST(train=False, transform=transform)
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    data_transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        lambda x: torch.floor(x * 255 / 128).squeeze(dim=0)
    ])
    
    mnist_train = torchvision.datasets.MNIST(
        root='./temp', train=True, transform=data_transform, download=True)
    mnist_test = torchvision.datasets.MNIST(
        root='./temp', train=False, transform=data_transform, download=True)
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
    Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./temp/MNIST/raw/train-images-idx3-ubyte.gz
    100.0%
    Extracting ./temp/MNIST/raw/train-images-idx3-ubyte.gz to ./temp/MNIST/raw
    100.0%
    Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
    Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./temp/MNIST/raw/train-labels-idx1-ubyte.gz
    Extracting ./temp/MNIST/raw/train-labels-idx1-ubyte.gz to ./temp/MNIST/raw
    
    Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
    
    Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./temp/MNIST/raw/t10k-images-idx3-ubyte.gz
    100.0%
    100.0%
    Extracting ./temp/MNIST/raw/t10k-images-idx3-ubyte.gz to ./temp/MNIST/raw
    
    Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
    Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./temp/MNIST/raw/t10k-labels-idx1-ubyte.gz
    Extracting ./temp/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./temp/MNIST/raw
    
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    ((train_images, train_labels), (
        test_images, test_labels)) = tf.keras.datasets.mnist.load_data()
    
    # Original pixel values of MNIST range from 0-255 (as the digits are stored as
    # uint8). For this section, pixel values that are greater than 128 (in the
    # original image) are converted to 1 and values that are less than 128 are
    # converted to 0. See section 18.9.2 and 18.9.3 for why
    train_images = tf.floor(tf.constant(train_images / 128, dtype = tf.float32))
    test_images = tf.floor(tf.constant(test_images / 128, dtype = tf.float32))
    
    train_labels = tf.constant(train_labels, dtype = tf.int32)
    test_labels = tf.constant(test_labels, dtype = tf.int32)
.. raw:: html
     
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    image, label = mnist_train[2]
    image.shape, label
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    ((28, 28), array(4, dtype=int32))
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    image, label = mnist_train[2]
    image.shape, label
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    (torch.Size([28, 28]), 4)
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    image, label = train_images[2], train_labels[2]
    image.shape, label.numpy()
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    (TensorShape([28, 28]), 4)
.. raw:: html
     
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    image.shape, image.dtype
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    ((28, 28), dtype('float32'))
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    image.shape, image.dtype
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    (torch.Size([28, 28]), torch.float32)
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    image.shape, image.dtype
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    (TensorShape([28, 28]), tf.float32)
.. raw:: html
     
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    label, type(label), label.dtype
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    (array(4, dtype=int32), mxnet.numpy.ndarray, dtype('int32'))
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    label, type(label)
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    (4, int)
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    label.numpy(), label.dtype
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    (4, tf.int32)
.. raw:: html
     
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    images, labels = mnist_train[10:38]
    images.shape, labels.shape
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    ((28, 28, 28), (28,))
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    images = torch.stack([mnist_train[i][0] for i in range(10, 38)], dim=0)
    labels = torch.tensor([mnist_train[i][1] for i in range(10, 38)])
    images.shape, labels.shape
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    (torch.Size([28, 28, 28]), torch.Size([28]))
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    images = tf.stack([train_images[i] for i in range(10, 38)], axis=0)
    labels = tf.constant([train_labels[i].numpy() for i in range(10, 38)])
    images.shape, labels.shape
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    (TensorShape([28, 28, 28]), TensorShape([28]))
.. raw:: html
     
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    d2l.show_images(images, 2, 9);
.. figure:: output_naive-bayes_6e475d_75_0.svg
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    d2l.show_images(images, 2, 9);
.. figure:: output_naive-bayes_6e475d_78_0.svg
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    d2l.show_images(images, 2, 9);
.. figure:: output_naive-bayes_6e475d_81_0.svg
.. raw:: html
     
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    X, Y = mnist_train[:]  # All training examples
    
    n_y = np.zeros((10))
    for y in range(10):
        n_y[y] = (Y == y).sum()
    P_y = n_y / n_y.sum()
    P_y
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    array([0.09871667, 0.11236667, 0.0993    , 0.10218333, 0.09736667,
           0.09035   , 0.09863333, 0.10441667, 0.09751666, 0.09915   ])
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    X = torch.stack([mnist_train[i][0] for i in range(len(mnist_train))], dim=0)
    Y = torch.tensor([mnist_train[i][1] for i in range(len(mnist_train))])
    
    n_y = torch.zeros(10)
    for y in range(10):
        n_y[y] = (Y == y).sum()
    P_y = n_y / n_y.sum()
    P_y
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    tensor([0.0987, 0.1124, 0.0993, 0.1022, 0.0974, 0.0904, 0.0986, 0.1044, 0.0975,
            0.0992])
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    X = train_images
    Y = train_labels
    
    n_y = tf.Variable(tf.zeros(10))
    for y in range(10):
        n_y[y].assign(tf.reduce_sum(tf.cast(Y == y, tf.float32)))
    P_y = n_y / tf.reduce_sum(n_y)
    P_y
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    
.. raw:: html
     
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    n_x = np.zeros((10, 28, 28))
    for y in range(10):
        n_x[y] = np.array(X.asnumpy()[Y.asnumpy() == y].sum(axis=0))
    P_xy = (n_x + 1) / (n_y + 1).reshape(10, 1, 1)
    
    d2l.show_images(P_xy, 2, 5);
.. figure:: output_naive-bayes_6e475d_99_0.svg
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    n_x = torch.zeros((10, 28, 28))
    for y in range(10):
        n_x[y] = torch.tensor(X.numpy()[Y.numpy() == y].sum(axis=0))
    P_xy = (n_x + 1) / (n_y + 1).reshape(10, 1, 1)
    
    d2l.show_images(P_xy, 2, 5);
.. figure:: output_naive-bayes_6e475d_102_0.svg
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    n_x = tf.Variable(tf.zeros((10, 28, 28)))
    for y in range(10):
        n_x[y].assign(tf.cast(tf.reduce_sum(
            X.numpy()[Y.numpy() == y], axis=0), tf.float32))
    P_xy = (n_x + 1) / tf.reshape((n_y + 1), (10, 1, 1))
    
    d2l.show_images(P_xy, 2, 5);
.. figure:: output_naive-bayes_6e475d_105_0.svg
.. raw:: html
     
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    def bayes_pred(x):
        x = np.expand_dims(x, axis=0)  # (28, 28) -> (1, 28, 28)
        p_xy = P_xy * x + (1 - P_xy)*(1 - x)
        p_xy = p_xy.reshape(10, -1).prod(axis=1)  # p(x|y)
        return np.array(p_xy) * P_y
    
    image, label = mnist_test[0]
    bayes_pred(image)
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    def bayes_pred(x):
        x = x.unsqueeze(0)  # (28, 28) -> (1, 28, 28)
        p_xy = P_xy * x + (1 - P_xy)*(1 - x)
        p_xy = p_xy.reshape(10, -1).prod(dim=1)  # p(x|y)
        return p_xy * P_y
    
    image, label = mnist_test[0]
    bayes_pred(image)
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    def bayes_pred(x):
        x = tf.expand_dims(x, axis=0)  # (28, 28) -> (1, 28, 28)
        p_xy = P_xy * x + (1 - P_xy)*(1 - x)
        p_xy = tf.math.reduce_prod(tf.reshape(p_xy, (10, -1)), axis=1)  # p(x|y)
        return p_xy * P_y
    
    image, label = train_images[0], train_labels[0]
    bayes_pred(image)
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    
.. raw:: html
     
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    a = 0.1
    print('underflow:', a**784)
    print('logarithm is normal:', 784*math.log(a))
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    underflow: 0.0
    logarithm is normal: -1805.2267129073316
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    a = 0.1
    print('underflow:', a**784)
    print('logarithm is normal:', 784*math.log(a))
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    underflow: 0.0
    logarithm is normal: -1805.2267129073316
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    a = 0.1
    print('underflow:', a**784)
    print('logarithm is normal:', 784*tf.math.log(a).numpy())
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    underflow: 0.0
    logarithm is normal: -1805.2267379760742
.. raw:: html
     
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    log_P_xy = np.log(P_xy)
    log_P_xy_neg = np.log(1 - P_xy)
    log_P_y = np.log(P_y)
    
    def bayes_pred_stable(x):
        x = np.expand_dims(x, axis=0)  # (28, 28) -> (1, 28, 28)
        p_xy = log_P_xy * x + log_P_xy_neg * (1 - x)
        p_xy = p_xy.reshape(10, -1).sum(axis=1)  # p(x|y)
        return p_xy + log_P_y
    
    py = bayes_pred_stable(image)
    py
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    array([-269.0042 , -301.73447, -245.21458, -218.8941 , -193.46907,
           -206.10315, -292.54315, -114.62834, -220.35619, -163.18881])
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    log_P_xy = torch.log(P_xy)
    log_P_xy_neg = torch.log(1 - P_xy)
    log_P_y = torch.log(P_y)
    
    def bayes_pred_stable(x):
        x = x.unsqueeze(0)  # (28, 28) -> (1, 28, 28)
        p_xy = log_P_xy * x + log_P_xy_neg * (1 - x)
        p_xy = p_xy.reshape(10, -1).sum(axis=1)  # p(x|y)
        return p_xy + log_P_y
    
    py = bayes_pred_stable(image)
    py
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    tensor([-269.0042, -301.7345, -245.2146, -218.8941, -193.4691, -206.1031,
            -292.5432, -114.6283, -220.3562, -163.1888])
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    log_P_xy = tf.math.log(P_xy)
    log_P_xy_neg = tf.math.log(1 - P_xy)
    log_P_y = tf.math.log(P_y)
    
    def bayes_pred_stable(x):
        x = tf.expand_dims(x, axis=0)  # (28, 28) -> (1, 28, 28)
        p_xy = log_P_xy * x + log_P_xy_neg * (1 - x)
        p_xy = tf.math.reduce_sum(tf.reshape(p_xy, (10, -1)), axis=1)  # p(x|y)
        return p_xy + log_P_y
    
    py = bayes_pred_stable(image)
    py
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    
.. raw:: html
     
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    # Convert label which is a scalar tensor of int32 dtype to a Python scalar
    # integer for comparison
    py.argmax(axis=0) == int(label)
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    array(True)
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    py.argmax(dim=0) == label
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    tensor(True)
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    tf.argmax(py, axis=0, output_type = tf.int32) == label
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    
.. raw:: html
     
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    def predict(X):
        return [bayes_pred_stable(x).argmax(axis=0).astype(np.int32) for x in X]
    
    X, y = mnist_test[:18]
    preds = predict(X)
    d2l.show_images(X, 2, 9, titles=[str(d) for d in preds]);
.. figure:: output_naive-bayes_6e475d_159_0.svg
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    def predict(X):
        return [bayes_pred_stable(x).argmax(dim=0).type(torch.int32).item()
                for x in X]
    
    X = torch.stack([mnist_test[i][0] for i in range(18)], dim=0)
    y = torch.tensor([mnist_test[i][1] for i in range(18)])
    preds = predict(X)
    d2l.show_images(X, 2, 9, titles=[str(d) for d in preds]);
.. figure:: output_naive-bayes_6e475d_162_0.svg
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    def predict(X):
        return [tf.argmax(
            bayes_pred_stable(x), axis=0, output_type = tf.int32).numpy()
                for x in X]
    
    X = tf.stack([train_images[i] for i in range(10, 38)], axis=0)
    y = tf.constant([train_labels[i].numpy() for i in range(10, 38)])
    preds = predict(X)
    d2l.show_images(X, 2, 9, titles=[str(d) for d in preds]);
.. figure:: output_naive-bayes_6e475d_165_0.svg
.. raw:: html
     
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    X, y = mnist_test[:]
    preds = np.array(predict(X), dtype=np.int32)
    float((preds == y).sum()) / len(y)  # Validation accuracy
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    0.8426
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    X = torch.stack([mnist_test[i][0] for i in range(len(mnist_test))], dim=0)
    y = torch.tensor([mnist_test[i][1] for i in range(len(mnist_test))])
    preds = torch.tensor(predict(X), dtype=torch.int32)
    float((preds == y).sum()) / len(y)  # Validation accuracy
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    0.8426
.. raw:: html
     
.. raw:: html
     
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    X = test_images
    y = test_labels
    preds = tf.constant(predict(X), dtype=tf.int32)
    # Validation accuracy
    tf.reduce_sum(tf.cast(preds == y, tf.float32)).numpy() / len(y)
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    0.8426
.. raw:: html
     
.. raw:: html
     
.. raw:: html
     
`Discussions `__
.. raw:: html
     
.. raw:: html
     
`Discussions `__
.. raw:: html
     
.. raw:: html
     
`Discussions `__
.. raw:: html
     
.. raw:: html