.. _chapter_gru:
Gated Recurrent Units (GRU)
===========================
In the previous section we discussed how gradients are calculated in a
recurrent neural network. In particular we found that long products of
matrices can lead to vanishing or divergent gradients. Let’s briefly
think about what such gradient anomalies mean in practice:
- We might encounter a situation where an early observation is highly
significant for predicting all future observations. Consider the
somewhat contrived case where the first observation contains a
checksum and the goal is to discern whether the checksum is correct
at the end of the sequence. In this case the influence of the first
token is vital. We would like to have some mechanism for storing
vital early information in a *memory cell*. Without such a mechanism
we will have to assign a very large gradient to this observation,
since it affects all subsequent observations.
- We might encounter situations where some symbols carry no pertinent
observation. For instance, when parsing a webpage there might be
auxiliary HTML code that is irrelevant for the purpose of assessing
the sentiment conveyed on the page. We would like to have some
mechanism for *skipping such symbols* in the latent state
representation.
- We might encounter situations where there is a logical break between
parts of a sequence. For instance there might be a transition between
chapters in a book, a transition between a bear and a bull market for
securities, etc.; In this case it would be nice to have a means of
*resetting* our internal state representation.
A number of methods have been proposed to address this. One of the
earliest is the Long Short Term Memory (LSTM)
which we will discuss in
:numref:`chapter_lstm`. The Gated Recurrent Unit (GRU)
:cite:`Cho.Van-Merrienboer.Bahdanau.ea.2014` is a slightly more
streamlined variant that often offers comparable performance and is
significantly faster to compute. See also
:cite:`Chung.Gulcehre.Cho.ea.2014` for more details. Due to its
simplicity we start with the GRU.
Gating the Hidden State
-----------------------
The key distinction between regular RNNs and GRUs is that the latter
support gating of the hidden state. This means that we have dedicated
mechanisms for when the hidden state should be updated and also when it
should be reset. These mechanisms are learned and they address the
concerns listed above. For instance, if the first symbol is of great
importance we will learn not to update the hidden state after the first
observation. Likewise, we will learn to skip irrelevant temporary
observations. Lastly, we will learn to reset the latent state whenever
needed. We discuss this in detail below.
Reset Gates and Update Gates
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The first thing we need to introduce are reset and update gates. We
engineer them to be vectors with entries in :math:`(0,1)` such that we
can perform convex combinations, e.g. of a hidden state and an
alternative. For instance, a reset variable would allow us to control
how much of the previous state we might still want to remember.
Likewise, an update variable would allow us to control how much of the
new state is just a copy of the old state.
We begin by engineering gates to generate these variables. The figure
below illustrates the inputs for both reset and update gates in a GRU,
given the current time step input :math:`\mathbf{X}_t` and the hidden
state of the previous time step :math:`\mathbf{H}_{t-1}`. The output is
given by a fully connected layer with a sigmoid as its activation
function.
.. figure:: ../img/gru_1.svg
Reset and update gate in a GRU.
Here, we assume there are :math:`h` hidden units and, for a given time
step :math:`t`, the mini-batch input is
:math:`\mathbf{X}_t \in \mathbb{R}^{n \times d}` (number of examples:
:math:`n`, number of inputs: :math:`d`) and the hidden state of the last
time step is :math:`\mathbf{H}_{t-1} \in \mathbb{R}^{n \times h}`. Then,
the reset gate :math:`\mathbf{R}_t \in \mathbb{R}^{n \times h}` and
update gate :math:`\mathbf{Z}_t \in \mathbb{R}^{n \times h}` are
computed as follows:
.. math::
\begin{aligned}
\mathbf{R}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xr} + \mathbf{H}_{t-1} \mathbf{W}_{hr} + \mathbf{b}_r)\\
\mathbf{Z}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xz} + \mathbf{H}_{t-1} \mathbf{W}_{hz} + \mathbf{b}_z)
\end{aligned}
Here,
:math:`\mathbf{W}_{xr}, \mathbf{W}_{xz} \in \mathbb{R}^{d \times h}` and
:math:`\mathbf{W}_{hr}, \mathbf{W}_{hz} \in \mathbb{R}^{h \times h}` are
weight parameters and
:math:`\mathbf{b}_r, \mathbf{b}_z \in \mathbb{R}^{1 \times h}` are
biases. We use a sigmoid function (see e.g. refer to
:numref:`chapter_mlp` for a description) to transform values to the
interval :math:`(0,1)`.
Reset Gate in Action
~~~~~~~~~~~~~~~~~~~~
We begin by integrating the reset gate with a regular latent state
updating mechanism. In a conventional deep RNN we would have an update
of the form
.. math:: \mathbf{H}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xh} + \mathbf{H}_{t-1}\mathbf{W}_{hh} + \mathbf{b}_h).
This is essentially identical to the discussion of the previous section,
albeit with a nonlinearity in the form of :math:`\tanh` to ensure that
the values of the hidden state remain in the interval :math:`(-1, 1)`.
If we want to be able to reduce the influence of previous states we can
multiply :math:`\mathbf{H}_{t-1}` with :math:`\mathbf{R}_t` elementwise.
Whenever the entries in :math:`\mathbf{R}_t` are close to :math:`1` we
recover a conventional deep RNN. For all entries of :math:`\mathbf{R}_t`
that are close to :math:`0` the hidden state is the result of an MLP
with :math:`\mathbf{X}_t` as input. Any pre-existing hidden state is
thus ‘reset’ to defaults. This leads to the following candidate for a
new hidden state (it is a *candidate* since we still need to incorporate
the action of the update gate).
.. math:: \tilde{\mathbf{H}}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xh} + \left(\mathbf{R}_t \odot \mathbf{H}_{t-1}\right) \mathbf{W}_{hh} + \mathbf{b}_h)
The figure below illustrates the computational flow after applying the
reset gate. The symbol :math:`\odot` indicates pointwise multiplication
between tensors.
.. figure:: ../img/gru_2.svg
Candidate hidden state computation in a GRU. The multiplication is
carried out elementwise.
Update Gate in Action
~~~~~~~~~~~~~~~~~~~~~
Next we need to incorporate the effect of the update gate. This
determines the extent to which the new state :math:`\mathbf{H}_t` is
just the old state :math:`\mathbf{H}_{t-1}` and by how much the new
candidate state :math:`\tilde{\mathbf{H}}_t` is used. The gating
variable :math:`\mathbf{Z}_t` can be used for this purpose, simply by
taking elementwise convex combinations between both candidates. This
leads to the final update equation for the GRU.
.. math:: \mathbf{H}_t = \mathbf{Z}_t \odot \mathbf{H}_{t-1} + (1 - \mathbf{Z}_t) \odot \tilde{\mathbf{H}}_t.
.. figure:: ../img/gru_3.svg
Hidden state computation in a GRU. As before, the multiplication is
carried out elementwise.
Whenever the update gate is close to :math:`1` we simply retain the old
state. In this case the information from :math:`\mathbf{X}_t` is
essentially ignored, effectively skipping time step :math:`t` in the
dependency chain. Whenever it is close to :math:`0` the new latent state
:math:`\mathbf{H}_t` approaches the candidate latent state
:math:`\tilde{\mathbf{H}}_t`. These designs can help cope with the
vanishing gradient problem in RNNs and better capture dependencies for
time series with large time step distances. In summary GRUs have the
following two distinguishing features:
- Reset gates help capture short-term dependencies in time series.
- Update gates help capture long-term dependencies in time series.
Implementation from Scratch
---------------------------
To gain a better understanding of the model let us implement a GRU from
scratch.
Reading the Data Set
~~~~~~~~~~~~~~~~~~~~
We begin by reading *The Time Machine* corpus that we used in
:numref:`chapter_rnn_scratch`. The code for reading the data set is
given below:
.. code:: python
import d2l
from mxnet import nd
from mxnet.gluon import rnn
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
Initialize Model Parameters
~~~~~~~~~~~~~~~~~~~~~~~~~~~
The next step is to initialize the model parameters. We draw the weights
from a Gaussian with variance :math:`0.01` and set the bias to
:math:`0`. The hyper-parameter ``num_hiddens`` defines the number of
hidden units. We instantiate all terms relating to update and reset gate
and the candidate hidden state itself. Subsequently we attach gradients
to all parameters.
.. code:: python
def get_params(vocab_size, num_hiddens, ctx):
num_inputs = num_outputs = vocab_size
normal = lambda shape : nd.random.normal(scale=0.01, shape=shape, ctx=ctx)
three = lambda : (normal((num_inputs, num_hiddens)),
normal((num_hiddens, num_hiddens)),
nd.zeros(num_hiddens, ctx=ctx))
W_xz, W_hz, b_z = three() # Update gate parameter
W_xr, W_hr, b_r = three() # Reset gate parameter
W_xh, W_hh, b_h = three() # Candidate hidden state parameter
# Output layer parameters
W_hq = normal((num_hiddens, num_outputs))
b_q = nd.zeros(num_outputs, ctx=ctx)
# Create gradient
params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
for param in params:
param.attach_grad()
return params
Define the Model
~~~~~~~~~~~~~~~~
Now we will define the hidden state initialization function
``init_gru_state``. Just like the ``init_rnn_state`` function defined in
:numref:`chapter_rnn_scratch`, this function returns a tuple composed
of an NDArray with a shape (batch size, number of hidden units) and with
all values set to 0.
.. code:: python
def init_gru_state(batch_size, num_hiddens, ctx):
return (nd.zeros(shape=(batch_size, num_hiddens), ctx=ctx), )
Now we are ready to define the actual model. Its structure is the same
as the basic RNN cell, just that the update equations are more complex.
.. code:: python
def gru(inputs, state, params):
W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
H, = state
outputs = []
for X in inputs:
Z = nd.sigmoid(nd.dot(X, W_xz) + nd.dot(H, W_hz) + b_z)
R = nd.sigmoid(nd.dot(X, W_xr) + nd.dot(H, W_hr) + b_r)
H_tilda = nd.tanh(nd.dot(X, W_xh) + nd.dot(R * H, W_hh) + b_h)
H = Z * H + (1 - Z) * H_tilda
Y = nd.dot(H, W_hq) + b_q
outputs.append(Y)
return nd.concat(*outputs, dim=0), (H,)
Training and Prediction
~~~~~~~~~~~~~~~~~~~~~~~
Training and prediction work in exactly the same manner as before.
.. code:: python
vocab_size, num_hiddens, ctx = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, ctx, get_params,
init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, ctx)
.. parsed-literal::
:class: output
Perplexity 1.1, 13446 tokens/sec on gpu(0)
time traveller well thattime is only a kind of space here is a p
traveller it s against reason said filby what reason said
.. figure:: output_gru_11a47a_9_1.svg
Concise Implementation
----------------------
In Gluon, we can directly call the ``GRU`` class in the ``rnn`` module.
This encapsulates all the configuration details that we made explicit
above. The code is significantly faster as it uses compiled operators
rather than Python for many details that we spelled out in detail
before.
.. code:: python
gru_layer = rnn.GRU(num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, ctx)
.. parsed-literal::
:class: output
Perplexity 1.1, 60693 tokens/sec on gpu(0)
time traveller it s against reason said filby what reason said
traveller it s against reason said filby what reason said
.. figure:: output_gru_11a47a_11_1.svg
Summary
-------
- Gated recurrent neural networks are better at capturing dependencies
for time series with large time step distances.
- Reset gates help capture short-term dependencies in time series.
- Update gates help capture long-term dependencies in time series.
- GRUs contain basic RNNs as their extreme case whenever the reset gate
is switched on. They can ignore sequences as needed.
Exercises
---------
1. Compare runtimes, perplexity and the extracted strings for
``rnn.RNN`` and ``rnn.GRU`` implementations with each other.
2. Assume that we only want to use the input for time step :math:`t'` to
predict the output at time step :math:`t > t'`. What are the best
values for reset and update gates for each time step?
3. Adjust the hyper-parameters and observe and analyze the impact on
running time, perplexity, and the written lyrics.
4. What happens if you implement only parts of a GRU? That is, implement
a recurrent cell that only has a reset gate. Likewise, implement a
recurrent cell only with an update gate.
Scan the QR Code to `Discuss `__
-----------------------------------------------------------------
|image0|
.. |image0| image:: ../img/qr_gru.svg