"
tokens that represent rare (unknown) words.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
vocab = d2l.Vocab(sentences, min_freq=10)
f'vocab size: {len(vocab)}'
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'vocab size: 6719'
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
vocab = d2l.Vocab(sentences, min_freq=10)
f'vocab size: {len(vocab)}'
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'vocab size: 6719'
.. raw:: html
.. raw:: html
Subsampling
-----------
Text data typically have high-frequency words such as "the", "a", and
"in": they may even occur billions of times in very large corpora.
However, these words often co-occur with many different words in context
windows, providing little useful signals. For instance, consider the
word "chip" in a context window: intuitively its co-occurrence with a
low-frequency word "intel" is more useful in training than the
co-occurrence with a high-frequency word "a". Moreover, training with
vast amounts of (high-frequency) words is slow. Thus, when training word
embedding models, high-frequency words can be *subsampled*
:cite:`Mikolov.Sutskever.Chen.ea.2013`. Specifically, each indexed
word :math:`w_i` in the dataset will be discarded with probability
.. math:: P(w_i) = \max\left(1 - \sqrt{\frac{t}{f(w_i)}}, 0\right),
where :math:`f(w_i)` is the ratio of the number of words :math:`w_i` to
the total number of words in the dataset, and the constant :math:`t` is
a hyperparameter (:math:`10^{-4}` in the experiment). We can see that
only when the relative frequency :math:`f(w_i) > t` can the
(high-frequency) word :math:`w_i` be discarded, and the higher the
relative frequency of the word, the greater the probability of being
discarded.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
def subsample(sentences, vocab):
"""Subsample high-frequency words."""
# Exclude unknown tokens ''
sentences = [[token for token in line if vocab[token] != vocab.unk]
for line in sentences]
counter = d2l.count_corpus(sentences)
num_tokens = sum(counter.values())
# Return True if `token` is kept during subsampling
def keep(token):
return(random.uniform(0, 1) <
math.sqrt(1e-4 / counter[token] * num_tokens))
return ([[token for token in line if keep(token)] for line in sentences],
counter)
subsampled, counter = subsample(sentences, vocab)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
def subsample(sentences, vocab):
"""Subsample high-frequency words."""
# Exclude unknown tokens ''
sentences = [[token for token in line if vocab[token] != vocab.unk]
for line in sentences]
counter = d2l.count_corpus(sentences)
num_tokens = sum(counter.values())
# Return True if `token` is kept during subsampling
def keep(token):
return(random.uniform(0, 1) <
math.sqrt(1e-4 / counter[token] * num_tokens))
return ([[token for token in line if keep(token)] for line in sentences],
counter)
subsampled, counter = subsample(sentences, vocab)
.. raw:: html
.. raw:: html
The following code snippet plots the histogram of the number of tokens
per sentence before and after subsampling. As expected, subsampling
significantly shortens sentences by dropping high-frequency words, which
will lead to training speedup.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_list_len_pair_hist(['origin', 'subsampled'], '# tokens per sentence',
'count', sentences, subsampled);
.. figure:: output_word-embedding-dataset_f77071_39_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_list_len_pair_hist(['origin', 'subsampled'], '# tokens per sentence',
'count', sentences, subsampled);
.. figure:: output_word-embedding-dataset_f77071_42_0.svg
.. raw:: html
.. raw:: html
For individual tokens, the sampling rate of the high-frequency word
"the" is less than 1/20.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def compare_counts(token):
return (f'# of "{token}": '
f'before={sum([l.count(token) for l in sentences])}, '
f'after={sum([l.count(token) for l in subsampled])}')
compare_counts('the')
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'# of "the": before=50770, after=1995'
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def compare_counts(token):
return (f'# of "{token}": '
f'before={sum([l.count(token) for l in sentences])}, '
f'after={sum([l.count(token) for l in subsampled])}')
compare_counts('the')
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'# of "the": before=50770, after=2028'
.. raw:: html
.. raw:: html
In contrast, low-frequency words "join" are completely kept.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
compare_counts('join')
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'# of "join": before=45, after=45'
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
compare_counts('join')
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'# of "join": before=45, after=45'
.. raw:: html
.. raw:: html
After subsampling, we map tokens to their indices for the corpus.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
corpus = [vocab[line] for line in subsampled]
corpus[:3]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[[], [71, 392, 2115, 145, 406], [22, 5277, 3054, 1580]]
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
corpus = [vocab[line] for line in subsampled]
corpus[:3]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[[], [2115, 274, 406], [5277, 3054, 1580, 95]]
.. raw:: html
.. raw:: html
Extracting Center Words and Context Words
-----------------------------------------
The following ``get_centers_and_contexts`` function extracts all the
center words and their context words from ``corpus``. It uniformly
samples an integer between 1 and ``max_window_size`` at random as the
context window size. For any center word, those words whose distance
from it does not exceed the sampled context window size are its context
words.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
def get_centers_and_contexts(corpus, max_window_size):
"""Return center words and context words in skip-gram."""
centers, contexts = [], []
for line in corpus:
# To form a "center word--context word" pair, each sentence needs to
# have at least 2 words
if len(line) < 2:
continue
centers += line
for i in range(len(line)): # Context window centered at `i`
window_size = random.randint(1, max_window_size)
indices = list(range(max(0, i - window_size),
min(len(line), i + 1 + window_size)))
# Exclude the center word from the context words
indices.remove(i)
contexts.append([line[idx] for idx in indices])
return centers, contexts
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
def get_centers_and_contexts(corpus, max_window_size):
"""Return center words and context words in skip-gram."""
centers, contexts = [], []
for line in corpus:
# To form a "center word--context word" pair, each sentence needs to
# have at least 2 words
if len(line) < 2:
continue
centers += line
for i in range(len(line)): # Context window centered at `i`
window_size = random.randint(1, max_window_size)
indices = list(range(max(0, i - window_size),
min(len(line), i + 1 + window_size)))
# Exclude the center word from the context words
indices.remove(i)
contexts.append([line[idx] for idx in indices])
return centers, contexts
.. raw:: html
.. raw:: html
Next, we create an artificial dataset containing two sentences of 7 and
3 words, respectively. Let the maximum context window size be 2 and
print all the center words and their context words.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
tiny_dataset = [list(range(7)), list(range(7, 10))]
print('dataset', tiny_dataset)
for center, context in zip(*get_centers_and_contexts(tiny_dataset, 2)):
print('center', center, 'has contexts', context)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
dataset [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]
center 0 has contexts [1, 2]
center 1 has contexts [0, 2, 3]
center 2 has contexts [1, 3]
center 3 has contexts [1, 2, 4, 5]
center 4 has contexts [3, 5]
center 5 has contexts [3, 4, 6]
center 6 has contexts [4, 5]
center 7 has contexts [8, 9]
center 8 has contexts [7, 9]
center 9 has contexts [7, 8]
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
tiny_dataset = [list(range(7)), list(range(7, 10))]
print('dataset', tiny_dataset)
for center, context in zip(*get_centers_and_contexts(tiny_dataset, 2)):
print('center', center, 'has contexts', context)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
dataset [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]
center 0 has contexts [1, 2]
center 1 has contexts [0, 2]
center 2 has contexts [1, 3]
center 3 has contexts [2, 4]
center 4 has contexts [3, 5]
center 5 has contexts [4, 6]
center 6 has contexts [5]
center 7 has contexts [8, 9]
center 8 has contexts [7, 9]
center 9 has contexts [7, 8]
.. raw:: html
.. raw:: html
When training on the PTB dataset, we set the maximum context window size
to 5. The following extracts all the center words and their context
words in the dataset.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
all_centers, all_contexts = get_centers_and_contexts(corpus, 5)
f'# center-context pairs: {sum([len(contexts) for contexts in all_contexts])}'
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'# center-context pairs: 1500885'
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
all_centers, all_contexts = get_centers_and_contexts(corpus, 5)
f'# center-context pairs: {sum([len(contexts) for contexts in all_contexts])}'
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'# center-context pairs: 1501296'
.. raw:: html
.. raw:: html
Negative Sampling
-----------------
We use negative sampling for approximate training. To sample noise words
according to a predefined distribution, we define the following
``RandomGenerator`` class, where the (possibly unnormalized) sampling
distribution is passed via the argument ``sampling_weights``.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class RandomGenerator:
"""Randomly draw among {1, ..., n} according to n sampling weights."""
def __init__(self, sampling_weights):
# Exclude
self.population = list(range(1, len(sampling_weights) + 1))
self.sampling_weights = sampling_weights
self.candidates = []
self.i = 0
def draw(self):
if self.i == len(self.candidates):
# Cache `k` random sampling results
self.candidates = random.choices(
self.population, self.sampling_weights, k=10000)
self.i = 0
self.i += 1
return self.candidates[self.i - 1]
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class RandomGenerator:
"""Randomly draw among {1, ..., n} according to n sampling weights."""
def __init__(self, sampling_weights):
# Exclude
self.population = list(range(1, len(sampling_weights) + 1))
self.sampling_weights = sampling_weights
self.candidates = []
self.i = 0
def draw(self):
if self.i == len(self.candidates):
# Cache `k` random sampling results
self.candidates = random.choices(
self.population, self.sampling_weights, k=10000)
self.i = 0
self.i += 1
return self.candidates[self.i - 1]
.. raw:: html
.. raw:: html
For example, we can draw 10 random variables :math:`X` among indices 1,
2, and 3 with sampling probabilities :math:`P(X=1)=2/9, P(X=2)=3/9`, and
:math:`P(X=3)=4/9` as follows.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
generator = RandomGenerator([2, 3, 4])
[generator.draw() for _ in range(10)]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[3, 3, 2, 2, 3, 1, 3, 1, 2, 3]
.. raw:: html
.. raw:: html
For a pair of center word and context word, we randomly sample ``K`` (5
in the experiment) noise words. According to the suggestions in the
word2vec paper, the sampling probability :math:`P(w)` of a noise word
:math:`w` is set to its relative frequency in the dictionary raised to
the power of 0.75 :cite:`Mikolov.Sutskever.Chen.ea.2013`.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
def get_negatives(all_contexts, vocab, counter, K):
"""Return noise words in negative sampling."""
# Sampling weights for words with indices 1, 2, ... (index 0 is the
# excluded unknown token) in the vocabulary
sampling_weights = [counter[vocab.to_tokens(i)]**0.75
for i in range(1, len(vocab))]
all_negatives, generator = [], RandomGenerator(sampling_weights)
for contexts in all_contexts:
negatives = []
while len(negatives) < len(contexts) * K:
neg = generator.draw()
# Noise words cannot be context words
if neg not in contexts:
negatives.append(neg)
all_negatives.append(negatives)
return all_negatives
all_negatives = get_negatives(all_contexts, vocab, counter, 5)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
def get_negatives(all_contexts, vocab, counter, K):
"""Return noise words in negative sampling."""
# Sampling weights for words with indices 1, 2, ... (index 0 is the
# excluded unknown token) in the vocabulary
sampling_weights = [counter[vocab.to_tokens(i)]**0.75
for i in range(1, len(vocab))]
all_negatives, generator = [], RandomGenerator(sampling_weights)
for contexts in all_contexts:
negatives = []
while len(negatives) < len(contexts) * K:
neg = generator.draw()
# Noise words cannot be context words
if neg not in contexts:
negatives.append(neg)
all_negatives.append(negatives)
return all_negatives
all_negatives = get_negatives(all_contexts, vocab, counter, 5)
.. raw:: html
.. raw:: html
.. _subsec_word2vec-minibatch-loading:
Loading Training Examples in Minibatches
----------------------------------------
After all the center words together with their context words and sampled
noise words are extracted, they will be transformed into minibatches of
examples that can be iteratively loaded during training.
In a minibatch, the :math:`i^\mathrm{th}` example includes a center word
and its :math:`n_i` context words and :math:`m_i` noise words. Due to
varying context window sizes, :math:`n_i+m_i` varies for different
:math:`i`. Thus, for each example we concatenate its context words and
noise words in the ``contexts_negatives`` variable, and pad zeros until
the concatenation length reaches :math:`\max_i n_i+m_i` (``max_len``).
To exclude paddings in the calculation of the loss, we define a mask
variable ``masks``. There is a one-to-one correspondence between
elements in ``masks`` and elements in ``contexts_negatives``, where
zeros (otherwise ones) in ``masks`` correspond to paddings in
``contexts_negatives``.
To distinguish between positive and negative examples, we separate
context words from noise words in ``contexts_negatives`` via a
``labels`` variable. Similar to ``masks``, there is also a one-to-one
correspondence between elements in ``labels`` and elements in
``contexts_negatives``, where ones (otherwise zeros) in ``labels``
correspond to context words (positive examples) in
``contexts_negatives``.
The above idea is implemented in the following ``batchify`` function.
Its input ``data`` is a list with length equal to the batch size, where
each element is an example consisting of the center word ``center``, its
context words ``context``, and its noise words ``negative``. This
function returns a minibatch that can be loaded for calculations during
training, such as including the mask variable.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
def batchify(data):
"""Return a minibatch of examples for skip-gram with negative sampling."""
max_len = max(len(c) + len(n) for _, c, n in data)
centers, contexts_negatives, masks, labels = [], [], [], []
for center, context, negative in data:
cur_len = len(context) + len(negative)
centers += [center]
contexts_negatives += [context + negative + [0] * (max_len - cur_len)]
masks += [[1] * cur_len + [0] * (max_len - cur_len)]
labels += [[1] * len(context) + [0] * (max_len - len(context))]
return (np.array(centers).reshape((-1, 1)), np.array(
contexts_negatives), np.array(masks), np.array(labels))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
def batchify(data):
"""Return a minibatch of examples for skip-gram with negative sampling."""
max_len = max(len(c) + len(n) for _, c, n in data)
centers, contexts_negatives, masks, labels = [], [], [], []
for center, context, negative in data:
cur_len = len(context) + len(negative)
centers += [center]
contexts_negatives += [context + negative + [0] * (max_len - cur_len)]
masks += [[1] * cur_len + [0] * (max_len - cur_len)]
labels += [[1] * len(context) + [0] * (max_len - len(context))]
return (torch.tensor(centers).reshape((-1, 1)), torch.tensor(
contexts_negatives), torch.tensor(masks), torch.tensor(labels))
.. raw:: html
.. raw:: html
Let us test this function using a minibatch of two examples.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
x_1 = (1, [2, 2], [3, 3, 3, 3])
x_2 = (1, [2, 2, 2], [3, 3])
batch = batchify((x_1, x_2))
names = ['centers', 'contexts_negatives', 'masks', 'labels']
for name, data in zip(names, batch):
print(name, '=', data)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
centers = [[1.]
[1.]]
contexts_negatives = [[2. 2. 3. 3. 3. 3.]
[2. 2. 2. 3. 3. 0.]]
masks = [[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 0.]]
labels = [[1. 1. 0. 0. 0. 0.]
[1. 1. 1. 0. 0. 0.]]
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
x_1 = (1, [2, 2], [3, 3, 3, 3])
x_2 = (1, [2, 2, 2], [3, 3])
batch = batchify((x_1, x_2))
names = ['centers', 'contexts_negatives', 'masks', 'labels']
for name, data in zip(names, batch):
print(name, '=', data)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
centers = tensor([[1],
[1]])
contexts_negatives = tensor([[2, 2, 3, 3, 3, 3],
[2, 2, 2, 3, 3, 0]])
masks = tensor([[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 0]])
labels = tensor([[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0]])
.. raw:: html
.. raw:: html
Putting All Things Together
---------------------------
Last, we define the ``load_data_ptb`` function that reads the PTB
dataset and returns the data iterator and the vocabulary.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
def load_data_ptb(batch_size, max_window_size, num_noise_words):
"""Download the PTB dataset and then load it into memory."""
sentences = read_ptb()
vocab = d2l.Vocab(sentences, min_freq=10)
subsampled, counter = subsample(sentences, vocab)
corpus = [vocab[line] for line in subsampled]
all_centers, all_contexts = get_centers_and_contexts(
corpus, max_window_size)
all_negatives = get_negatives(
all_contexts, vocab, counter, num_noise_words)
dataset = gluon.data.ArrayDataset(
all_centers, all_contexts, all_negatives)
data_iter = gluon.data.DataLoader(
dataset, batch_size, shuffle=True,batchify_fn=batchify,
num_workers=d2l.get_dataloader_workers())
return data_iter, vocab
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
def load_data_ptb(batch_size, max_window_size, num_noise_words):
"""Download the PTB dataset and then load it into memory."""
num_workers = d2l.get_dataloader_workers()
sentences = read_ptb()
vocab = d2l.Vocab(sentences, min_freq=10)
subsampled, counter = subsample(sentences, vocab)
corpus = [vocab[line] for line in subsampled]
all_centers, all_contexts = get_centers_and_contexts(
corpus, max_window_size)
all_negatives = get_negatives(
all_contexts, vocab, counter, num_noise_words)
class PTBDataset(torch.utils.data.Dataset):
def __init__(self, centers, contexts, negatives):
assert len(centers) == len(contexts) == len(negatives)
self.centers = centers
self.contexts = contexts
self.negatives = negatives
def __getitem__(self, index):
return (self.centers[index], self.contexts[index],
self.negatives[index])
def __len__(self):
return len(self.centers)
dataset = PTBDataset(all_centers, all_contexts, all_negatives)
data_iter = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True,
collate_fn=batchify,
num_workers=num_workers)
return data_iter, vocab
.. raw:: html
.. raw:: html
Let us print the first minibatch of the data iterator.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
data_iter, vocab = load_data_ptb(512, 5, 5)
for batch in data_iter:
for name, data in zip(names, batch):
print(name, 'shape:', data.shape)
break
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
centers shape: (512, 1)
contexts_negatives shape: (512, 60)
masks shape: (512, 60)
labels shape: (512, 60)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
data_iter, vocab = load_data_ptb(512, 5, 5)
for batch in data_iter:
for name, data in zip(names, batch):
print(name, 'shape:', data.shape)
break
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
centers shape: torch.Size([512, 1])
contexts_negatives shape: torch.Size([512, 60])
masks shape: torch.Size([512, 60])
labels shape: torch.Size([512, 60])
.. raw:: html
.. raw:: html
Summary
-------
- High-frequency words may not be so useful in training. We can
subsample them for speedup in training.
- For computational efficiency, we load examples in minibatches. We can
define other variables to distinguish paddings from non-paddings, and
positive examples from negative ones.
Exercises
---------
1. How does the running time of code in this section changes if not
using subsampling?
2. The ``RandomGenerator`` class caches ``k`` random sampling results.
Set ``k`` to other values and see how it affects the data loading
speed.
3. What other hyperparameters in the code of this section may affect the
data loading speed?
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html