.. _sec_seq2seq_attention: Bahdanau Attention ================== We studied the machine translation problem in :numref:`sec_seq2seq`, where we designed an encoder-decoder architecture based on two RNNs for sequence to sequence learning. Specifically, the RNN encoder transforms a variable-length sequence into a fixed-shape context variable, then the RNN decoder generates the output (target) sequence token by token based on the generated tokens and the context variable. However, even though not all the input (source) tokens are useful for decoding a certain token, the *same* context variable that encodes the entire input sequence is still used at each decoding step. In a separate but related challenge of handwriting generation for a given text sequence, Graves designed a differentiable attention model to align text characters with the much longer pen trace, where the alignment moves only in one direction :cite:`Graves.2013`. Inspired by the idea of learning to align, Bahdanau et al. proposed a differentiable attention model without the severe unidirectional alignment limitation :cite:`Bahdanau.Cho.Bengio.2014`. When predicting a token, if not all the input tokens are relevant, the model aligns (or attends) only to parts of the input sequence that are relevant to the current prediction. This is achieved by treating the context variable as an output of attention pooling. Model ----- When describing Bahdanau attention for the RNN encoder-decoder below, we will follow the same notation in :numref:`sec_seq2seq`. The new attention-based model is the same as that in :numref:`sec_seq2seq` except that the context variable :math:`\mathbf{c}` in :eq:`eq_seq2seq_s_t` is replaced by :math:`\mathbf{c}_{t'}` at any decoding time step :math:`t'`. Suppose that there are :math:`T` tokens in the input sequence, the context variable at the decoding time step :math:`t'` is the output of attention pooling: .. math:: \mathbf{c}_{t'} = \sum_{t=1}^T \alpha(\mathbf{s}_{t' - 1}, \mathbf{h}_t) \mathbf{h}_t, where the decoder hidden state :math:`\mathbf{s}_{t' - 1}` at time step :math:`t' - 1` is the query, and the encoder hidden states :math:`\mathbf{h}_t` are both the keys and values, and the attention weight :math:`\alpha` is computed as in :eq:`eq_attn-scoring-alpha` using the additive attention scoring function defined by :eq:`eq_additive-attn`. Slightly different from the vanilla RNN encoder-decoder architecture in :numref:`fig_seq2seq_details`, the same architecture with Bahdanau attention is depicted in :numref:`fig_s2s_attention_details`. .. _fig_s2s_attention_details: .. figure:: ../img/seq2seq-attention-details.svg Layers in an RNN encoder-decoder model with Bahdanau attention. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python from mxnet import np, npx from mxnet.gluon import nn, rnn from d2l import mxnet as d2l npx.set_np() .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import torch from torch import nn from d2l import torch as d2l .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import tensorflow as tf from d2l import tensorflow as d2l .. raw:: html
.. raw:: html
Defining the Decoder with Attention ----------------------------------- To implement the RNN encoder-decoder with Bahdanau attention, we only need to redefine the decoder. To visualize the learned attention weights more conveniently, the following ``AttentionDecoder`` class defines the base interface for decoders with attention mechanisms. .. raw:: latex \diilbookstyleinputcell .. code:: python #@save class AttentionDecoder(d2l.Decoder): """The base attention-based decoder interface.""" def __init__(self, **kwargs): super(AttentionDecoder, self).__init__(**kwargs) @property def attention_weights(self): raise NotImplementedError Now let us implement the RNN decoder with Bahdanau attention in the following ``Seq2SeqAttentionDecoder`` class. The state of the decoder is initialized with (i) the encoder final-layer hidden states at all the time steps (as keys and values of the attention); (ii) the encoder all-layer hidden state at the final time step (to initialize the hidden state of the decoder); and (iii) the encoder valid length (to exclude the padding tokens in attention pooling). At each decoding time step, the decoder final-layer hidden state at the previous time step is used as the query of the attention. As a result, both the attention output and the input embedding are concatenated as the input of the RNN decoder. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class Seq2SeqAttentionDecoder(AttentionDecoder): def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs): super(Seq2SeqAttentionDecoder, self).__init__(**kwargs) self.attention = d2l.AdditiveAttention(num_hiddens, dropout) self.embedding = nn.Embedding(vocab_size, embed_size) self.rnn = rnn.GRU(num_hiddens, num_layers, dropout=dropout) self.dense = nn.Dense(vocab_size, flatten=False) def init_state(self, enc_outputs, enc_valid_lens, *args): # Shape of `outputs`: (`num_steps`, `batch_size`, `num_hiddens`). # Shape of `hidden_state[0]`: (`num_layers`, `batch_size`, # `num_hiddens`) outputs, hidden_state = enc_outputs return (outputs.swapaxes(0, 1), hidden_state, enc_valid_lens) def forward(self, X, state): # Shape of `enc_outputs`: (`batch_size`, `num_steps`, `num_hiddens`). # Shape of `hidden_state[0]`: (`num_layers`, `batch_size`, # `num_hiddens`) enc_outputs, hidden_state, enc_valid_lens = state # Shape of the output `X`: (`num_steps`, `batch_size`, `embed_size`) X = self.embedding(X).swapaxes(0, 1) outputs, self._attention_weights = [], [] for x in X: # Shape of `query`: (`batch_size`, 1, `num_hiddens`) query = np.expand_dims(hidden_state[0][-1], axis=1) # Shape of `context`: (`batch_size`, 1, `num_hiddens`) context = self.attention( query, enc_outputs, enc_outputs, enc_valid_lens) # Concatenate on the feature dimension x = np.concatenate((context, np.expand_dims(x, axis=1)), axis=-1) # Reshape `x` as (1, `batch_size`, `embed_size` + `num_hiddens`) out, hidden_state = self.rnn(x.swapaxes(0, 1), hidden_state) outputs.append(out) self._attention_weights.append(self.attention.attention_weights) # After fully-connected layer transformation, shape of `outputs`: # (`num_steps`, `batch_size`, `vocab_size`) outputs = self.dense(np.concatenate(outputs, axis=0)) return outputs.swapaxes(0, 1), [enc_outputs, hidden_state, enc_valid_lens] @property def attention_weights(self): return self._attention_weights .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class Seq2SeqAttentionDecoder(AttentionDecoder): def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs): super(Seq2SeqAttentionDecoder, self).__init__(**kwargs) self.attention = d2l.AdditiveAttention( num_hiddens, num_hiddens, num_hiddens, dropout) self.embedding = nn.Embedding(vocab_size, embed_size) self.rnn = nn.GRU( embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout) self.dense = nn.Linear(num_hiddens, vocab_size) def init_state(self, enc_outputs, enc_valid_lens, *args): # Shape of `outputs`: (`num_steps`, `batch_size`, `num_hiddens`). # Shape of `hidden_state[0]`: (`num_layers`, `batch_size`, # `num_hiddens`) outputs, hidden_state = enc_outputs return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens) def forward(self, X, state): # Shape of `enc_outputs`: (`batch_size`, `num_steps`, `num_hiddens`). # Shape of `hidden_state[0]`: (`num_layers`, `batch_size`, # `num_hiddens`) enc_outputs, hidden_state, enc_valid_lens = state # Shape of the output `X`: (`num_steps`, `batch_size`, `embed_size`) X = self.embedding(X).permute(1, 0, 2) outputs, self._attention_weights = [], [] for x in X: # Shape of `query`: (`batch_size`, 1, `num_hiddens`) query = torch.unsqueeze(hidden_state[-1], dim=1) # Shape of `context`: (`batch_size`, 1, `num_hiddens`) context = self.attention( query, enc_outputs, enc_outputs, enc_valid_lens) # Concatenate on the feature dimension x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1) # Reshape `x` as (1, `batch_size`, `embed_size` + `num_hiddens`) out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state) outputs.append(out) self._attention_weights.append(self.attention.attention_weights) # After fully-connected layer transformation, shape of `outputs`: # (`num_steps`, `batch_size`, `vocab_size`) outputs = self.dense(torch.cat(outputs, dim=0)) return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens] @property def attention_weights(self): return self._attention_weights .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class Seq2SeqAttentionDecoder(AttentionDecoder): def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs): super().__init__(**kwargs) self.attention = d2l.AdditiveAttention(num_hiddens, num_hiddens, num_hiddens, dropout) self.embedding = tf.keras.layers.Embedding(vocab_size, embed_size) self.rnn = tf.keras.layers.RNN(tf.keras.layers.StackedRNNCells( [tf.keras.layers.GRUCell(num_hiddens, dropout=dropout) for _ in range(num_layers)]), return_sequences=True, return_state=True) self.dense = tf.keras.layers.Dense(vocab_size) def init_state(self, enc_outputs, enc_valid_lens, *args): # Shape of `outputs`: (`batch_size`, `num_steps`, `num_hiddens`). # Shape of `hidden_state[0]`: (`num_layers`, `batch_size`, `num_hiddens`) outputs, hidden_state = enc_outputs return (outputs, hidden_state, enc_valid_lens) def call(self, X, state, **kwargs): # Shape of `enc_outputs`: (`batch_size`, `num_steps`, `num_hiddens`). # Shape of `hidden_state[0]`: (`num_layers`, `batch_size`, `num_hiddens`) enc_outputs, hidden_state, enc_valid_lens = state # Shape of the output `X`: (`num_steps`, `batch_size`, `embed_size`) X = self.embedding(X) # Input `X` has shape: (`batch_size`, `num_steps`) X = tf.transpose(X, perm=(1, 0, 2)) outputs, self._attention_weights = [], [] for x in X: # Shape of `query`: (`batch_size`, 1, `num_hiddens`) query = tf.expand_dims(hidden_state[-1], axis=1) # Shape of `context`: (`batch_size, 1, `num_hiddens`) context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens, **kwargs) # Concatenate on the feature dimension x = tf.concat((context, tf.expand_dims(x, axis=1)), axis=-1) out = self.rnn(x, hidden_state, **kwargs) hidden_state = out[1:] outputs.append(out[0]) self._attention_weights.append(self.attention.attention_weights) # After fully-connected layer transformation, shape of `outputs`: # (`batch_size`, `num_steps`, `vocab_size`) outputs = self.dense(tf.concat(outputs, axis=1)) return outputs, [enc_outputs, hidden_state, enc_valid_lens] @property def attention_weights(self): return self._attention_weights .. raw:: html
.. raw:: html
In the following, we test the implemented decoder with Bahdanau attention using a minibatch of 4 sequence inputs of 7 time steps. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2) encoder.initialize() decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2) decoder.initialize() X = np.zeros((4, 7)) # (`batch_size`, `num_steps`) state = decoder.init_state(encoder(X), None) output, state = decoder(X, state) output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output ((4, 7, 10), 3, (4, 7, 16), 1, (2, 4, 16)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2) encoder.eval() decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2) decoder.eval() X = torch.zeros((4, 7), dtype=torch.long) # (`batch_size`, `num_steps`) state = decoder.init_state(encoder(X), None) output, state = decoder(X, state) output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output (torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16])) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2) decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2) X = tf.zeros((4, 7)) state = decoder.init_state(encoder(X, training=False), None) output, state = decoder(X, state, training=False) output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output (TensorShape([4, 7, 10]), 3, TensorShape([4, 7, 16]), 2, TensorShape([4, 16])) .. raw:: html
.. raw:: html
Training -------- Similar to :numref:`sec_seq2seq_training`, here we specify hyperparemeters, instantiate an encoder and a decoder with Bahdanau attention, and train this model for machine translation. Due to the newly added attention mechanism, this training is much slower than that in :numref:`sec_seq2seq_training` without attention mechanisms. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1 batch_size, num_steps = 64, 10 lr, num_epochs, device = 0.005, 250, d2l.try_gpu() train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps) encoder = d2l.Seq2SeqEncoder( len(src_vocab), embed_size, num_hiddens, num_layers, dropout) decoder = Seq2SeqAttentionDecoder( len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout) net = d2l.EncoderDecoder(encoder, decoder) d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output loss 0.026, 3042.4 tokens/sec on gpu(0) .. figure:: output_bahdanau-attention_7f08d9_41_1.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1 batch_size, num_steps = 64, 10 lr, num_epochs, device = 0.005, 250, d2l.try_gpu() train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps) encoder = d2l.Seq2SeqEncoder( len(src_vocab), embed_size, num_hiddens, num_layers, dropout) decoder = Seq2SeqAttentionDecoder( len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout) net = d2l.EncoderDecoder(encoder, decoder) d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output loss 0.020, 5455.6 tokens/sec on cuda:0 .. figure:: output_bahdanau-attention_7f08d9_44_1.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1 batch_size, num_steps = 64, 10 lr, num_epochs, device = 0.005, 250, d2l.try_gpu() train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps) encoder = d2l.Seq2SeqEncoder( len(src_vocab), embed_size, num_hiddens, num_layers, dropout) decoder = Seq2SeqAttentionDecoder( len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout) net = d2l.EncoderDecoder(encoder, decoder) d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output loss 0.026, 531.0 tokens/sec on .. figure:: output_bahdanau-attention_7f08d9_47_1.svg .. raw:: html
.. raw:: html
After the model is trained, we use it to translate a few English sentences into French and compute their BLEU scores. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .'] fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .'] for eng, fra in zip(engs, fras): translation, dec_attention_weight_seq = d2l.predict_seq2seq( net, eng, src_vocab, tgt_vocab, num_steps, device, True) print(f'{eng} => {translation}, ', f'bleu {d2l.bleu(translation, fra, k=2):.3f}') .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output go . => va !, bleu 1.000 i lost . => j'ai perdu ., bleu 1.000 he's calm . => il est riche ., bleu 0.658 i'm home . => je suis chez moi ., bleu 1.000 .. raw:: latex \diilbookstyleinputcell .. code:: python attention_weights = np.concatenate([step[0][0][0] for step in dec_attention_weight_seq], 0 ).reshape((1, 1, -1, num_steps)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .'] fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .'] for eng, fra in zip(engs, fras): translation, dec_attention_weight_seq = d2l.predict_seq2seq( net, eng, src_vocab, tgt_vocab, num_steps, device, True) print(f'{eng} => {translation}, ', f'bleu {d2l.bleu(translation, fra, k=2):.3f}') .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output go . => va !, bleu 1.000 i lost . => j'ai perdu ., bleu 1.000 he's calm . => il est riche ., bleu 0.658 i'm home . => je suis chez moi ., bleu 1.000 .. raw:: latex \diilbookstyleinputcell .. code:: python attention_weights = torch.cat([step[0][0][0] for step in dec_attention_weight_seq], 0).reshape(( 1, 1, -1, num_steps)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .'] fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .'] for eng, fra in zip(engs, fras): translation, dec_attention_weight_seq = d2l.predict_seq2seq( net, eng, src_vocab, tgt_vocab, num_steps, True) print(f'{eng} => {translation}, ', f'bleu {d2l.bleu(translation, fra, k=2):.3f}') .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output go . => va !, bleu 1.000 i lost . => j'ai perdu ., bleu 1.000 he's calm . => il est paresseux ., bleu 0.658 i'm home . => je suis chez moi ., bleu 1.000 .. raw:: latex \diilbookstyleinputcell .. code:: python attention_weights = tf.reshape( tf.concat([step[0][0][0] for step in dec_attention_weight_seq], 0), (1, 1, -1, num_steps)) .. raw:: html
.. raw:: html
By visualizing the attention weights when translating the last English sentence, we can see that each query assigns non-uniform weights over key-value pairs. It shows that at each decoding step, different parts of the input sequences are selectively aggregated in the attention pooling. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python # Plus one to include the end-of-sequence token d2l.show_heatmaps( attention_weights[:, :, :, :len(engs[-1].split()) + 1], xlabel='Key positions', ylabel='Query positions') .. figure:: output_bahdanau-attention_7f08d9_68_0.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python # Plus one to include the end-of-sequence token d2l.show_heatmaps( attention_weights[:, :, :, :len(engs[-1].split()) + 1].cpu(), xlabel='Key positions', ylabel='Query positions') .. figure:: output_bahdanau-attention_7f08d9_71_0.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python # Plus one to include the end-of-sequence token d2l.show_heatmaps(attention_weights[:, :, :, :len(engs[-1].split()) + 1], xlabel='Key posistions', ylabel='Query posistions') .. figure:: output_bahdanau-attention_7f08d9_74_0.svg .. raw:: html
.. raw:: html
Summary ------- - When predicting a token, if not all the input tokens are relevant, the RNN encoder-decoder with Bahdanau attention selectively aggregates different parts of the input sequence. This is achieved by treating the context variable as an output of additive attention pooling. - In the RNN encoder-decoder, Bahdanau attention treats the decoder hidden state at the previous time step as the query, and the encoder hidden states at all the time steps as both the keys and values. Exercises --------- 1. Replace GRU with LSTM in the experiment. 2. Modify the experiment to replace the additive attention scoring function with the scaled dot-product. How does it influence the training efficiency? .. raw:: html
mxnetpytorch
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html