.. _sec_nadaraya-watson:
Attention Pooling: Nadaraya-Watson Kernel Regression
====================================================
Now you know the major components of attention mechanisms under the
framework in :numref:`fig_qkv`. To recapitulate, the interactions
between queries (volitional cues) and keys (nonvolitional cues) result
in *attention pooling*. The attention pooling selectively aggregates
values (sensory inputs) to produce the output. In this section, we will
describe attention pooling in greater detail to give you a high-level
view of how attention mechanisms work in practice. Specifically, the
Nadaraya-Watson kernel regression model proposed in 1964 is a simple yet
complete example for demonstrating machine learning with attention
mechanisms.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from mxnet import autograd, gluon, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l
npx.set_np()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import torch
from torch import nn
from d2l import torch as d2l
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import tensorflow as tf
from d2l import tensorflow as d2l
tf.random.set_seed(seed=1322)
.. raw:: html
.. raw:: html
Generating the Dataset
----------------------
To keep things simple, let us consider the following regression problem:
given a dataset of input-output pairs
:math:`\{(x_1, y_1), \ldots, (x_n, y_n)\}`, how to learn :math:`f` to
predict the output :math:`\hat{y} = f(x)` for any new input :math:`x`?
Here we generate an artificial dataset according to the following
nonlinear function with the noise term :math:`\epsilon`:
.. math:: y_i = 2\sin(x_i) + x_i^{0.8} + \epsilon,
where :math:`\epsilon` obeys a normal distribution with zero mean and
standard deviation 0.5. Both 50 training examples and 50 testing
examples are generated. To better visualize the pattern of attention
later, the training inputs are sorted.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
n_train = 50 # No. of training examples
x_train = np.sort(np.random.rand(n_train) * 5) # Training inputs
def f(x):
return 2 * np.sin(x) + x**0.8
y_train = f(x_train) + np.random.normal(0.0, 0.5, (n_train,)) # Training outputs
x_test = np.arange(0, 5, 0.1) # Testing examples
y_truth = f(x_test) # Ground-truth outputs for the testing examples
n_test = len(x_test) # No. of testing examples
n_test
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
50
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
n_train = 50 # No. of training examples
x_train, _ = torch.sort(torch.rand(n_train) * 5) # Training inputs
def f(x):
return 2 * torch.sin(x) + x**0.8
y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,)) # Training outputs
x_test = torch.arange(0, 5, 0.1) # Testing examples
y_truth = f(x_test) # Ground-truth outputs for the testing examples
n_test = len(x_test) # No. of testing examples
n_test
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
50
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
n_train = 50
x_train = tf.sort(tf.random.uniform(shape=(n_train,), maxval=5))
def f(x):
return 2 * tf.sin(x) + x**0.8
y_train = f(x_train) + tf.random.normal((n_train,), 0.0, 0.5) # Training outputs
x_test = tf.range(0, 5, 0.1) # Testing examples
y_truth = f(x_test) # Ground-truth outputs for the testing examples
n_test = len(x_test) # No. of testing examples
n_test
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
50
.. raw:: html
.. raw:: html
The following function plots all the training examples (represented by
circles), the ground-truth data generation function ``f`` without the
noise term (labeled by "Truth"), and the learned prediction function
(labeled by "Pred").
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def plot_kernel_reg(y_hat):
d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],
xlim=[0, 5], ylim=[-1, 5])
d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);
Average Pooling
---------------
We begin with perhaps the world's "dumbest" estimator for this
regression problem: using average pooling to average over all the
training outputs:
.. math:: f(x) = \frac{1}{n}\sum_{i=1}^n y_i,
:label: eq_avg-pooling
which is plotted below. As we can see, this estimator is indeed not so
smart.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
y_hat = y_train.mean().repeat(n_test)
plot_kernel_reg(y_hat)
.. figure:: output_nadaraya-watson_61a333_29_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
y_hat = torch.repeat_interleave(y_train.mean(), n_test)
plot_kernel_reg(y_hat)
.. figure:: output_nadaraya-watson_61a333_32_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
y_hat = tf.repeat(tf.reduce_mean(y_train), repeats=n_test)
plot_kernel_reg(y_hat)
.. figure:: output_nadaraya-watson_61a333_35_0.svg
.. raw:: html
.. raw:: html
Nonparametric Attention Pooling
-------------------------------
Obviously, average pooling omits the inputs :math:`x_i`. A better idea
was proposed by Nadaraya :cite:`Nadaraya.1964` and Watson
:cite:`Watson.1964` to weigh the outputs :math:`y_i` according to
their input locations:
.. math:: f(x) = \sum_{i=1}^n \frac{K(x - x_i)}{\sum_{j=1}^n K(x - x_j)} y_i,
:label: eq_nadaraya-watson
where :math:`K` is a *kernel*. The estimator in
:eq:`eq_nadaraya-watson` is called *Nadaraya-Watson kernel
regression*. Here we will not dive into details of kernels. Recall the
framework of attention mechanisms in :numref:`fig_qkv`. From the
perspective of attention, we can rewrite :eq:`eq_nadaraya-watson`
in a more generalized form of *attention pooling*:
.. math:: f(x) = \sum_{i=1}^n \alpha(x, x_i) y_i,
:label: eq_attn-pooling
where :math:`x` is the query and :math:`(x_i, y_i)` is the key-value
pair. Comparing :eq:`eq_attn-pooling` and
:eq:`eq_avg-pooling`, the attention pooling here is a weighted
average of values :math:`y_i`. The *attention weight*
:math:`\alpha(x, x_i)` in :eq:`eq_attn-pooling` is assigned to the
corresponding value :math:`y_i` based on the interaction between the
query :math:`x` and the key :math:`x_i` modeled by :math:`\alpha`. For
any query, its attention weights over all the key-value pairs are a
valid probability distribution: they are non-negative and sum up to one.
To gain intuitions of attention pooling, just consider a *Gaussian
kernel* defined as
.. math::
K(u) = \frac{1}{\sqrt{2\pi}} \exp(-\frac{u^2}{2}).
Plugging the Gaussian kernel into :eq:`eq_attn-pooling` and
:eq:`eq_nadaraya-watson` gives
.. math:: \begin{aligned} f(x) &=\sum_{i=1}^n \alpha(x, x_i) y_i\\ &= \sum_{i=1}^n \frac{\exp\left(-\frac{1}{2}(x - x_i)^2\right)}{\sum_{j=1}^n \exp\left(-\frac{1}{2}(x - x_j)^2\right)} y_i \\&= \sum_{i=1}^n \mathrm{softmax}\left(-\frac{1}{2}(x - x_i)^2\right) y_i. \end{aligned}
:label: eq_nadaraya-watson-gaussian
In :eq:`eq_nadaraya-watson-gaussian`, a key :math:`x_i` that is
closer to the given query :math:`x` will get *more attention* via a
*larger attention weight* assigned to the key's corresponding value
:math:`y_i`.
Notably, Nadaraya-Watson kernel regression is a nonparametric model;
thus :eq:`eq_nadaraya-watson-gaussian` is an example of
*nonparametric attention pooling*. In the following, we plot the
prediction based on this nonparametric attention model. The predicted
line is smooth and closer to the ground-truth than that produced by
average pooling.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
# Shape of `X_repeat`: (`n_test`, `n_train`), where each row contains the
# same testing inputs (i.e., same queries)
X_repeat = x_test.repeat(n_train).reshape((-1, n_train))
# Note that `x_train` contains the keys. Shape of `attention_weights`:
# (`n_test`, `n_train`), where each row contains attention weights to be
# assigned among the values (`y_train`) given each query
attention_weights = npx.softmax(-(X_repeat - x_train)**2 / 2)
# Each element of `y_hat` is weighted average of values, where weights are
# attention weights
y_hat = np.dot(attention_weights, y_train)
plot_kernel_reg(y_hat)
.. figure:: output_nadaraya-watson_61a333_41_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
# Shape of `X_repeat`: (`n_test`, `n_train`), where each row contains the
# same testing inputs (i.e., same queries)
X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))
# Note that `x_train` contains the keys. Shape of `attention_weights`:
# (`n_test`, `n_train`), where each row contains attention weights to be
# assigned among the values (`y_train`) given each query
attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)
# Each element of `y_hat` is weighted average of values, where weights are
# attention weights
y_hat = torch.matmul(attention_weights, y_train)
plot_kernel_reg(y_hat)
.. figure:: output_nadaraya-watson_61a333_44_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
# Shape of `X_repeat`: (`n_test`, `n_train`), where each row contains the
# same testing inputs (i.e., same queries)
X_repeat = tf.repeat(tf.expand_dims(x_train, axis=0), repeats=n_train, axis=0)
# Note that `x_train` contains the keys. Shape of `attention_weights`:
# (`n_test`, `n_train`), where each row contains attention weights to be
# assigned among the values (`y_train`) given each query
attention_weights = tf.nn.softmax(-(X_repeat - tf.expand_dims(x_train, axis=1))**2/2, axis=1)
# Each element of `y_hat` is weighted average of values, where weights are attention weights
y_hat = tf.matmul(attention_weights, tf.expand_dims(y_train, axis=1))
plot_kernel_reg(y_hat)
.. figure:: output_nadaraya-watson_61a333_47_0.svg
.. raw:: html
.. raw:: html
Now let us take a look at the attention weights. Here testing inputs are
queries while training inputs are keys. Since both inputs are sorted, we
can see that the closer the query-key pair is, the higher attention
weight is in the attention pooling.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(np.expand_dims(np.expand_dims(attention_weights, 0), 0),
xlabel='Sorted training inputs',
ylabel='Sorted testing inputs')
.. figure:: output_nadaraya-watson_61a333_53_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),
xlabel='Sorted training inputs',
ylabel='Sorted testing inputs')
.. figure:: output_nadaraya-watson_61a333_56_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(tf.expand_dims(tf.expand_dims(attention_weights, axis=0), axis=0),
xlabel='Sorted training inputs',
ylabel='Sorted testing inputs')
.. figure:: output_nadaraya-watson_61a333_59_0.svg
.. raw:: html
.. raw:: html
**Parametric Attention Pooling**
--------------------------------
Nonparametric Nadaraya-Watson kernel regression enjoys the *consistency*
benefit: given enough data this model converges to the optimal solution.
Nonetheless, we can easily integrate learnable parameters into attention
pooling.
As an example, slightly different from
:eq:`eq_nadaraya-watson-gaussian`, in the following the distance
between the query :math:`x` and the key :math:`x_i` is multiplied by a
learnable parameter :math:`w`:
.. math:: \begin{aligned}f(x) &= \sum_{i=1}^n \alpha(x, x_i) y_i \\&= \sum_{i=1}^n \frac{\exp\left(-\frac{1}{2}((x - x_i)w)^2\right)}{\sum_{j=1}^n \exp\left(-\frac{1}{2}((x - x_j)w)^2\right)} y_i \\&= \sum_{i=1}^n \mathrm{softmax}\left(-\frac{1}{2}((x - x_i)w)^2\right) y_i.\end{aligned}
:label: eq_nadaraya-watson-gaussian-para
In the rest of the section, we will train this model by learning the
parameter of the attention pooling in
:eq:`eq_nadaraya-watson-gaussian-para`.
.. _subsec_batch_dot:
Batch Matrix Multiplication
~~~~~~~~~~~~~~~~~~~~~~~~~~~
To more efficiently compute attention for minibatches, we can leverage
batch matrix multiplication utilities provided by deep learning
frameworks.
Suppose that the first minibatch contains :math:`n` matrices
:math:`\mathbf{X}_1, \ldots, \mathbf{X}_n` of shape :math:`a\times b`,
and the second minibatch contains :math:`n` matrices
:math:`\mathbf{Y}_1, \ldots, \mathbf{Y}_n` of shape :math:`b\times c`.
Their batch matrix multiplication results in :math:`n` matrices
:math:`\mathbf{X}_1\mathbf{Y}_1, \ldots, \mathbf{X}_n\mathbf{Y}_n` of
shape :math:`a\times c`. Therefore, given two tensors of shape
(:math:`n`, :math:`a`, :math:`b`) and (:math:`n`, :math:`b`, :math:`c`),
the shape of their batch matrix multiplication output is (:math:`n`,
:math:`a`, :math:`c`).
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = np.ones((2, 1, 4))
Y = np.ones((2, 4, 6))
npx.batch_dot(X, Y).shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(2, 1, 6)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = torch.ones((2, 1, 4))
Y = torch.ones((2, 4, 6))
torch.bmm(X, Y).shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
torch.Size([2, 1, 6])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = tf.ones((2, 1, 4))
Y = tf.ones((2, 4, 6))
tf.matmul(X, Y).shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
TensorShape([2, 1, 6])
.. raw:: html
.. raw:: html
In the context of attention mechanisms, we can use minibatch matrix
multiplication to compute weighted averages of values in a minibatch.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
weights = np.ones((2, 10)) * 0.1
values = np.arange(20).reshape((2, 10))
npx.batch_dot(np.expand_dims(weights, 1), np.expand_dims(values, -1))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([[[ 4.5]],
[[14.5]]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
weights = torch.ones((2, 10)) * 0.1
values = torch.arange(20.0).reshape((2, 10))
torch.bmm(weights.unsqueeze(1), values.unsqueeze(-1))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([[[ 4.5000]],
[[14.5000]]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
weights = tf.ones((2, 10)) * 0.1
values = tf.reshape(tf.range(20.0), shape = (2, 10))
tf.matmul(tf.expand_dims(weights, axis=1), tf.expand_dims(values, axis=-1)).numpy()
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([[[ 4.5]],
[[14.5]]], dtype=float32)
.. raw:: html
.. raw:: html
Defining the Model
~~~~~~~~~~~~~~~~~~
Using minibatch matrix multiplication, below we define the parametric
version of Nadaraya-Watson kernel regression based on the parametric
attention pooling in :eq:`eq_nadaraya-watson-gaussian-para`.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class NWKernelRegression(nn.Block):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.w = self.params.get('w', shape=(1,))
def forward(self, queries, keys, values):
# Shape of the output `queries` and `attention_weights`:
# (no. of queries, no. of key-value pairs)
queries = queries.repeat(keys.shape[1]).reshape((-1, keys.shape[1]))
self.attention_weights = npx.softmax(
-((queries - keys) * self.w.data())**2 / 2)
# Shape of `values`: (no. of queries, no. of key-value pairs)
return npx.batch_dot(np.expand_dims(self.attention_weights, 1),
np.expand_dims(values, -1)).reshape(-1)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class NWKernelRegression(nn.Module):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.w = nn.Parameter(torch.rand((1,), requires_grad=True))
def forward(self, queries, keys, values):
# Shape of the output `queries` and `attention_weights`:
# (no. of queries, no. of key-value pairs)
queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))
self.attention_weights = nn.functional.softmax(
-((queries - keys) * self.w)**2 / 2, dim=1)
# Shape of `values`: (no. of queries, no. of key-value pairs)
return torch.bmm(self.attention_weights.unsqueeze(1),
values.unsqueeze(-1)).reshape(-1)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class NWKernelRegression(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.w = tf.Variable(initial_value=tf.random.uniform(shape=(1,)))
def call(self, queries, keys, values, **kwargs):
# For training queries are `x_train`. Keys are distance of taining data for each point. Values are `y_train`.
# Shape of the output `queries` and `attention_weights`: (no. of queries, no. of key-value pairs)
queries = tf.repeat(tf.expand_dims(queries, axis=1), repeats=keys.shape[1], axis=1)
self.attention_weights = tf.nn.softmax(-((queries - keys) * self.w)**2 /2, axis =1)
# Shape of `values`: (no. of queries, no. of key-value pairs)
return tf.squeeze(tf.matmul(tf.expand_dims(self.attention_weights, axis=1), tf.expand_dims(values, axis=-1)))
.. raw:: html
.. raw:: html
Training
~~~~~~~~
In the following, we transform the training dataset to keys and values
to train the attention model. In the parametric attention pooling, any
training input takes key-value pairs from all the training examples
except for itself to predict its output.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
# Shape of `X_tile`: (`n_train`, `n_train`), where each column contains the
# same training inputs
X_tile = np.tile(x_train, (n_train, 1))
# Shape of `Y_tile`: (`n_train`, `n_train`), where each column contains the
# same training outputs
Y_tile = np.tile(y_train, (n_train, 1))
# Shape of `keys`: ('n_train', 'n_train' - 1)
keys = X_tile[(1 - np.eye(n_train)).astype('bool')].reshape((n_train, -1))
# Shape of `values`: ('n_train', 'n_train' - 1)
values = Y_tile[(1 - np.eye(n_train)).astype('bool')].reshape((n_train, -1))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
# Shape of `X_tile`: (`n_train`, `n_train`), where each column contains the
# same training inputs
X_tile = x_train.repeat((n_train, 1))
# Shape of `Y_tile`: (`n_train`, `n_train`), where each column contains the
# same training outputs
Y_tile = y_train.repeat((n_train, 1))
# Shape of `keys`: ('n_train', 'n_train' - 1)
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
# Shape of `values`: ('n_train', 'n_train' - 1)
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
# Shape of `X_tile`: (`n_train`, `n_train`), where each column contains the
# same training inputs
X_tile = tf.repeat(tf.expand_dims(x_train, axis=0), repeats=n_train, axis=0)
# Shape of `Y_tile`: (`n_train`, `n_train`), where each column contains the
# same training outputs
Y_tile = tf.repeat(tf.expand_dims(y_train, axis=0), repeats=n_train, axis=0)
# Shape of `keys`: ('n_train', 'n_train' - 1)
keys = tf.reshape(X_tile[tf.cast(1 - tf.eye(n_train), dtype=tf.bool)], shape=(n_train, -1))
# Shape of `values`: ('n_train', 'n_train' - 1)
values = tf.reshape(Y_tile[tf.cast(1 - tf.eye(n_train), dtype=tf.bool)], shape=(n_train, -1))
.. raw:: html
.. raw:: html
Using the squared loss and stochastic gradient descent, we train the
parametric attention model.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = NWKernelRegression()
net.initialize()
loss = gluon.loss.L2Loss()
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.5})
animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])
for epoch in range(5):
with autograd.record():
l = loss(net(x_train, keys, values), y_train)
l.backward()
trainer.step(1)
print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')
animator.add(epoch + 1, float(l.sum()))
.. figure:: output_nadaraya-watson_61a333_113_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = NWKernelRegression()
loss = nn.MSELoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=0.5)
animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])
for epoch in range(5):
trainer.zero_grad()
l = loss(net(x_train, keys, values), y_train)
l.sum().backward()
trainer.step()
print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')
animator.add(epoch + 1, float(l.sum()))
.. figure:: output_nadaraya-watson_61a333_116_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = NWKernelRegression()
loss_object = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.SGD(learning_rate=0.5)
animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])
for epoch in range(5):
with tf.GradientTape() as t:
loss = loss_object(y_train, net(x_train, keys, values)) * len(y_train)
grads = t.gradient(loss, net.trainable_variables)
optimizer.apply_gradients(zip(grads, net.trainable_variables))
print(f'epoch {epoch + 1}, loss {float(loss):.6f}')
animator.add(epoch + 1, float(loss))
.. figure:: output_nadaraya-watson_61a333_119_0.svg
.. raw:: html
.. raw:: html
After training the parametric attention model, we can plot its
prediction. Trying to fit the training dataset with noise, the predicted
line is less smooth than its nonparametric counterpart that was plotted
earlier.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
# Shape of `keys`: (`n_test`, `n_train`), where each column contains the same
# training inputs (i.e., same keys)
keys = np.tile(x_train, (n_test, 1))
# Shape of `value`: (`n_test`, `n_train`)
values = np.tile(y_train, (n_test, 1))
y_hat = net(x_test, keys, values)
plot_kernel_reg(y_hat)
.. figure:: output_nadaraya-watson_61a333_125_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
# Shape of `keys`: (`n_test`, `n_train`), where each column contains the same
# training inputs (i.e., same keys)
keys = x_train.repeat((n_test, 1))
# Shape of `value`: (`n_test`, `n_train`)
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(y_hat)
.. figure:: output_nadaraya-watson_61a333_128_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
# Shape of `keys`: (`n_test`, `n_train`), where each column contains the same
# training inputs (i.e., same keys)
keys = tf.repeat(tf.expand_dims(x_train, axis=0), repeats=n_test, axis=0)
# Shape of `value`: (`n_test`, `n_train`)
values = tf.repeat(tf.expand_dims(y_train, axis=0), repeats=n_test, axis=0)
y_hat = net(x_test, keys, values)
plot_kernel_reg(y_hat)
.. figure:: output_nadaraya-watson_61a333_131_0.svg
.. raw:: html
.. raw:: html
Comparing with nonparametric attention pooling, the region with large
attention weights becomes sharper in the learnable and parametric
setting.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(np.expand_dims(np.expand_dims(net.attention_weights, 0), 0),
xlabel='Sorted training inputs',
ylabel='Sorted testing inputs')
.. figure:: output_nadaraya-watson_61a333_137_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),
xlabel='Sorted training inputs',
ylabel='Sorted testing inputs')
.. figure:: output_nadaraya-watson_61a333_140_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(tf.expand_dims(tf.expand_dims(net.attention_weights, axis=0), axis=0),
xlabel='Sorted training inputs',
ylabel='Sorted testing inputs')
.. figure:: output_nadaraya-watson_61a333_143_0.svg
.. raw:: html
.. raw:: html
Summary
-------
- Nadaraya-Watson kernel regression is an example of machine learning
with attention mechanisms.
- The attention pooling of Nadaraya-Watson kernel regression is a
weighted average of the training outputs. From the attention
perspective, the attention weight is assigned to a value based on a
function of a query and the key that is paired with the value.
- Attention pooling can be either nonparametric or parametric.
Exercises
---------
1. Increase the number of training examples. Can you learn nonparametric
Nadaraya-Watson kernel regression better?
2. What is the value of our learned :math:`w` in the parametric
attention pooling experiment? Why does it make the weighted region
sharper when visualizing the attention weights?
3. How can we add hyperparameters to nonparametric Nadaraya-Watson
kernel regression to predict better?
4. Design another parametric attention pooling for the kernel regression
of this section. Train this new model and visualize its attention
weights.
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html