”. We will
consistently distinguish the terminology "BERT input sequence" from
other types of "sequences". For instance, one *BERT input sequence* may
include either one *text sequence* or two *text sequences*.
To distinguish text pairs, the learned segment embeddings
:math:`\mathbf{e}_A` and :math:`\mathbf{e}_B` are added to the token
embeddings of the first sequence and the second sequence, respectively.
For single text inputs, only :math:`\mathbf{e}_A` is used.
The following ``get_tokens_and_segments`` takes either one sentence or
two sentences as the input, then returns tokens of the BERT input
sequence and their corresponding segment IDs.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
def get_tokens_and_segments(tokens_a, tokens_b=None):
"""Get tokens of the BERT input sequence and their segment IDs."""
tokens = [''] + tokens_a + ['']
# 0 and 1 are marking segment A and B, respectively
segments = [0] * (len(tokens_a) + 2)
if tokens_b is not None:
tokens += tokens_b + ['']
segments += [1] * (len(tokens_b) + 1)
return tokens, segments
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
def get_tokens_and_segments(tokens_a, tokens_b=None):
"""Get tokens of the BERT input sequence and their segment IDs."""
tokens = [''] + tokens_a + ['']
# 0 and 1 are marking segment A and B, respectively
segments = [0] * (len(tokens_a) + 2)
if tokens_b is not None:
tokens += tokens_b + ['']
segments += [1] * (len(tokens_b) + 1)
return tokens, segments
.. raw:: html
.. raw:: html
BERT chooses the transformer encoder as its bidirectional architecture.
Common in the transformer encoder, positional embeddings are added at
every position of the BERT input sequence. However, different from the
original transformer encoder, BERT uses *learnable* positional
embeddings. To sum up, :numref:`fig_bert-input` shows that the
embeddings of the BERT input sequence are the sum of the token
embeddings, segment embeddings, and positional embeddings.
.. _fig_bert-input:
.. figure:: ../img/bert-input.svg
The embeddings of the BERT input sequence are the sum of the token
embeddings, segment embeddings, and positional embeddings.
The following ``BERTEncoder`` class is similar to the
``TransformerEncoder`` class as implemented in
:numref:`sec_transformer`. Different from ``TransformerEncoder``,
``BERTEncoder`` uses segment embeddings and learnable positional
embeddings.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class BERTEncoder(nn.Block):
"""BERT encoder."""
def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads,
num_layers, dropout, max_len=1000, **kwargs):
super(BERTEncoder, self).__init__(**kwargs)
self.token_embedding = nn.Embedding(vocab_size, num_hiddens)
self.segment_embedding = nn.Embedding(2, num_hiddens)
self.blks = nn.Sequential()
for _ in range(num_layers):
self.blks.add(d2l.EncoderBlock(
num_hiddens, ffn_num_hiddens, num_heads, dropout, True))
# In BERT, positional embeddings are learnable, thus we create a
# parameter of positional embeddings that are long enough
self.pos_embedding = self.params.get('pos_embedding',
shape=(1, max_len, num_hiddens))
def forward(self, tokens, segments, valid_lens):
# Shape of `X` remains unchanged in the following code snippet:
# (batch size, max sequence length, `num_hiddens`)
X = self.token_embedding(tokens) + self.segment_embedding(segments)
X = X + self.pos_embedding.data(ctx=X.ctx)[:, :X.shape[1], :]
for blk in self.blks:
X = blk(X, valid_lens)
return X
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class BERTEncoder(nn.Module):
"""BERT encoder."""
def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,
ffn_num_hiddens, num_heads, num_layers, dropout,
max_len=1000, key_size=768, query_size=768, value_size=768,
**kwargs):
super(BERTEncoder, self).__init__(**kwargs)
self.token_embedding = nn.Embedding(vocab_size, num_hiddens)
self.segment_embedding = nn.Embedding(2, num_hiddens)
self.blks = nn.Sequential()
for i in range(num_layers):
self.blks.add_module(f"{i}", d2l.EncoderBlock(
key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, dropout, True))
# In BERT, positional embeddings are learnable, thus we create a
# parameter of positional embeddings that are long enough
self.pos_embedding = nn.Parameter(torch.randn(1, max_len,
num_hiddens))
def forward(self, tokens, segments, valid_lens):
# Shape of `X` remains unchanged in the following code snippet:
# (batch size, max sequence length, `num_hiddens`)
X = self.token_embedding(tokens) + self.segment_embedding(segments)
X = X + self.pos_embedding.data[:, :X.shape[1], :]
for blk in self.blks:
X = blk(X, valid_lens)
return X
.. raw:: html
.. raw:: html
Suppose that the vocabulary size is 10000. To demonstrate forward
inference of ``BERTEncoder``, let us create an instance of it and
initialize its parameters.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
vocab_size, num_hiddens, ffn_num_hiddens, num_heads = 10000, 768, 1024, 4
num_layers, dropout = 2, 0.2
encoder = BERTEncoder(vocab_size, num_hiddens, ffn_num_hiddens, num_heads,
num_layers, dropout)
encoder.initialize()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
vocab_size, num_hiddens, ffn_num_hiddens, num_heads = 10000, 768, 1024, 4
norm_shape, ffn_num_input, num_layers, dropout = [768], 768, 2, 0.2
encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape, ffn_num_input,
ffn_num_hiddens, num_heads, num_layers, dropout)
.. raw:: html
.. raw:: html
We define ``tokens`` to be 2 BERT input sequences of length 8, where
each token is an index of the vocabulary. The forward inference of
``BERTEncoder`` with the input ``tokens`` returns the encoded result
where each token is represented by a vector whose length is predefined
by the hyperparameter ``num_hiddens``. This hyperparameter is usually
referred to as the *hidden size* (number of hidden units) of the
transformer encoder.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
tokens = np.random.randint(0, vocab_size, (2, 8))
segments = np.array([[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1, 1]])
encoded_X = encoder(tokens, segments, None)
encoded_X.shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(2, 8, 768)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
tokens = torch.randint(0, vocab_size, (2, 8))
segments = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1, 1]])
encoded_X = encoder(tokens, segments, None)
encoded_X.shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
torch.Size([2, 8, 768])
.. raw:: html
.. raw:: html
.. _subsec_bert_pretraining_tasks:
Pretraining Tasks
-----------------
The forward inference of ``BERTEncoder`` gives the BERT representation
of each token of the input text and the inserted special tokens “”
and “”. Next, we will use these representations to compute the loss
function for pretraining BERT. The pretraining is composed of the
following two tasks: masked language modeling and next sentence
prediction.
.. _subsec_mlm:
Masked Language Modeling
~~~~~~~~~~~~~~~~~~~~~~~~
As illustrated in :numref:`sec_language_model`, a language model
predicts a token using the context on its left. To encode context
bidirectionally for representing each token, BERT randomly masks tokens
and uses tokens from the bidirectional context to predict the masked
tokens in a self-supervised fashion. This task is referred to as a
*masked language model*.
In this pretraining task, 15% of tokens will be selected at random as
the masked tokens for prediction. To predict a masked token without
cheating by using the label, one straightforward approach is to always
replace it with a special “” token in the BERT input sequence.
However, the artificial special token “” will never appear in
fine-tuning. To avoid such a mismatch between pretraining and
fine-tuning, if a token is masked for prediction (e.g., "great" is
selected to be masked and predicted in "this movie is great"), in the
input it will be replaced with:
- a special “” token for 80% of the time (e.g., "this movie is
great" becomes "this movie is ");
- a random token for 10% of the time (e.g., "this movie is great"
becomes "this movie is drink");
- the unchanged label token for 10% of the time (e.g., "this movie is
great" becomes "this movie is great").
Note that for 10% of 15% time a random token is inserted. This
occasional noise encourages BERT to be less biased towards the masked
token (especially when the label token remains unchanged) in its
bidirectional context encoding.
We implement the following ``MaskLM`` class to predict masked tokens in
the masked language model task of BERT pretraining. The prediction uses
a one-hidden-layer MLP (``self.mlp``). In forward inference, it takes
two inputs: the encoded result of ``BERTEncoder`` and the token
positions for prediction. The output is the prediction results at these
positions.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class MaskLM(nn.Block):
"""The masked language model task of BERT."""
def __init__(self, vocab_size, num_hiddens, **kwargs):
super(MaskLM, self).__init__(**kwargs)
self.mlp = nn.Sequential()
self.mlp.add(
nn.Dense(num_hiddens, flatten=False, activation='relu'))
self.mlp.add(nn.LayerNorm())
self.mlp.add(nn.Dense(vocab_size, flatten=False))
def forward(self, X, pred_positions):
num_pred_positions = pred_positions.shape[1]
pred_positions = pred_positions.reshape(-1)
batch_size = X.shape[0]
batch_idx = np.arange(0, batch_size)
# Suppose that `batch_size` = 2, `num_pred_positions` = 3, then
# `batch_idx` is `np.array([0, 0, 0, 1, 1, 1])`
batch_idx = np.repeat(batch_idx, num_pred_positions)
masked_X = X[batch_idx, pred_positions]
masked_X = masked_X.reshape((batch_size, num_pred_positions, -1))
mlm_Y_hat = self.mlp(masked_X)
return mlm_Y_hat
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class MaskLM(nn.Module):
"""The masked language model task of BERT."""
def __init__(self, vocab_size, num_hiddens, num_inputs=768, **kwargs):
super(MaskLM, self).__init__(**kwargs)
self.mlp = nn.Sequential(nn.Linear(num_inputs, num_hiddens),
nn.ReLU(),
nn.LayerNorm(num_hiddens),
nn.Linear(num_hiddens, vocab_size))
def forward(self, X, pred_positions):
num_pred_positions = pred_positions.shape[1]
pred_positions = pred_positions.reshape(-1)
batch_size = X.shape[0]
batch_idx = torch.arange(0, batch_size)
# Suppose that `batch_size` = 2, `num_pred_positions` = 3, then
# `batch_idx` is `torch.tensor([0, 0, 0, 1, 1, 1])`
batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions)
masked_X = X[batch_idx, pred_positions]
masked_X = masked_X.reshape((batch_size, num_pred_positions, -1))
mlm_Y_hat = self.mlp(masked_X)
return mlm_Y_hat
.. raw:: html
.. raw:: html
To demonstrate the forward inference of ``MaskLM``, we create its
instance ``mlm`` and initialize it. Recall that ``encoded_X`` from the
forward inference of ``BERTEncoder`` represents 2 BERT input sequences.
We define ``mlm_positions`` as the 3 indices to predict in either BERT
input sequence of ``encoded_X``. The forward inference of ``mlm``
returns prediction results ``mlm_Y_hat`` at all the masked positions
``mlm_positions`` of ``encoded_X``. For each prediction, the size of the
result is equal to the vocabulary size.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
mlm = MaskLM(vocab_size, num_hiddens)
mlm.initialize()
mlm_positions = np.array([[1, 5, 2], [6, 1, 5]])
mlm_Y_hat = mlm(encoded_X, mlm_positions)
mlm_Y_hat.shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(2, 3, 10000)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
mlm = MaskLM(vocab_size, num_hiddens)
mlm_positions = torch.tensor([[1, 5, 2], [6, 1, 5]])
mlm_Y_hat = mlm(encoded_X, mlm_positions)
mlm_Y_hat.shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
torch.Size([2, 3, 10000])
.. raw:: html
.. raw:: html
With the ground truth labels ``mlm_Y`` of the predicted tokens
``mlm_Y_hat`` under masks, we can calculate the cross-entropy loss of
the masked language model task in BERT pretraining.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
mlm_Y = np.array([[7, 8, 9], [10, 20, 30]])
loss = gluon.loss.SoftmaxCrossEntropyLoss()
mlm_l = loss(mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y.reshape(-1))
mlm_l.shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(6,)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
mlm_Y = torch.tensor([[7, 8, 9], [10, 20, 30]])
loss = nn.CrossEntropyLoss(reduction='none')
mlm_l = loss(mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y.reshape(-1))
mlm_l.shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
torch.Size([6])
.. raw:: html
.. raw:: html
.. _subsec_nsp:
Next Sentence Prediction
~~~~~~~~~~~~~~~~~~~~~~~~
Although masked language modeling is able to encode bidirectional
context for representing words, it does not explicitly model the logical
relationship between text pairs. To help understand the relationship
between two text sequences, BERT considers a binary classification task,
*next sentence prediction*, in its pretraining. When generating sentence
pairs for pretraining, for half of the time they are indeed consecutive
sentences with the label "True"; while for the other half of the time
the second sentence is randomly sampled from the corpus with the label
"False".
The following ``NextSentencePred`` class uses a one-hidden-layer MLP to
predict whether the second sentence is the next sentence of the first in
the BERT input sequence. Due to self-attention in the transformer
encoder, the BERT representation of the special token “” encodes
both the two sentences from the input. Hence, the output layer
(``self.output``) of the MLP classifier takes ``X`` as the input, where
``X`` is the output of the MLP hidden layer whose input is the encoded
“” token.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class NextSentencePred(nn.Block):
"""The next sentence prediction task of BERT."""
def __init__(self, **kwargs):
super(NextSentencePred, self).__init__(**kwargs)
self.output = nn.Dense(2)
def forward(self, X):
# `X` shape: (batch size, `num_hiddens`)
return self.output(X)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class NextSentencePred(nn.Module):
"""The next sentence prediction task of BERT."""
def __init__(self, num_inputs, **kwargs):
super(NextSentencePred, self).__init__(**kwargs)
self.output = nn.Linear(num_inputs, 2)
def forward(self, X):
# `X` shape: (batch size, `num_hiddens`)
return self.output(X)
.. raw:: html
.. raw:: html
We can see that the forward inference of an ``NextSentencePred``
instance returns binary predictions for each BERT input sequence.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
nsp = NextSentencePred()
nsp.initialize()
nsp_Y_hat = nsp(encoded_X)
nsp_Y_hat.shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(2, 2)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
# PyTorch by default won't flatten the tensor as seen in mxnet where, if
# flatten=True, all but the first axis of input data are collapsed together
encoded_X = torch.flatten(encoded_X, start_dim=1)
# input_shape for NSP: (batch size, `num_hiddens`)
nsp = NextSentencePred(encoded_X.shape[-1])
nsp_Y_hat = nsp(encoded_X)
nsp_Y_hat.shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
torch.Size([2, 2])
.. raw:: html
.. raw:: html
The cross-entropy loss of the 2 binary classifications can also be
computed.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
nsp_y = np.array([0, 1])
nsp_l = loss(nsp_Y_hat, nsp_y)
nsp_l.shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(2,)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
nsp_y = torch.tensor([0, 1])
nsp_l = loss(nsp_Y_hat, nsp_y)
nsp_l.shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
torch.Size([2])
.. raw:: html
.. raw:: html
It is noteworthy that all the labels in both the aforementioned
pretraining tasks can be trivially obtained from the pretraining corpus
without manual labeling effort. The original BERT has been pretrained on
the concatenation of BookCorpus :cite:`Zhu.Kiros.Zemel.ea.2015` and
English Wikipedia. These two text corpora are huge: they have 800
million words and 2.5 billion words, respectively.
Putting All Things Together
---------------------------
When pretraining BERT, the final loss function is a linear combination
of both the loss functions for masked language modeling and next
sentence prediction. Now we can define the ``BERTModel`` class by
instantiating the three classes ``BERTEncoder``, ``MaskLM``, and
``NextSentencePred``. The forward inference returns the encoded BERT
representations ``encoded_X``, predictions of masked language modeling
``mlm_Y_hat``, and next sentence predictions ``nsp_Y_hat``.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class BERTModel(nn.Block):
"""The BERT model."""
def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads,
num_layers, dropout, max_len=1000):
super(BERTModel, self).__init__()
self.encoder = BERTEncoder(vocab_size, num_hiddens, ffn_num_hiddens,
num_heads, num_layers, dropout, max_len)
self.hidden = nn.Dense(num_hiddens, activation='tanh')
self.mlm = MaskLM(vocab_size, num_hiddens)
self.nsp = NextSentencePred()
def forward(self, tokens, segments, valid_lens=None, pred_positions=None):
encoded_X = self.encoder(tokens, segments, valid_lens)
if pred_positions is not None:
mlm_Y_hat = self.mlm(encoded_X, pred_positions)
else:
mlm_Y_hat = None
# The hidden layer of the MLP classifier for next sentence prediction.
# 0 is the index of the '' token
nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))
return encoded_X, mlm_Y_hat, nsp_Y_hat
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class BERTModel(nn.Module):
"""The BERT model."""
def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,
ffn_num_hiddens, num_heads, num_layers, dropout,
max_len=1000, key_size=768, query_size=768, value_size=768,
hid_in_features=768, mlm_in_features=768,
nsp_in_features=768):
super(BERTModel, self).__init__()
self.encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, num_layers,
dropout, max_len=max_len, key_size=key_size,
query_size=query_size, value_size=value_size)
self.hidden = nn.Sequential(nn.Linear(hid_in_features, num_hiddens),
nn.Tanh())
self.mlm = MaskLM(vocab_size, num_hiddens, mlm_in_features)
self.nsp = NextSentencePred(nsp_in_features)
def forward(self, tokens, segments, valid_lens=None, pred_positions=None):
encoded_X = self.encoder(tokens, segments, valid_lens)
if pred_positions is not None:
mlm_Y_hat = self.mlm(encoded_X, pred_positions)
else:
mlm_Y_hat = None
# The hidden layer of the MLP classifier for next sentence prediction.
# 0 is the index of the '' token
nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))
return encoded_X, mlm_Y_hat, nsp_Y_hat
.. raw:: html
.. raw:: html
Summary
-------
- Word embedding models such as word2vec and GloVe are
context-independent. They assign the same pretrained vector to the
same word regardless of the context of the word (if any). It is hard
for them to handle well polysemy or complex semantics in natural
languages.
- For context-sensitive word representations such as ELMo and GPT,
representations of words depend on their contexts.
- ELMo encodes context bidirectionally but uses task-specific
architectures (however, it is practically non-trivial to craft a
specific architecture for every natural language processing task);
while GPT is task-agnostic but encodes context left-to-right.
- BERT combines the best of both worlds: it encodes context
bidirectionally and requires minimal architecture changes for a wide
range of natural language processing tasks.
- The embeddings of the BERT input sequence are the sum of the token
embeddings, segment embeddings, and positional embeddings.
- Pretraining BERT is composed of two tasks: masked language modeling
and next sentence prediction. The former is able to encode
bidirectional context for representing words, while the latter
explicitly models the logical relationship between text pairs.
Exercises
---------
1. Why does BERT succeed?
2. All other things being equal, will a masked language model require
more or fewer pretraining steps to converge than a left-to-right
language model? Why?
3. In the original implementation of BERT, the positionwise feed-forward
network in ``BERTEncoder`` (via ``d2l.EncoderBlock``) and the
fully-connected layer in ``MaskLM`` both use the Gaussian error
linear unit (GELU) :cite:`Hendrycks.Gimpel.2016` as the activation
function. Research into the difference between GELU and ReLU.
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html