.. _sec_transformer:
Transformer
===========
We have compared CNNs, RNNs, and self-attention in
:numref:`subsec_cnn-rnn-self-attention`. Notably, self-attention
enjoys both parallel computation and the shortest maximum path length.
Therefore natually, it is appealing to design deep architectures by
using self-attention. Unlike earlier self-attention models that still
rely on RNNs for input representations
:cite:`Cheng.Dong.Lapata.2016,Lin.Feng.Santos.ea.2017,Paulus.Xiong.Socher.2017`,
the transformer model is solely based on attention mechanisms without
any convolutional or recurrent layer
:cite:`Vaswani.Shazeer.Parmar.ea.2017`. Though originally proposed for
sequence to sequence learning on text data, transformers have been
pervasive in a wide range of modern deep learning applications, such as
in areas of language, vision, speech, and reinforcement learning.
Model
-----
As an instance of the encoder-decoder architecture, the overall
architecture of the transformer is presented in
:numref:`fig_transformer`. As we can see, the transformer is composed
of an encoder and a decoder. Different from Bahdanau attention for
sequence to sequence learning in :numref:`fig_s2s_attention_details`,
the input (source) and output (target) sequence embeddings are added
with positional encoding before being fed into the encoder and the
decoder that stack modules based on self-attention.
.. _fig_transformer:
.. figure:: ../img/transformer.svg
:width: 500px
The transformer architecture.
Now we provide an overview of the transformer architecture in
:numref:`fig_transformer`. On a high level, the transformer encoder is
a stack of multiple identical layers, where each layer has two sublayers
(either is denoted as :math:`\mathrm{sublayer}`). The first is a
multi-head self-attention pooling and the second is a positionwise
feed-forward network. Specifically, in the encoder self-attention,
queries, keys, and values are all from the the outputs of the previous
encoder layer. Inspired by the ResNet design in :numref:`sec_resnet`,
a residual connection is employed around both sublayers. In the
transformer, for any input :math:`\mathbf{x} \in \mathbb{R}^d` at any
position of the sequence, we require that
:math:`\mathrm{sublayer}(\mathbf{x}) \in \mathbb{R}^d` so that the
residual connection
:math:`\mathbf{x} + \mathrm{sublayer}(\mathbf{x}) \in \mathbb{R}^d` is
feasible. This addition from the residual connection is immediately
followed by layer normalization :cite:`Ba.Kiros.Hinton.2016`. As a
result, the transformer encoder outputs a :math:`d`-dimensional vector
representation for each position of the input sequence.
The transformer decoder is also a stack of multiple identical layers
with residual connections and layer normalizations. Besides the two
sublayers described in the encoder, the decoder inserts a third
sublayer, known as the encoder-decoder attention, between these two. In
the encoder-decoder attention, queries are from the outputs of the
previous decoder layer, and the keys and values are from the transformer
encoder outputs. In the decoder self-attention, queries, keys, and
values are all from the the outputs of the previous decoder layer.
However, each position in the decoder is allowed to only attend to all
positions in the decoder up to that position. This *masked* attention
preserves the auto-regressive property, ensuring that the prediction
only depends on those output tokens that have been generated.
We have already described and implemented multi-head attention based on
scaled dot-products in :numref:`sec_multihead-attention` and
positional encoding in :numref:`subsec_positional-encoding`. In the
following, we will implement the rest of the transformer model.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import math
import pandas as pd
from mxnet import autograd, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l
npx.set_np()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import math
import pandas as pd
import torch
from torch import nn
from d2l import torch as d2l
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import numpy as np
import pandas as pd
import tensorflow as tf
from d2l import tensorflow as d2l
.. raw:: html
.. raw:: html
Positionwise Feed-Forward Networks
----------------------------------
The positionwise feed-forward network transforms the representation at
all the sequence positions using the same MLP. This is why we call it
*positionwise*. In the implementation below, the input ``X`` with shape
(batch size, number of time steps or sequence length in tokens, number
of hidden units or feature dimension) will be transformed by a two-layer
MLP into an output tensor of shape (batch size, number of time steps,
``ffn_num_outputs``).
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class PositionWiseFFN(nn.Block):
"""Positionwise feed-forward network."""
def __init__(self, ffn_num_hiddens, ffn_num_outputs, **kwargs):
super(PositionWiseFFN, self).__init__(**kwargs)
self.dense1 = nn.Dense(ffn_num_hiddens, flatten=False,
activation='relu')
self.dense2 = nn.Dense(ffn_num_outputs, flatten=False)
def forward(self, X):
return self.dense2(self.dense1(X))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class PositionWiseFFN(nn.Module):
"""Positionwise feed-forward network."""
def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs,
**kwargs):
super(PositionWiseFFN, self).__init__(**kwargs)
self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
self.relu = nn.ReLU()
self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)
def forward(self, X):
return self.dense2(self.relu(self.dense1(X)))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class PositionWiseFFN(tf.keras.layers.Layer):
"""Positionwise feed-forward network."""
def __init__(self, ffn_num_hiddens, ffn_num_outputs, **kwargs):
super().__init__(*kwargs)
self.dense1 = tf.keras.layers.Dense(ffn_num_hiddens)
self.relu = tf.keras.layers.ReLU()
self.dense2 = tf.keras.layers.Dense(ffn_num_outputs)
def call(self, X):
return self.dense2(self.relu(self.dense1(X)))
.. raw:: html
.. raw:: html
The following example shows that the innermost dimension of a tensor
changes to the number of outputs in the positionwise feed-forward
network. Since the same MLP transforms at all the positions, when the
inputs at all these positions are the same, their outputs are also
identical.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
ffn = PositionWiseFFN(4, 8)
ffn.initialize()
ffn(np.ones((2, 3, 4)))[0]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([[ 0.00239431, 0.00927085, -0.00021069, -0.00923989, -0.0082903 ,
-0.00162741, 0.00659031, 0.00023905],
[ 0.00239431, 0.00927085, -0.00021069, -0.00923989, -0.0082903 ,
-0.00162741, 0.00659031, 0.00023905],
[ 0.00239431, 0.00927085, -0.00021069, -0.00923989, -0.0082903 ,
-0.00162741, 0.00659031, 0.00023905]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
ffn = PositionWiseFFN(4, 4, 8)
ffn.eval()
ffn(torch.ones((2, 3, 4)))[0]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([[-0.4656, 0.1496, -0.4186, -0.0076, 0.0727, -0.2567, -0.0681, -0.2657],
[-0.4656, 0.1496, -0.4186, -0.0076, 0.0727, -0.2567, -0.0681, -0.2657],
[-0.4656, 0.1496, -0.4186, -0.0076, 0.0727, -0.2567, -0.0681, -0.2657]],
grad_fn=)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
ffn = PositionWiseFFN(4, 8)
ffn(tf.ones((2, 3, 4)))[0]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
Residual Connection and Layer Normalization
-------------------------------------------
Now let us focus on the "add & norm" component in
:numref:`fig_transformer`. As we described at the beginning of this
section, this is a residual connection immediately followed by layer
normalization. Both are key to effective deep architectures.
In :numref:`sec_batch_norm`, we explained how batch normalization
recenters and rescales across the examples within a minibatch. Layer
normalization is the same as batch normalization except that the former
normalizes across the feature dimension. Despite its pervasive
applications in computer vision, batch normalization is usually
empirically less effective than layer normalization in natural language
processing tasks, whose inputs are often variable-length sequences.
The following code snippet compares the normalization across different
dimensions by layer normalization and batch normalization.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
ln = nn.LayerNorm()
ln.initialize()
bn = nn.BatchNorm()
bn.initialize()
X = np.array([[1, 2], [2, 3]])
# Compute mean and variance from `X` in the training mode
with autograd.record():
print('layer norm:', ln(X), '\nbatch norm:', bn(X))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
layer norm: [[-0.99998 0.99998]
[-0.99998 0.99998]]
batch norm: [[-0.99998 -0.99998]
[ 0.99998 0.99998]]
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
ln = nn.LayerNorm(2)
bn = nn.BatchNorm1d(2)
X = torch.tensor([[1, 2], [2, 3]], dtype=torch.float32)
# Compute mean and variance from `X` in the training mode
print('layer norm:', ln(X), '\nbatch norm:', bn(X))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
layer norm: tensor([[-1.0000, 1.0000],
[-1.0000, 1.0000]], grad_fn=)
batch norm: tensor([[-1.0000, -1.0000],
[ 1.0000, 1.0000]], grad_fn=)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
ln = tf.keras.layers.LayerNormalization()
bn = tf.keras.layers.BatchNormalization()
X = tf.constant([[1, 2], [2, 3]], dtype=tf.float32)
print('layer norm:', ln(X), '\nbatch norm:', bn(X, training=True))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
layer norm: tf.Tensor(
[[-0.998006 0.9980061]
[-0.9980061 0.998006 ]], shape=(2, 2), dtype=float32)
batch norm: tf.Tensor(
[[-0.998006 -0.9980061 ]
[ 0.9980061 0.99800587]], shape=(2, 2), dtype=float32)
.. raw:: html
.. raw:: html
Now we can implement the ``AddNorm`` class using a residual connection
followed by layer normalization. Dropout is also applied for
regularization.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class AddNorm(nn.Block):
"""Residual connection followed by layer normalization."""
def __init__(self, dropout, **kwargs):
super(AddNorm, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)
self.ln = nn.LayerNorm()
def forward(self, X, Y):
return self.ln(self.dropout(Y) + X)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class AddNorm(nn.Module):
"""Residual connection followed by layer normalization."""
def __init__(self, normalized_shape, dropout, **kwargs):
super(AddNorm, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)
self.ln = nn.LayerNorm(normalized_shape)
def forward(self, X, Y):
return self.ln(self.dropout(Y) + X)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class AddNorm(tf.keras.layers.Layer):
"""Residual connection followed by layer normalization."""
def __init__(self, normalized_shape, dropout, **kwargs):
super().__init__(**kwargs)
self.dropout = tf.keras.layers.Dropout(dropout)
self.ln = tf.keras.layers.LayerNormalization(normalized_shape)
def call(self, X, Y, **kwargs):
return self.ln(self.dropout(Y, **kwargs) + X)
.. raw:: html
.. raw:: html
The residual connection requires that the two inputs are of the same
shape so that the output tensor also has the same shape after the
addition operation.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
add_norm = AddNorm(0.5)
add_norm.initialize()
add_norm(np.ones((2, 3, 4)), np.ones((2, 3, 4))).shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(2, 3, 4)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
add_norm = AddNorm([3, 4], 0.5) # Normalized_shape is input.size()[1:]
add_norm.eval()
add_norm(torch.ones((2, 3, 4)), torch.ones((2, 3, 4))).shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
torch.Size([2, 3, 4])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
add_norm = AddNorm([1, 2], 0.5) # Normalized_shape is: [i for i in range(len(input.shape))][1:]
add_norm(tf.ones((2, 3, 4)), tf.ones((2, 3, 4)), training=False).shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
TensorShape([2, 3, 4])
.. raw:: html
.. raw:: html
Encoder
-------
With all the essential components to assemble the transformer encoder,
let us start by implementing a single layer within the encoder. The
following ``EncoderBlock`` class contains two sublayers: multi-head
self-attention and positionwise feed-forward networks, where a residual
connection followed by layer normalization is employed around both
sublayers.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class EncoderBlock(nn.Block):
"""Transformer encoder block."""
def __init__(self, num_hiddens, ffn_num_hiddens, num_heads, dropout,
use_bias=False, **kwargs):
super(EncoderBlock, self).__init__(**kwargs)
self.attention = d2l.MultiHeadAttention(
num_hiddens, num_heads, dropout, use_bias)
self.addnorm1 = AddNorm(dropout)
self.ffn = PositionWiseFFN(ffn_num_hiddens, num_hiddens)
self.addnorm2 = AddNorm(dropout)
def forward(self, X, valid_lens):
Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
return self.addnorm2(Y, self.ffn(Y))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class EncoderBlock(nn.Module):
"""Transformer encoder block."""
def __init__(self, key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
dropout, use_bias=False, **kwargs):
super(EncoderBlock, self).__init__(**kwargs)
self.attention = d2l.MultiHeadAttention(
key_size, query_size, value_size, num_hiddens, num_heads, dropout,
use_bias)
self.addnorm1 = AddNorm(norm_shape, dropout)
self.ffn = PositionWiseFFN(
ffn_num_input, ffn_num_hiddens, num_hiddens)
self.addnorm2 = AddNorm(norm_shape, dropout)
def forward(self, X, valid_lens):
Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
return self.addnorm2(Y, self.ffn(Y))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class EncoderBlock(tf.keras.layers.Layer):
"""Transformer encoder block."""
def __init__(self, key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_hiddens, num_heads, dropout, bias=False, **kwargs):
super().__init__(**kwargs)
self.attention = d2l.MultiHeadAttention(key_size, query_size, value_size, num_hiddens,
num_heads, dropout, bias)
self.addnorm1 = AddNorm(norm_shape, dropout)
self.ffn = PositionWiseFFN(ffn_num_hiddens, num_hiddens)
self.addnorm2 = AddNorm(norm_shape, dropout)
def call(self, X, valid_lens, **kwargs):
Y = self.addnorm1(X, self.attention(X, X, X, valid_lens, **kwargs), **kwargs)
return self.addnorm2(Y, self.ffn(Y), **kwargs)
.. raw:: html
.. raw:: html
As we can see, any layer in the transformer encoder does not change the
shape of its input.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = np.ones((2, 100, 24))
valid_lens = np.array([3, 2])
encoder_blk = EncoderBlock(24, 48, 8, 0.5)
encoder_blk.initialize()
encoder_blk(X, valid_lens).shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(2, 100, 24)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = torch.ones((2, 100, 24))
valid_lens = torch.tensor([3, 2])
encoder_blk = EncoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5)
encoder_blk.eval()
encoder_blk(X, valid_lens).shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
torch.Size([2, 100, 24])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = tf.ones((2, 100, 24))
valid_lens = tf.constant([3, 2])
norm_shape = [i for i in range(len(X.shape))][1:]
encoder_blk = EncoderBlock(24, 24, 24, 24, norm_shape, 48, 8, 0.5)
encoder_blk(X, valid_lens, training=False).shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
TensorShape([2, 100, 24])
.. raw:: html
.. raw:: html
In the following transformer encoder implementation, we stack
``num_layers`` instances of the above ``EncoderBlock`` classes. Since we
use the fixed positional encoding whose values are always between -1 and
1, we multiply values of the learnable input embeddings by the square
root of the embedding dimension to rescale before summing up the input
embedding and the positional encoding.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class TransformerEncoder(d2l.Encoder):
"""Transformer encoder."""
def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens,
num_heads, num_layers, dropout, use_bias=False, **kwargs):
super(TransformerEncoder, self).__init__(**kwargs)
self.num_hiddens = num_hiddens
self.embedding = nn.Embedding(vocab_size, num_hiddens)
self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
self.blks = nn.Sequential()
for _ in range(num_layers):
self.blks.add(
EncoderBlock(num_hiddens, ffn_num_hiddens, num_heads, dropout,
use_bias))
def forward(self, X, valid_lens, *args):
# Since positional encoding values are between -1 and 1, the embedding
# values are multiplied by the square root of the embedding dimension
# to rescale before they are summed up
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
self.attention_weights = [None] * len(self.blks)
for i, blk in enumerate(self.blks):
X = blk(X, valid_lens)
self.attention_weights[
i] = blk.attention.attention.attention_weights
return X
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class TransformerEncoder(d2l.Encoder):
"""Transformer encoder."""
def __init__(self, vocab_size, key_size, query_size, value_size,
num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
num_heads, num_layers, dropout, use_bias=False, **kwargs):
super(TransformerEncoder, self).__init__(**kwargs)
self.num_hiddens = num_hiddens
self.embedding = nn.Embedding(vocab_size, num_hiddens)
self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
self.blks = nn.Sequential()
for i in range(num_layers):
self.blks.add_module("block"+str(i),
EncoderBlock(key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_input, ffn_num_hiddens,
num_heads, dropout, use_bias))
def forward(self, X, valid_lens, *args):
# Since positional encoding values are between -1 and 1, the embedding
# values are multiplied by the square root of the embedding dimension
# to rescale before they are summed up
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
self.attention_weights = [None] * len(self.blks)
for i, blk in enumerate(self.blks):
X = blk(X, valid_lens)
self.attention_weights[
i] = blk.attention.attention.attention_weights
return X
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class TransformerEncoder(d2l.Encoder):
"""Transformer encoder."""
def __init__(self, vocab_size, key_size, query_size, value_size,
num_hiddens, norm_shape, ffn_num_hiddens, num_heads,
num_layers, dropout, bias=False, **kwargs):
super().__init__(**kwargs)
self.num_hiddens = num_hiddens
self.embedding = tf.keras.layers.Embedding(vocab_size, num_hiddens)
self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
self.blks = [EncoderBlock(
key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_hiddens, num_heads, dropout, bias) for _ in range(
num_layers)]
def call(self, X, valid_lens, **kwargs):
# Since positional encoding values are between -1 and 1, the embedding
# values are multiplied by the square root of the embedding dimension
# to rescale before they are summed up
X = self.pos_encoding(self.embedding(X) * tf.math.sqrt(
tf.cast(self.num_hiddens, dtype=tf.float32)), **kwargs)
self.attention_weights = [None] * len(self.blks)
for i, blk in enumerate(self.blks):
X = blk(X, valid_lens, **kwargs)
self.attention_weights[
i] = blk.attention.attention.attention_weights
return X
.. raw:: html
.. raw:: html
Below we specify hyperparameters to create a two-layer transformer
encoder. The shape of the transformer encoder output is (batch size,
number of time steps, ``num_hiddens``).
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
encoder = TransformerEncoder(200, 24, 48, 8, 2, 0.5)
encoder.initialize()
encoder(np.ones((2, 100)), valid_lens).shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(2, 100, 24)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
encoder = TransformerEncoder(
200, 24, 24, 24, 24, [100, 24], 24, 48, 8, 2, 0.5)
encoder.eval()
encoder(torch.ones((2, 100), dtype=torch.long), valid_lens).shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
torch.Size([2, 100, 24])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
encoder = TransformerEncoder(200, 24, 24, 24, 24, [1, 2], 48, 8, 2, 0.5)
encoder(tf.ones((2, 100)), valid_lens, training=False).shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
TensorShape([2, 100, 24])
.. raw:: html
.. raw:: html
Decoder
-------
As shown in :numref:`fig_transformer`, the transformer decoder is
composed of multiple identical layers. Each layer is implemented in the
following ``DecoderBlock`` class, which contains three sublayers:
decoder self-attention, encoder-decoder attention, and positionwise
feed-forward networks. These sublayers employ a residual connection
around them followed by layer normalization.
As we described earlier in this section, in the masked multi-head
decoder self-attention (the first sublayer), queries, keys, and values
all come from the outputs of the previous decoder layer. When training
sequence-to-sequence models, tokens at all the positions (time steps) of
the output sequence are known. However, during prediction the output
sequence is generated token by token; thus, at any decoder time step
only the generated tokens can be used in the decoder self-attention. To
preserve auto-regression in the decoder, its masked self-attention
specifies ``dec_valid_lens`` so that any query only attends to all
positions in the decoder up to the query position.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class DecoderBlock(nn.Block):
# The `i`-th block in the decoder
def __init__(self, num_hiddens, ffn_num_hiddens, num_heads,
dropout, i, **kwargs):
super(DecoderBlock, self).__init__(**kwargs)
self.i = i
self.attention1 = d2l.MultiHeadAttention(num_hiddens, num_heads,
dropout)
self.addnorm1 = AddNorm(dropout)
self.attention2 = d2l.MultiHeadAttention(num_hiddens, num_heads,
dropout)
self.addnorm2 = AddNorm(dropout)
self.ffn = PositionWiseFFN(ffn_num_hiddens, num_hiddens)
self.addnorm3 = AddNorm(dropout)
def forward(self, X, state):
enc_outputs, enc_valid_lens = state[0], state[1]
# During training, all the tokens of any output sequence are processed
# at the same time, so `state[2][self.i]` is `None` as initialized.
# When decoding any output sequence token by token during prediction,
# `state[2][self.i]` contains representations of the decoded output at
# the `i`-th block up to the current time step
if state[2][self.i] is None:
key_values = X
else:
key_values = np.concatenate((state[2][self.i], X), axis=1)
state[2][self.i] = key_values
if autograd.is_training():
batch_size, num_steps, _ = X.shape
# Shape of `dec_valid_lens`: (`batch_size`, `num_steps`), where
# every row is [1, 2, ..., `num_steps`]
dec_valid_lens = np.tile(np.arange(1, num_steps + 1, ctx=X.ctx),
(batch_size, 1))
else:
dec_valid_lens = None
# Self-attention
X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
Y = self.addnorm1(X, X2)
# Encoder-decoder attention. Shape of `enc_outputs`:
# (`batch_size`, `num_steps`, `num_hiddens`)
Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
Z = self.addnorm2(Y, Y2)
return self.addnorm3(Z, self.ffn(Z)), state
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class DecoderBlock(nn.Module):
# The `i`-th block in the decoder
def __init__(self, key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
dropout, i, **kwargs):
super(DecoderBlock, self).__init__(**kwargs)
self.i = i
self.attention1 = d2l.MultiHeadAttention(
key_size, query_size, value_size, num_hiddens, num_heads, dropout)
self.addnorm1 = AddNorm(norm_shape, dropout)
self.attention2 = d2l.MultiHeadAttention(
key_size, query_size, value_size, num_hiddens, num_heads, dropout)
self.addnorm2 = AddNorm(norm_shape, dropout)
self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens,
num_hiddens)
self.addnorm3 = AddNorm(norm_shape, dropout)
def forward(self, X, state):
enc_outputs, enc_valid_lens = state[0], state[1]
# During training, all the tokens of any output sequence are processed
# at the same time, so `state[2][self.i]` is `None` as initialized.
# When decoding any output sequence token by token during prediction,
# `state[2][self.i]` contains representations of the decoded output at
# the `i`-th block up to the current time step
if state[2][self.i] is None:
key_values = X
else:
key_values = torch.cat((state[2][self.i], X), axis=1)
state[2][self.i] = key_values
if self.training:
batch_size, num_steps, _ = X.shape
# Shape of `dec_valid_lens`: (`batch_size`, `num_steps`), where
# every row is [1, 2, ..., `num_steps`]
dec_valid_lens = torch.arange(
1, num_steps + 1, device=X.device).repeat(batch_size, 1)
else:
dec_valid_lens = None
# Self-attention
X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
Y = self.addnorm1(X, X2)
# Encoder-decoder attention. Shape of `enc_outputs`:
# (`batch_size`, `num_steps`, `num_hiddens`)
Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
Z = self.addnorm2(Y, Y2)
return self.addnorm3(Z, self.ffn(Z)), state
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class DecoderBlock(tf.keras.layers.Layer):
# The `i`-th block in the decoder
def __init__(self, key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_hiddens, num_heads, dropout, i, **kwargs):
super().__init__(**kwargs)
self.i = i
self.attention1 = d2l.MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)
self.addnorm1 = AddNorm(norm_shape, dropout)
self.attention2 = d2l.MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)
self.addnorm2 = AddNorm(norm_shape, dropout)
self.ffn = PositionWiseFFN(ffn_num_hiddens, num_hiddens)
self.addnorm3 = AddNorm(norm_shape, dropout)
def call(self, X, state, **kwargs):
enc_outputs, enc_valid_lens = state[0], state[1]
# During training, all the tokens of any output sequence are processed
# at the same time, so `state[2][self.i]` is `None` as initialized.
# When decoding any output sequence token by token during prediction,
# `state[2][self.i]` contains representations of the decoded output at
# the `i`-th block up to the current time step
if state[2][self.i] is None:
key_values = X
else:
key_values = tf.concat((state[2][self.i], X), axis=1)
state[2][self.i] = key_values
if kwargs["training"]:
batch_size, num_steps, _ = X.shape
# Shape of `dec_valid_lens`: (`batch_size`, `num_steps`), where
# every row is [1, 2, ..., `num_steps`]
dec_valid_lens = tf.repeat(tf.reshape(tf.range(1, num_steps + 1),
shape=(-1, num_steps)), repeats=batch_size, axis=0)
else:
dec_valid_lens = None
# Self-attention
X2 = self.attention1(X, key_values, key_values, dec_valid_lens, **kwargs)
Y = self.addnorm1(X, X2, **kwargs)
# Encoder-decoder attention. Shape of `enc_outputs`: (`batch_size`, `num_steps`, `num_hiddens`)
Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens, **kwargs)
Z = self.addnorm2(Y, Y2, **kwargs)
return self.addnorm3(Z, self.ffn(Z), **kwargs), state
.. raw:: html
.. raw:: html
To facilitate scaled dot-product operations in the encoder-decoder
attention and addition operations in the residual connections, the
feature dimension (``num_hiddens``) of the decoder is the same as that
of the encoder.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
decoder_blk = DecoderBlock(24, 48, 8, 0.5, 0)
decoder_blk.initialize()
X = np.ones((2, 100, 24))
state = [encoder_blk(X, valid_lens), valid_lens, [None]]
decoder_blk(X, state)[0].shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(2, 100, 24)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
decoder_blk = DecoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5, 0)
decoder_blk.eval()
X = torch.ones((2, 100, 24))
state = [encoder_blk(X, valid_lens), valid_lens, [None]]
decoder_blk(X, state)[0].shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
torch.Size([2, 100, 24])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
decoder_blk = DecoderBlock(24, 24, 24, 24, [1, 2], 48, 8, 0.5, 0)
X = tf.ones((2, 100, 24))
state = [encoder_blk(X, valid_lens), valid_lens, [None]]
decoder_blk(X, state, training=False)[0].shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
TensorShape([2, 100, 24])
.. raw:: html
.. raw:: html
Now we construct the entire transformer decoder composed of
``num_layers`` instances of ``DecoderBlock``. In the end, a
fully-connected layer computes the prediction for all the ``vocab_size``
possible output tokens. Both of the decoder self-attention weights and
the encoder-decoder attention weights are stored for later
visualization.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class TransformerDecoder(d2l.AttentionDecoder):
def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens,
num_heads, num_layers, dropout, **kwargs):
super(TransformerDecoder, self).__init__(**kwargs)
self.num_hiddens = num_hiddens
self.num_layers = num_layers
self.embedding = nn.Embedding(vocab_size, num_hiddens)
self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
self.blks = nn.Sequential()
for i in range(num_layers):
self.blks.add(
DecoderBlock(num_hiddens, ffn_num_hiddens, num_heads,
dropout, i))
self.dense = nn.Dense(vocab_size, flatten=False)
def init_state(self, enc_outputs, enc_valid_lens, *args):
return [enc_outputs, enc_valid_lens, [None] * self.num_layers]
def forward(self, X, state):
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
self._attention_weights = [[None] * len(self.blks) for _ in range (2)]
for i, blk in enumerate(self.blks):
X, state = blk(X, state)
# Decoder self-attention weights
self._attention_weights[0][
i] = blk.attention1.attention.attention_weights
# Encoder-decoder attention weights
self._attention_weights[1][
i] = blk.attention2.attention.attention_weights
return self.dense(X), state
@property
def attention_weights(self):
return self._attention_weights
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class TransformerDecoder(d2l.AttentionDecoder):
def __init__(self, vocab_size, key_size, query_size, value_size,
num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
num_heads, num_layers, dropout, **kwargs):
super(TransformerDecoder, self).__init__(**kwargs)
self.num_hiddens = num_hiddens
self.num_layers = num_layers
self.embedding = nn.Embedding(vocab_size, num_hiddens)
self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
self.blks = nn.Sequential()
for i in range(num_layers):
self.blks.add_module("block"+str(i),
DecoderBlock(key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_input, ffn_num_hiddens,
num_heads, dropout, i))
self.dense = nn.Linear(num_hiddens, vocab_size)
def init_state(self, enc_outputs, enc_valid_lens, *args):
return [enc_outputs, enc_valid_lens, [None] * self.num_layers]
def forward(self, X, state):
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
self._attention_weights = [[None] * len(self.blks) for _ in range (2)]
for i, blk in enumerate(self.blks):
X, state = blk(X, state)
# Decoder self-attention weights
self._attention_weights[0][
i] = blk.attention1.attention.attention_weights
# Encoder-decoder attention weights
self._attention_weights[1][
i] = blk.attention2.attention.attention_weights
return self.dense(X), state
@property
def attention_weights(self):
return self._attention_weights
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class TransformerDecoder(d2l.AttentionDecoder):
def __init__(self, vocab_size, key_size, query_size, value_size,
num_hiddens, norm_shape, ffn_num_hidens, num_heads, num_layers, dropout, **kwargs):
super().__init__(**kwargs)
self.num_hiddens = num_hiddens
self.num_layers = num_layers
self.embedding = tf.keras.layers.Embedding(vocab_size, num_hiddens)
self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
self.blks = [DecoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_hiddens, num_heads, dropout, i) for i in range(num_layers)]
self.dense = tf.keras.layers.Dense(vocab_size)
def init_state(self, enc_outputs, enc_valid_lens, *args):
return [enc_outputs, enc_valid_lens, [None] * self.num_layers]
def call(self, X, state, **kwargs):
X = self.pos_encoding(self.embedding(X) * tf.math.sqrt(tf.cast(self.num_hiddens, dtype=tf.float32)), **kwargs)
self._attention_weights = [[None] * len(self.blks) for _ in range(2)] # 2 Attention layers in decoder
for i, blk in enumerate(self.blks):
X, state = blk(X, state, **kwargs)
# Decoder self-attention weights
self._attention_weights[0][i] = blk.attention1.attention.attention_weights
# Encoder-decoder attention weights
self._attention_weights[1][i] = blk.attention2.attention.attention_weights
return self.dense(X), state
@property
def attention_weights(self):
return self._attention_weights
.. raw:: html
.. raw:: html
Training
--------
Let us instantiate an encoder-decoder model by following the transformer
architecture. Here we specify that both the transformer encoder and the
transformer decoder have 2 layers using 4-head attention. Similar to
:numref:`sec_seq2seq_training`, we train the transformer model for
sequence to sequence learning on the English-French machine translation
dataset.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
num_hiddens, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10
lr, num_epochs, device = 0.005, 200, d2l.try_gpu()
ffn_num_hiddens, num_heads = 64, 4
train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
encoder = TransformerEncoder(
len(src_vocab), num_hiddens, ffn_num_hiddens, num_heads, num_layers,
dropout)
decoder = TransformerDecoder(
len(tgt_vocab), num_hiddens, ffn_num_hiddens, num_heads, num_layers,
dropout)
net = d2l.EncoderDecoder(encoder, decoder)
d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
loss 0.030, 2201.9 tokens/sec on gpu(0)
.. figure:: output_transformer_5722f1_159_1.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
num_hiddens, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10
lr, num_epochs, device = 0.005, 200, d2l.try_gpu()
ffn_num_input, ffn_num_hiddens, num_heads = 32, 64, 4
key_size, query_size, value_size = 32, 32, 32
norm_shape = [32]
train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
encoder = TransformerEncoder(
len(src_vocab), key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
num_layers, dropout)
decoder = TransformerDecoder(
len(tgt_vocab), key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
num_layers, dropout)
net = d2l.EncoderDecoder(encoder, decoder)
d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
loss 0.031, 4599.2 tokens/sec on cuda:0
.. figure:: output_transformer_5722f1_162_1.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
num_hiddens, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10
lr, num_epochs, device = 0.005, 200, d2l.try_gpu()
ffn_num_hiddens, num_heads = 64, 4
key_size, query_size, value_size = 32, 32, 32
norm_shape = [2]
train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
encoder = TransformerEncoder(
len(src_vocab), key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_hiddens, num_heads, num_layers, dropout)
decoder = TransformerDecoder(
len(tgt_vocab), key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_hiddens, num_heads, num_layers, dropout)
net = d2l.EncoderDecoder(encoder, decoder)
d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
loss 0.028, 821.4 tokens/sec on
.. figure:: output_transformer_5722f1_165_1.svg
.. raw:: html
.. raw:: html
After training, we use the transformer model to translate a few English
sentences into French and compute their BLEU scores.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
translation, dec_attention_weight_seq = d2l.predict_seq2seq(
net, eng, src_vocab, tgt_vocab, num_steps, device, True)
print(f'{eng} => {translation}, ',
f'bleu {d2l.bleu(translation, fra, k=2):.3f}')
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
go . => va !, bleu 1.000
i lost . => j'ai perdu ., bleu 1.000
he's calm . => il est calme ., bleu 1.000
i'm home . => je suis chez moi ., bleu 0.803
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
translation, dec_attention_weight_seq = d2l.predict_seq2seq(
net, eng, src_vocab, tgt_vocab, num_steps, device, True)
print(f'{eng} => {translation}, ',
f'bleu {d2l.bleu(translation, fra, k=2):.3f}')
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
go . => va !, bleu 1.000
i lost . => j'ai perdu ., bleu 1.000
he's calm . => il est calme ., bleu 1.000
i'm home . => je suis chez moi ., bleu 1.000
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
translation, dec_attention_weight_seq = d2l.predict_seq2seq(
net, eng, src_vocab, tgt_vocab, num_steps, True)
print(f'{eng} => {translation}, ',
f'bleu {d2l.bleu(translation, fra, k=2):.3f}')
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
go . => va !, bleu 1.000
i lost . => j'ai perdu ., bleu 1.000
he's calm . => il est malade est malade ., bleu 0.473
i'm home . => je suis chez moi ., bleu 1.000
.. raw:: html
.. raw:: html
Let us visualize the transformer attention weights when translating the
last English sentence into French. The shape of the encoder
self-attention weights is (number of encoder layers, number of attention
heads, ``num_steps`` or number of queries, ``num_steps`` or number of
key-value pairs).
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
enc_attention_weights = np.concatenate(net.encoder.attention_weights, 0).reshape((num_layers,
num_heads, -1, num_steps))
enc_attention_weights.shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(2, 4, 10, 10)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
enc_attention_weights = torch.cat(net.encoder.attention_weights, 0).reshape((num_layers, num_heads,
-1, num_steps))
enc_attention_weights.shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
torch.Size([2, 4, 10, 10])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
enc_attention_weights = tf.reshape(
tf.concat(net.encoder.attention_weights, 0),
(num_layers, num_heads, -1, num_steps))
enc_attention_weights.shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
TensorShape([2, 4, 10, 10])
.. raw:: html
.. raw:: html
In the encoder self-attention, both queries and keys come from the same
input sequence. Since padding tokens do not carry meaning, with
specified valid length of the input sequence, no query attends to
positions of padding tokens. In the following, two layers of multi-head
attention weights are presented row by row. Each head independently
attends based on a separate representation subspaces of queries, keys,
and values.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(
enc_attention_weights, xlabel='Key positions', ylabel='Query positions',
titles=['Head %d' % i for i in range(1, 5)], figsize=(7, 3.5))
.. figure:: output_transformer_5722f1_195_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(
enc_attention_weights.cpu(), xlabel='Key positions',
ylabel='Query positions', titles=['Head %d' % i for i in range(1, 5)],
figsize=(7, 3.5))
.. figure:: output_transformer_5722f1_198_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(
enc_attention_weights, xlabel='Key positions', ylabel='Query positions',
titles=['Head %d' % i for i in range(1, 5)], figsize=(7, 3.5))
.. figure:: output_transformer_5722f1_201_0.svg
.. raw:: html
.. raw:: html
To visualize both the decoder self-attention weights and the
encoder-decoder attention weights, we need more data manipulations. For
example, we fill the masked attention weights with zero. Note that the
decoder self-attention weights and the encoder-decoder attention weights
both have the same queries: the beginning-of-sequence token followed by
the output tokens.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
dec_attention_weights_2d = [np.array(head[0]).tolist()
for step in dec_attention_weight_seq
for attn in step for blk in attn for head in blk]
dec_attention_weights_filled = np.array(
pd.DataFrame(dec_attention_weights_2d).fillna(0.0).values)
dec_attention_weights = dec_attention_weights_filled.reshape((-1, 2, num_layers, num_heads, num_steps))
dec_self_attention_weights, dec_inter_attention_weights = \
dec_attention_weights.transpose(1, 2, 3, 0, 4)
dec_self_attention_weights.shape, dec_inter_attention_weights.shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
((2, 4, 7, 10), (2, 4, 7, 10))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
dec_attention_weights_2d = [head[0].tolist()
for step in dec_attention_weight_seq
for attn in step for blk in attn for head in blk]
dec_attention_weights_filled = torch.tensor(
pd.DataFrame(dec_attention_weights_2d).fillna(0.0).values)
dec_attention_weights = dec_attention_weights_filled.reshape((-1, 2, num_layers, num_heads, num_steps))
dec_self_attention_weights, dec_inter_attention_weights = \
dec_attention_weights.permute(1, 2, 3, 0, 4)
dec_self_attention_weights.shape, dec_inter_attention_weights.shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(torch.Size([2, 4, 6, 10]), torch.Size([2, 4, 6, 10]))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
dec_attention_weights_2d = [head[0] for step in dec_attention_weight_seq
for attn in step
for blk in attn for head in blk]
dec_attention_weights_filled = tf.convert_to_tensor(
np.asarray(pd.DataFrame(dec_attention_weights_2d).fillna(
0.0).values).astype(np.float32))
dec_attention_weights = tf.reshape(dec_attention_weights_filled, shape=(
-1, 2, num_layers, num_heads, num_steps))
dec_self_attention_weights, dec_inter_attention_weights = tf.transpose(
dec_attention_weights, perm=(1, 2, 3, 0, 4))
print(dec_self_attention_weights.shape, dec_inter_attention_weights.shape)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(2, 4, 6, 10) (2, 4, 6, 10)
.. raw:: html
.. raw:: html
Due to the auto-regressive property of the decoder self-attention, no
query attends to key-value pairs after the query position.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
# Plus one to include the beginning-of-sequence token
d2l.show_heatmaps(
dec_self_attention_weights[:, :, :, :len(translation.split()) + 1],
xlabel='Key positions', ylabel='Query positions',
titles=['Head %d' % i for i in range(1, 5)], figsize=(7, 3.5))
.. figure:: output_transformer_5722f1_219_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
# Plus one to include the beginning-of-sequence token
d2l.show_heatmaps(
dec_self_attention_weights[:, :, :, :len(translation.split()) + 1],
xlabel='Key positions', ylabel='Query positions',
titles=['Head %d' % i for i in range(1, 5)], figsize=(7, 3.5))
.. figure:: output_transformer_5722f1_222_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
# Plus one to include the beginning-of-sequence token
d2l.show_heatmaps(
dec_self_attention_weights[:, :, :, :len(translation.split()) + 1],
xlabel='Key positions', ylabel='Query positions',
titles=['Head %d' % i for i in range(1, 5)], figsize=(7, 3.5))
.. figure:: output_transformer_5722f1_225_0.svg
.. raw:: html
.. raw:: html
Similar to the case in the encoder self-attention, via the specified
valid length of the input sequence, no query from the output sequence
attends to those padding tokens from the input sequence.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(
dec_inter_attention_weights, xlabel='Key positions',
ylabel='Query positions', titles=['Head %d' % i for i in range(1, 5)],
figsize=(7, 3.5))
.. figure:: output_transformer_5722f1_231_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(
dec_inter_attention_weights, xlabel='Key positions',
ylabel='Query positions', titles=['Head %d' % i for i in range(1, 5)],
figsize=(7, 3.5))
.. figure:: output_transformer_5722f1_234_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(
dec_inter_attention_weights, xlabel='Key positions',
ylabel='Query positions', titles=['Head %d' % i for i in range(1, 5)],
figsize=(7, 3.5))
.. figure:: output_transformer_5722f1_237_0.svg
.. raw:: html
.. raw:: html
Although the transformer architecture was originally proposed for
sequence-to-sequence learning, as we will discover later in the book,
either the transformer encoder or the transformer decoder is often
individually used for different deep learning tasks.
Summary
-------
- The transformer is an instance of the encoder-decoder architecture,
though either the encoder or the decoder can be used individually in
practice.
- In the transformer, multi-head self-attention is used for
representing the input sequence and the output sequence, though the
decoder has to preserve the auto-regressive property via a masked
version.
- Both the residual connections and the layer normalization in the
transformer are important for training a very deep model.
- The positionwise feed-forward network in the transformer model
transforms the representation at all the sequence positions using the
same MLP.
Exercises
---------
1. Train a deeper transformer in the experiments. How does it affect the
training speed and the translation performance?
2. Is it a good idea to replace scaled dot-product attention with
additive attention in the transformer? Why?
3. For language modeling, should we use the transformer encoder,
decoder, or both? How to design this method?
4. What can be challenges to transformers if input sequences are very
long? Why?
5. How to improve computational and memory efficiency of transformers?
Hint: you may refer to the survey paper by Tay et al.
:cite:`Tay.Dehghani.Bahri.ea.2020`.
6. How can we design transformer-based models for image classification
tasks without using CNNs? Hint: you may refer to the vision
transformer :cite:`Dosovitskiy.Beyer.Kolesnikov.ea.2021`.
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html