.. _sec_rnn-concise:
Concise Implementation of Recurrent Neural Networks
===================================================
While :numref:`sec_rnn_scratch` was instructive to see how RNNs are
implemented, this is not convenient or fast. This section will show how
to implement the same language model more efficiently using functions
provided by high-level APIs of a deep learning framework. We begin as
before by reading the time machine dataset.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from mxnet import np, npx
from mxnet.gluon import nn, rnn
from d2l import mxnet as d2l
npx.set_np()
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import tensorflow as tf
from d2l import tensorflow as d2l
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
.. raw:: html
.. raw:: html
Defining the Model
------------------
High-level APIs provide implementations of recurrent neural networks. We
construct the recurrent neural network layer ``rnn_layer`` with a single
hidden layer and 256 hidden units. In fact, we have not even discussed
yet what it means to have multiple layers---this will happen in
:numref:`sec_deep_rnn`. For now, suffice it to say that multiple
layers simply amount to the output of one layer of RNN being used as the
input for the next layer of RNN.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
num_hiddens = 256
rnn_layer = rnn.RNN(num_hiddens)
rnn_layer.initialize()
Initializing the hidden state is straightforward. We invoke the member
function ``begin_state``. This returns a list (``state``) that contains
an initial hidden state for each example in the minibatch, whose shape
is (number of hidden layers, batch size, number of hidden units). For
some models to be introduced later (e.g., long short-term memory), such
a list also contains other information.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
state = rnn_layer.begin_state(batch_size=batch_size)
len(state), state[0].shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(1, (1, 32, 256))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
num_hiddens = 256
rnn_layer = nn.RNN(len(vocab), num_hiddens)
We use a tensor to initialize the hidden state, whose shape is (number
of hidden layers, batch size, number of hidden units).
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
state = torch.zeros((1, batch_size, num_hiddens))
state.shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
torch.Size([1, 32, 256])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
num_hiddens = 256
rnn_cell = tf.keras.layers.SimpleRNNCell(num_hiddens,
kernel_initializer='glorot_uniform')
rnn_layer = tf.keras.layers.RNN(rnn_cell, time_major=True,
return_sequences=True, return_state=True)
state = rnn_cell.get_initial_state(batch_size=batch_size, dtype=tf.float32)
state.shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
TensorShape([32, 256])
.. raw:: html
.. raw:: html
With a hidden state and an input, we can compute the output with the
updated hidden state. It should be emphasized that the "output" (``Y``)
of ``rnn_layer`` does *not* involve computation of output layers: it
refers to the hidden state at *each* time step, and they can be used as
the input to the subsequent output layer.
.. raw:: html
.. raw:: html
Besides, the updated hidden state (``state_new``) returned by
``rnn_layer`` refers to the hidden state at the *last* time step of the
minibatch. It can be used to initialize the hidden state for the next
minibatch within an epoch in sequential partitioning. For multiple
hidden layers, the hidden state of each layer will be stored in this
variable (``state_new``). For some models to be introduced later (e.g.,
long short-term memory), this variable also contains other information.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = np.random.uniform(size=(num_steps, batch_size, len(vocab)))
Y, state_new = rnn_layer(X, state)
Y.shape, len(state_new), state_new[0].shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
((35, 32, 256), 1, (1, 32, 256))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = torch.rand(size=(num_steps, batch_size, len(vocab)))
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(torch.Size([35, 32, 256]), torch.Size([1, 32, 256]))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = tf.random.uniform((num_steps, batch_size, len(vocab)))
Y, state_new = rnn_layer(X, state)
Y.shape, len(state_new), state_new[0].shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(TensorShape([35, 32, 256]), 32, TensorShape([256]))
.. raw:: html
.. raw:: html
Similar to :numref:`sec_rnn_scratch`, we define an ``RNNModel`` class
for a complete RNN model. Note that ``rnn_layer`` only contains the
hidden recurrent layers, we need to create a separate output layer.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class RNNModel(nn.Block):
"""The RNN model."""
def __init__(self, rnn_layer, vocab_size, **kwargs):
super(RNNModel, self).__init__(**kwargs)
self.rnn = rnn_layer
self.vocab_size = vocab_size
self.dense = nn.Dense(vocab_size)
def forward(self, inputs, state):
X = npx.one_hot(inputs.T, self.vocab_size)
Y, state = self.rnn(X, state)
# The fully-connected layer will first change the shape of `Y` to
# (`num_steps` * `batch_size`, `num_hiddens`). Its output shape is
# (`num_steps` * `batch_size`, `vocab_size`).
output = self.dense(Y.reshape(-1, Y.shape[-1]))
return output, state
def begin_state(self, *args, **kwargs):
return self.rnn.begin_state(*args, **kwargs)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class RNNModel(nn.Module):
"""The RNN model."""
def __init__(self, rnn_layer, vocab_size, **kwargs):
super(RNNModel, self).__init__(**kwargs)
self.rnn = rnn_layer
self.vocab_size = vocab_size
self.num_hiddens = self.rnn.hidden_size
# If the RNN is bidirectional (to be introduced later),
# `num_directions` should be 2, else it should be 1.
if not self.rnn.bidirectional:
self.num_directions = 1
self.linear = nn.Linear(self.num_hiddens, self.vocab_size)
else:
self.num_directions = 2
self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)
def forward(self, inputs, state):
X = F.one_hot(inputs.T.long(), self.vocab_size)
X = X.to(torch.float32)
Y, state = self.rnn(X, state)
# The fully connected layer will first change the shape of `Y` to
# (`num_steps` * `batch_size`, `num_hiddens`). Its output shape is
# (`num_steps` * `batch_size`, `vocab_size`).
output = self.linear(Y.reshape((-1, Y.shape[-1])))
return output, state
def begin_state(self, device, batch_size=1):
if not isinstance(self.rnn, nn.LSTM):
# `nn.GRU` takes a tensor as hidden state
return torch.zeros((self.num_directions * self.rnn.num_layers,
batch_size, self.num_hiddens),
device=device)
else:
# `nn.LSTM` takes a tuple of hidden states
return (torch.zeros((
self.num_directions * self.rnn.num_layers,
batch_size, self.num_hiddens), device=device),
torch.zeros((
self.num_directions * self.rnn.num_layers,
batch_size, self.num_hiddens), device=device))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class RNNModel(tf.keras.layers.Layer):
def __init__(self, rnn_layer, vocab_size, **kwargs):
super(RNNModel, self).__init__(**kwargs)
self.rnn = rnn_layer
self.vocab_size = vocab_size
self.dense = tf.keras.layers.Dense(vocab_size)
def call(self, inputs, state):
X = tf.one_hot(tf.transpose(inputs), self.vocab_size)
# Later RNN like `tf.keras.layers.LSTMCell` return more than two values
Y, *state = self.rnn(X, state)
output = self.dense(tf.reshape(Y, (-1, Y.shape[-1])))
return output, state
def begin_state(self, *args, **kwargs):
return self.rnn.cell.get_initial_state(*args, **kwargs)
.. raw:: html
.. raw:: html
Training and Predicting
-----------------------
Before training the model, let us make a prediction with the a model
that has random weights.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
device = d2l.try_gpu()
net = RNNModel(rnn_layer, len(vocab))
net.initialize(force_reinit=True, ctx=device)
d2l.predict_ch8('time traveller', 10, net, vocab, device)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'time travellervmoopwrrrr'
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
device = d2l.try_gpu()
net = RNNModel(rnn_layer, vocab_size=len(vocab))
net = net.to(device)
d2l.predict_ch8('time traveller', 10, net, vocab, device)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'time travellerllllllllll'
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
device_name = d2l.try_gpu()._device_name
strategy = tf.distribute.OneDeviceStrategy(device_name)
with strategy.scope():
net = RNNModel(rnn_layer, vocab_size=len(vocab))
d2l.predict_ch8('time traveller', 10, net, vocab)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'time travellerdbmoafoam'
.. raw:: html
.. raw:: html
As is quite obvious, this model does not work at all. Next, we call
``train_ch8`` with the same hyperparameters defined in
:numref:`sec_rnn_scratch` and train our model with high-level APIs.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
num_epochs, lr = 500, 1
d2l.train_ch8(net, train_iter, vocab, lr, num_epochs, device)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
perplexity 1.2, 145303.3 tokens/sec on gpu(0)
time traveller for so it will be convenient to sperk of my insoa
traveller held in his hand was a glittering this began to r
.. figure:: output_rnn-concise_eff2f4_68_1.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
num_epochs, lr = 500, 1
d2l.train_ch8(net, train_iter, vocab, lr, num_epochs, device)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
perplexity 1.3, 279664.0 tokens/sec on cuda:0
time traveller but now you begin to seethe object of my investig
travellerey thind to move ato the engron th you mere heregh
.. figure:: output_rnn-concise_eff2f4_71_1.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
num_epochs, lr = 500, 1
d2l.train_ch8(net, train_iter, vocab, lr, num_epochs, strategy)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
perplexity 1.4, 15631.2 tokens/sec on /GPU:0
time travellerit wo th y s callont fowhow scane of thedimensions
travellerthere whac orecrmereee said the time travellerit
.. figure:: output_rnn-concise_eff2f4_74_1.svg
.. raw:: html
.. raw:: html
Compared with the last section, this model achieves comparable
perplexity, albeit within a shorter period of time, due to the code
being more optimized by high-level APIs of the deep learning framework.
Summary
-------
- High-level APIs of the deep learning framework provides an
implementation of the RNN layer.
- The RNN layer of high-level APIs returns an output and an updated
hidden state, where the output does not involve output layer
computation.
- Using high-level APIs leads to faster RNN training than using its
implementation from scratch.
Exercises
---------
1. Can you make the RNN model overfit using the high-level APIs?
2. What happens if you increase the number of hidden layers in the RNN
model? Can you make the model work?
3. Implement the autoregressive model of :numref:`sec_sequence` using
an RNN.
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html