.. _sec_attention-scoring-functions:
Attention Scoring Functions
===========================
In :numref:`sec_nadaraya-watson`, we used a Gaussian kernel to model
interactions between queries and keys. Treating the exponent of the
Gaussian kernel in :eq:`eq_nadaraya-watson-gaussian` as an
*attention scoring function* (or *scoring function* for short), the
results of this function were essentially fed into a softmax operation.
As a result, we obtained a probability distribution (attention weights)
over values that are paired with keys. In the end, the output of the
attention pooling is simply a weighted sum of the values based on these
attention weights.
At a high level, we can use the above algorithm to instantiate the
framework of attention mechanisms in :numref:`fig_qkv`. Denoting an
attention scoring function by :math:`a`,
:numref:`fig_attention_output` illustrates how the output of attention
pooling can be computed as a weighted sum of values. Since attention
weights are a probability distribution, the weighted sum is essentially
a weighted average.
.. _fig_attention_output:
.. figure:: ../img/attention-output.svg
Computing the output of attention pooling as a weighted average of
values.
Mathematically, suppose that we have a query
:math:`\mathbf{q} \in \mathbb{R}^q` and :math:`m` key-value pairs
:math:`(\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m)`,
where any :math:`\mathbf{k}_i \in \mathbb{R}^k` and any
:math:`\mathbf{v}_i \in \mathbb{R}^v`. The attention pooling :math:`f`
is instantiated as a weighted sum of the values:
.. math:: f(\mathbf{q}, (\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m)) = \sum_{i=1}^m \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i \in \mathbb{R}^v,
:label: eq_attn-pooling
where the attention weight (scalar) for the query :math:`\mathbf{q}` and
key :math:`\mathbf{k}_i` is computed by the softmax operation of an
attention scoring function :math:`a` that maps two vectors to a scalar:
.. math:: \alpha(\mathbf{q}, \mathbf{k}_i) = \mathrm{softmax}(a(\mathbf{q}, \mathbf{k}_i)) = \frac{\exp(a(\mathbf{q}, \mathbf{k}_i))}{\sum_{j=1}^m \exp(a(\mathbf{q}, \mathbf{k}_j))} \in \mathbb{R}.
:label: eq_attn-scoring-alpha
As we can see, different choices of the attention scoring function
:math:`a` lead to different behaviors of attention pooling. In this
section, we introduce two popular scoring functions that we will use to
develop more sophisticated attention mechanisms later.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import math
from mxnet import np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l
npx.set_np()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import math
import torch
from torch import nn
from d2l import torch as d2l
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import tensorflow as tf
from d2l import tensorflow as d2l
.. raw:: html
.. raw:: html
Masked Softmax Operation
------------------------
As we just mentioned, a softmax operation is used to output a
probability distribution as attention weights. In some cases, not all
the values should be fed into attention pooling. For instance, for
efficient minibatch processing in :numref:`sec_machine_translation`,
some text sequences are padded with special tokens that do not carry
meaning. To get an attention pooling over only meaningful tokens as
values, we can specify a valid sequence length (in number of tokens) to
filter out those beyond this specified range when computing softmax. In
this way, we can implement such a *masked softmax operation* in the
following ``masked_softmax`` function, where any value beyond the valid
length is masked as zero.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
def masked_softmax(X, valid_lens):
"""Perform softmax operation by masking elements on the last axis."""
# `X`: 3D tensor, `valid_lens`: 1D or 2D tensor
if valid_lens is None:
return npx.softmax(X)
else:
shape = X.shape
if valid_lens.ndim == 1:
valid_lens = valid_lens.repeat(shape[1])
else:
valid_lens = valid_lens.reshape(-1)
# On the last axis, replace masked elements with a very large negative
# value, whose exponentiation outputs 0
X = npx.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, True,
value=-1e6, axis=1)
return npx.softmax(X).reshape(shape)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
def masked_softmax(X, valid_lens):
"""Perform softmax operation by masking elements on the last axis."""
# `X`: 3D tensor, `valid_lens`: 1D or 2D tensor
if valid_lens is None:
return nn.functional.softmax(X, dim=-1)
else:
shape = X.shape
if valid_lens.dim() == 1:
valid_lens = torch.repeat_interleave(valid_lens, shape[1])
else:
valid_lens = valid_lens.reshape(-1)
# On the last axis, replace masked elements with a very large negative
# value, whose exponentiation outputs 0
X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
value=-1e6)
return nn.functional.softmax(X.reshape(shape), dim=-1)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
def masked_softmax(X, valid_lens):
"""Perform softmax operation by masking elements on the last axis."""
# `X`: 3D tensor, `valid_lens`: 1D or 2D tensor
if valid_lens is None:
return tf.nn.softmax(X, axis=-1)
else:
shape = X.shape
if len(valid_lens.shape) == 1:
valid_lens = tf.repeat(valid_lens, repeats=shape[1])
else:
valid_lens = tf.reshape(valid_lens, shape=-1)
# On the last axis, replace masked elements with a very large negative
# value, whose exponentiation outputs 0
X = d2l.sequence_mask(tf.reshape(X, shape=(-1, shape[-1])), valid_lens, value=-1e6)
return tf.nn.softmax(tf.reshape(X, shape=shape), axis=-1)
.. raw:: html
.. raw:: html
To demonstrate how this function works, consider a minibatch of two
:math:`2 \times 4` matrix examples, where the valid lengths for these
two examples are two and three, respectively. As a result of the masked
softmax operation, values beyond the valid lengths are all masked as
zero.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
masked_softmax(np.random.uniform(size=(2, 2, 4)), np.array([2, 3]))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([[[0.488994 , 0.511006 , 0. , 0. ],
[0.4365484 , 0.56345165, 0. , 0. ]],
[[0.288171 , 0.3519408 , 0.3598882 , 0. ],
[0.29034296, 0.25239873, 0.45725837, 0. ]]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([[[0.5667, 0.4333, 0.0000, 0.0000],
[0.6657, 0.3343, 0.0000, 0.0000]],
[[0.2451, 0.3035, 0.4514, 0.0000],
[0.4595, 0.2742, 0.2663, 0.0000]]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
masked_softmax(tf.random.uniform(shape=(2, 2, 4)), tf.constant([2, 3]))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
Similarly, we can also use a two-dimensional tensor to specify valid
lengths for every row in each matrix example.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
masked_softmax(np.random.uniform(size=(2, 2, 4)),
np.array([[1, 3], [2, 4]]))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([[[1. , 0. , 0. , 0. ],
[0.35848376, 0.3658879 , 0.27562833, 0. ]],
[[0.54370314, 0.45629686, 0. , 0. ],
[0.19598778, 0.25580427, 0.19916739, 0.3490406 ]]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
[0.3734, 0.1832, 0.4434, 0.0000]],
[[0.5169, 0.4831, 0.0000, 0.0000],
[0.3576, 0.1722, 0.1807, 0.2894]]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
masked_softmax(tf.random.uniform((2, 2, 4)), tf.constant([[1, 3], [2, 4]]))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
.. _subsec_additive-attention:
Additive Attention
------------------
In general, when queries and keys are vectors of different lengths, we
can use additive attention as the scoring function. Given a query
:math:`\mathbf{q} \in \mathbb{R}^q` and a key
:math:`\mathbf{k} \in \mathbb{R}^k`, the *additive attention* scoring
function
.. math:: a(\mathbf q, \mathbf k) = \mathbf w_v^\top \text{tanh}(\mathbf W_q\mathbf q + \mathbf W_k \mathbf k) \in \mathbb{R},
:label: eq_additive-attn
where learnable parameters :math:`\mathbf W_q\in\mathbb R^{h\times q}`,
:math:`\mathbf W_k\in\mathbb R^{h\times k}`, and
:math:`\mathbf w_v\in\mathbb R^{h}`. Equivalent to
:eq:`eq_additive-attn`, the query and the key are concatenated and
fed into an MLP with a single hidden layer whose number of hidden units
is :math:`h`, a hyperparameter. By using :math:`\tanh` as the activation
function and disabling bias terms, we implement additive attention in
the following.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class AdditiveAttention(nn.Block):
"""Additive attention."""
def __init__(self, num_hiddens, dropout, **kwargs):
super(AdditiveAttention, self).__init__(**kwargs)
# Use `flatten=False` to only transform the last axis so that the
# shapes for the other axes are kept the same
self.W_k = nn.Dense(num_hiddens, use_bias=False, flatten=False)
self.W_q = nn.Dense(num_hiddens, use_bias=False, flatten=False)
self.w_v = nn.Dense(1, use_bias=False, flatten=False)
self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values, valid_lens):
queries, keys = self.W_q(queries), self.W_k(keys)
# After dimension expansion, shape of `queries`: (`batch_size`, no. of
# queries, 1, `num_hiddens`) and shape of `keys`: (`batch_size`, 1,
# no. of key-value pairs, `num_hiddens`). Sum them up with
# broadcasting
features = np.expand_dims(queries, axis=2) + np.expand_dims(
keys, axis=1)
features = np.tanh(features)
# There is only one output of `self.w_v`, so we remove the last
# one-dimensional entry from the shape. Shape of `scores`:
# (`batch_size`, no. of queries, no. of key-value pairs)
scores = np.squeeze(self.w_v(features), axis=-1)
self.attention_weights = masked_softmax(scores, valid_lens)
# Shape of `values`: (`batch_size`, no. of key-value pairs, value
# dimension)
return npx.batch_dot(self.dropout(self.attention_weights), values)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class AdditiveAttention(nn.Module):
"""Additive attention."""
def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
super(AdditiveAttention, self).__init__(**kwargs)
self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
self.w_v = nn.Linear(num_hiddens, 1, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values, valid_lens):
queries, keys = self.W_q(queries), self.W_k(keys)
# After dimension expansion, shape of `queries`: (`batch_size`, no. of
# queries, 1, `num_hiddens`) and shape of `keys`: (`batch_size`, 1,
# no. of key-value pairs, `num_hiddens`). Sum them up with
# broadcasting
features = queries.unsqueeze(2) + keys.unsqueeze(1)
features = torch.tanh(features)
# There is only one output of `self.w_v`, so we remove the last
# one-dimensional entry from the shape. Shape of `scores`:
# (`batch_size`, no. of queries, no. of key-value pairs)
scores = self.w_v(features).squeeze(-1)
self.attention_weights = masked_softmax(scores, valid_lens)
# Shape of `values`: (`batch_size`, no. of key-value pairs, value
# dimension)
return torch.bmm(self.dropout(self.attention_weights), values)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class AdditiveAttention(tf.keras.layers.Layer):
"""Additive attention."""
def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
super().__init__(**kwargs)
self.W_k = tf.keras.layers.Dense(num_hiddens, use_bias=False)
self.W_q = tf.keras.layers.Dense(num_hiddens, use_bias=False)
self.w_v = tf.keras.layers.Dense(1, use_bias=False)
self.dropout = tf.keras.layers.Dropout(dropout)
def call(self, queries, keys, values, valid_lens, **kwargs):
queries, keys = self.W_q(queries), self.W_k(keys)
# After dimension expansion, shape of `queries`: (`batch_size`, no. of
# queries, 1, `num_hiddens`) and shape of `keys`: (`batch_size`, 1,
# no. of key-value pairs, `num_hiddens`). Sum them up with
# broadcasting
features = tf.expand_dims(queries, axis=2) + tf.expand_dims(
keys, axis=1)
features = tf.nn.tanh(features)
# There is only one output of `self.w_v`, so we remove the last
# one-dimensional entry from the shape. Shape of `scores`:
# (`batch_size`, no. of queries, no. of key-value pairs)
scores = tf.squeeze(self.w_v(features), axis=-1)
self.attention_weights = masked_softmax(scores, valid_lens)
# Shape of `values`: (`batch_size`, no. of key-value pairs, value
# dimension)
return tf.matmul(self.dropout(
self.attention_weights, **kwargs), values)
.. raw:: html
.. raw:: html
Let us demonstrate the above ``AdditiveAttention`` class with a toy
example, where shapes (batch size, number of steps or sequence length in
tokens, feature size) of queries, keys, and values are (:math:`2`,
:math:`1`, :math:`20`), (:math:`2`, :math:`10`, :math:`2`), and
(:math:`2`, :math:`10`, :math:`4`), respectively. The attention pooling
output has a shape of (batch size, number of steps for queries, feature
size for values).
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
queries, keys = np.random.normal(0, 1, (2, 1, 20)), np.ones((2, 10, 2))
# The two value matrices in the `values` minibatch are identical
values = np.arange(40).reshape(1, 10, 4).repeat(2, axis=0)
valid_lens = np.array([2, 6])
attention = AdditiveAttention(num_hiddens=8, dropout=0.1)
attention.initialize()
attention(queries, keys, values, valid_lens)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([[[ 2. , 3. , 4. , 5. ]],
[[10. , 11. , 12.000001, 13. ]]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))
# The two value matrices in the `values` minibatch are identical
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(
2, 1, 1)
valid_lens = torch.tensor([2, 6])
attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8,
dropout=0.1)
attention.eval()
attention(queries, keys, values, valid_lens)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([[[ 2.0000, 3.0000, 4.0000, 5.0000]],
[[10.0000, 11.0000, 12.0000, 13.0000]]], grad_fn=)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
queries, keys = tf.random.normal(shape=(2, 1, 20)), tf.ones((2, 10, 2))
# The two value matrices in the `values` minibatch are identical
values = tf.repeat(tf.reshape(
tf.range(40, dtype=tf.float32), shape=(1, 10, 4)), repeats=2, axis=0)
valid_lens = tf.constant([2, 6])
attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8,
dropout=0.1)
attention(queries, keys, values, valid_lens, training=False)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
Although additive attention contains learnable parameters, since every
key is the same in this example, the attention weights are uniform,
determined by the specified valid lengths.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
.. figure:: output_attention-scoring-functions_2a8fdc_75_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
.. figure:: output_attention-scoring-functions_2a8fdc_78_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(tf.reshape(attention.attention_weights, (1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
.. figure:: output_attention-scoring-functions_2a8fdc_81_0.svg
.. raw:: html
.. raw:: html
Scaled Dot-Product Attention
----------------------------
A more computationally efficient design for the scoring function can be
simply dot product. However, the dot product operation requires that
both the query and the key have the same vector length, say :math:`d`.
Assume that all the elements of the query and the key are independent
random variables with zero mean and unit variance. The dot product of
both vectors has zero mean and a variance of :math:`d`. To ensure that
the variance of the dot product still remains one regardless of vector
length, the *scaled dot-product attention* scoring function
.. math:: a(\mathbf q, \mathbf k) = \mathbf{q}^\top \mathbf{k} /\sqrt{d}
divides the dot product by :math:`\sqrt{d}`. In practice, we often think
in minibatches for efficiency, such as computing attention for :math:`n`
queries and :math:`m` key-value pairs, where queries and keys are of
length :math:`d` and values are of length :math:`v`. The scaled
dot-product attention of queries
:math:`\mathbf Q\in\mathbb R^{n\times d}`, keys
:math:`\mathbf K\in\mathbb R^{m\times d}`, and values
:math:`\mathbf V\in\mathbb R^{m\times v}` is
.. math:: \mathrm{softmax}\left(\frac{\mathbf Q \mathbf K^\top }{\sqrt{d}}\right) \mathbf V \in \mathbb{R}^{n\times v}.
:label: eq_softmax_QK_V
In the following implementation of the scaled dot product attention, we
use dropout for model regularization.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class DotProductAttention(nn.Block):
"""Scaled dot product attention."""
def __init__(self, dropout, **kwargs):
super(DotProductAttention, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)
# Shape of `queries`: (`batch_size`, no. of queries, `d`)
# Shape of `keys`: (`batch_size`, no. of key-value pairs, `d`)
# Shape of `values`: (`batch_size`, no. of key-value pairs, value
# dimension)
# Shape of `valid_lens`: (`batch_size`,) or (`batch_size`, no. of queries)
def forward(self, queries, keys, values, valid_lens=None):
d = queries.shape[-1]
# Set `transpose_b=True` to swap the last two dimensions of `keys`
scores = npx.batch_dot(queries, keys, transpose_b=True) / math.sqrt(d)
self.attention_weights = masked_softmax(scores, valid_lens)
return npx.batch_dot(self.dropout(self.attention_weights), values)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class DotProductAttention(nn.Module):
"""Scaled dot product attention."""
def __init__(self, dropout, **kwargs):
super(DotProductAttention, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)
# Shape of `queries`: (`batch_size`, no. of queries, `d`)
# Shape of `keys`: (`batch_size`, no. of key-value pairs, `d`)
# Shape of `values`: (`batch_size`, no. of key-value pairs, value
# dimension)
# Shape of `valid_lens`: (`batch_size`,) or (`batch_size`, no. of queries)
def forward(self, queries, keys, values, valid_lens=None):
d = queries.shape[-1]
# Set `transpose_b=True` to swap the last two dimensions of `keys`
scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)
self.attention_weights = masked_softmax(scores, valid_lens)
return torch.bmm(self.dropout(self.attention_weights), values)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class DotProductAttention(tf.keras.layers.Layer):
"""Scaled dot product attention."""
def __init__(self, dropout, **kwargs):
super().__init__(**kwargs)
self.dropout = tf.keras.layers.Dropout(dropout)
# Shape of `queries`: (`batch_size`, no. of queries, `d`)
# Shape of `keys`: (`batch_size`, no. of key-value pairs, `d`)
# Shape of `values`: (`batch_size`, no. of key-value pairs, value
# dimension)
# Shape of `valid_lens`: (`batch_size`,) or (`batch_size`, no. of queries)
def call(self, queries, keys, values, valid_lens, **kwargs):
d = queries.shape[-1]
scores = tf.matmul(queries, keys, transpose_b=True)/tf.math.sqrt(
tf.cast(d, dtype=tf.float32))
self.attention_weights = masked_softmax(scores, valid_lens)
return tf.matmul(self.dropout(self.attention_weights, **kwargs), values)
.. raw:: html
.. raw:: html
To demonstrate the above ``DotProductAttention`` class, we use the same
keys, values, and valid lengths from the earlier toy example for
additive attention. For the dot product operation, we make the feature
size of queries the same as that of keys.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
queries = np.random.normal(0, 1, (2, 1, 2))
attention = DotProductAttention(dropout=0.5)
attention.initialize()
attention(queries, keys, values, valid_lens)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([[[ 2. , 3. , 4. , 5. ]],
[[10. , 11. , 12.000001, 13. ]]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
queries = torch.normal(0, 1, (2, 1, 2))
attention = DotProductAttention(dropout=0.5)
attention.eval()
attention(queries, keys, values, valid_lens)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([[[ 2.0000, 3.0000, 4.0000, 5.0000]],
[[10.0000, 11.0000, 12.0000, 13.0000]]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
queries = tf.random.normal(shape=(2, 1, 2))
attention = DotProductAttention(dropout=0.5)
attention(queries, keys, values, valid_lens, training=False)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
Same as in the additive attention demonstration, since ``keys`` contains
the same element that cannot be differentiated by any query, uniform
attention weights are obtained.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
.. figure:: output_attention-scoring-functions_2a8fdc_111_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
.. figure:: output_attention-scoring-functions_2a8fdc_114_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(tf.reshape(attention.attention_weights, (1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
.. figure:: output_attention-scoring-functions_2a8fdc_117_0.svg
.. raw:: html
.. raw:: html
Summary
-------
- We can compute the output of attention pooling as a weighted average
of values, where different choices of the attention scoring function
lead to different behaviors of attention pooling.
- When queries and keys are vectors of different lengths, we can use
the additive attention scoring function. When they are the same, the
scaled dot-product attention scoring function is more computationally
efficient.
Exercises
---------
1. Modify keys in the toy example and visualize attention weights. Do
additive attention and scaled dot-product attention still output the
same attention weights? Why or why not?
2. Using matrix multiplications only, can you design a new scoring
function for queries and keys with different vector lengths?
3. When queries and keys have the same vector length, is vector
summation a better design than dot product for the scoring function?
Why or why not?
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html