.. _sec_gru: Gated Recurrent Units (GRU) =========================== In :numref:`sec_bptt`, we discussed how gradients are calculated in RNNs. In particular we found that long products of matrices can lead to vanishing or exploding gradients. Let us briefly think about what such gradient anomalies mean in practice: - We might encounter a situation where an early observation is highly significant for predicting all future observations. Consider the somewhat contrived case where the first observation contains a checksum and the goal is to discern whether the checksum is correct at the end of the sequence. In this case, the influence of the first token is vital. We would like to have some mechanisms for storing vital early information in a *memory cell*. Without such a mechanism, we will have to assign a very large gradient to this observation, since it affects all the subsequent observations. - We might encounter situations where some tokens carry no pertinent observation. For instance, when parsing a web page there might be auxiliary HTML code that is irrelevant for the purpose of assessing the sentiment conveyed on the page. We would like to have some mechanism for *skipping* such tokens in the latent state representation. - We might encounter situations where there is a logical break between parts of a sequence. For instance, there might be a transition between chapters in a book, or a transition between a bear and a bull market for securities. In this case it would be nice to have a means of *resetting* our internal state representation. A number of methods have been proposed to address this. One of the earliest is long short-term memory :cite:`Hochreiter.Schmidhuber.1997` which we will discuss in :numref:`sec_lstm`. The gated recurrent unit (GRU) :cite:`Cho.Van-Merrienboer.Bahdanau.ea.2014` is a slightly more streamlined variant that often offers comparable performance and is significantly faster to compute :cite:`Chung.Gulcehre.Cho.ea.2014`. Due to its simplicity, let us start with the GRU. Gated Hidden State ------------------ The key distinction between vanilla RNNs and GRUs is that the latter support gating of the hidden state. This means that we have dedicated mechanisms for when a hidden state should be *updated* and also when it should be *reset*. These mechanisms are learned and they address the concerns listed above. For instance, if the first token is of great importance we will learn not to update the hidden state after the first observation. Likewise, we will learn to skip irrelevant temporary observations. Last, we will learn to reset the latent state whenever needed. We discuss this in detail below. Reset Gate and Update Gate ~~~~~~~~~~~~~~~~~~~~~~~~~~ The first thing we need to introduce are the *reset gate* and the *update gate*. We engineer them to be vectors with entries in :math:`(0, 1)` such that we can perform convex combinations. For instance, a reset gate would allow us to control how much of the previous state we might still want to remember. Likewise, an update gate would allow us to control how much of the new state is just a copy of the old state. We begin by engineering these gates. :numref:`fig_gru_1` illustrates the inputs for both the reset and update gates in a GRU, given the input of the current time step and the hidden state of the previous time step. The outputs of two gates are given by two fully-connected layers with a sigmoid activation function. .. _fig_gru_1: .. figure:: ../img/gru-1.svg Computing the reset gate and the update gate in a GRU model. Mathematically, for a given time step :math:`t`, suppose that the input is a minibatch :math:`\mathbf{X}_t \in \mathbb{R}^{n \times d}` (number of examples: :math:`n`, number of inputs: :math:`d`) and the hidden state of the previous time step is :math:`\mathbf{H}_{t-1} \in \mathbb{R}^{n \times h}` (number of hidden units: :math:`h`). Then, the reset gate :math:`\mathbf{R}_t \in \mathbb{R}^{n \times h}` and update gate :math:`\mathbf{Z}_t \in \mathbb{R}^{n \times h}` are computed as follows: .. math:: \begin{aligned} \mathbf{R}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xr} + \mathbf{H}_{t-1} \mathbf{W}_{hr} + \mathbf{b}_r),\\ \mathbf{Z}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xz} + \mathbf{H}_{t-1} \mathbf{W}_{hz} + \mathbf{b}_z), \end{aligned} where :math:`\mathbf{W}_{xr}, \mathbf{W}_{xz} \in \mathbb{R}^{d \times h}` and :math:`\mathbf{W}_{hr}, \mathbf{W}_{hz} \in \mathbb{R}^{h \times h}` are weight parameters and :math:`\mathbf{b}_r, \mathbf{b}_z \in \mathbb{R}^{1 \times h}` are biases. Note that broadcasting (see :numref:`subsec_broadcasting`) is triggered during the summation. We use sigmoid functions (as introduced in :numref:`sec_mlp`) to transform input values to the interval :math:`(0, 1)`. Candidate Hidden State ~~~~~~~~~~~~~~~~~~~~~~ Next, let us integrate the reset gate :math:`\mathbf{R}_t` with the regular latent state updating mechanism in :eq:`rnn_h_with_state`. It leads to the following *candidate hidden state* :math:`\tilde{\mathbf{H}}_t \in \mathbb{R}^{n \times h}` at time step :math:`t`: .. math:: \tilde{\mathbf{H}}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xh} + \left(\mathbf{R}_t \odot \mathbf{H}_{t-1}\right) \mathbf{W}_{hh} + \mathbf{b}_h), :label: gru_tilde_H where :math:`\mathbf{W}_{xh} \in \mathbb{R}^{d \times h}` and :math:`\mathbf{W}_{hh} \in \mathbb{R}^{h \times h}` are weight parameters, :math:`\mathbf{b}_h \in \mathbb{R}^{1 \times h}` is the bias, and the symbol :math:`\odot` is the Hadamard (elementwise) product operator. Here we use a nonlinearity in the form of tanh to ensure that the values in the candidate hidden state remain in the interval :math:`(-1, 1)`. The result is a *candidate* since we still need to incorporate the action of the update gate. Comparing with :eq:`rnn_h_with_state`, now the influence of the previous states can be reduced with the elementwise multiplication of :math:`\mathbf{R}_t` and :math:`\mathbf{H}_{t-1}` in :eq:`gru_tilde_H`. Whenever the entries in the reset gate :math:`\mathbf{R}_t` are close to 1, we recover a vanilla RNN such as in :eq:`rnn_h_with_state`. For all entries of the reset gate :math:`\mathbf{R}_t` that are close to 0, the candidate hidden state is the result of an MLP with :math:`\mathbf{X}_t` as the input. Any pre-existing hidden state is thus *reset* to defaults. :numref:`fig_gru_2` illustrates the computational flow after applying the reset gate. .. _fig_gru_2: .. figure:: ../img/gru-2.svg Computing the candidate hidden state in a GRU model. Hidden State ~~~~~~~~~~~~ Finally, we need to incorporate the effect of the update gate :math:`\mathbf{Z}_t`. This determines the extent to which the new hidden state :math:`\mathbf{H}_t \in \mathbb{R}^{n \times h}` is just the old state :math:`\mathbf{H}_{t-1}` and by how much the new candidate state :math:`\tilde{\mathbf{H}}_t` is used. The update gate :math:`\mathbf{Z}_t` can be used for this purpose, simply by taking elementwise convex combinations between both :math:`\mathbf{H}_{t-1}` and :math:`\tilde{\mathbf{H}}_t`. This leads to the final update equation for the GRU: .. math:: \mathbf{H}_t = \mathbf{Z}_t \odot \mathbf{H}_{t-1} + (1 - \mathbf{Z}_t) \odot \tilde{\mathbf{H}}_t. Whenever the update gate :math:`\mathbf{Z}_t` is close to 1, we simply retain the old state. In this case the information from :math:`\mathbf{X}_t` is essentially ignored, effectively skipping time step :math:`t` in the dependency chain. In contrast, whenever :math:`\mathbf{Z}_t` is close to 0, the new latent state :math:`\mathbf{H}_t` approaches the candidate latent state :math:`\tilde{\mathbf{H}}_t`. These designs can help us cope with the vanishing gradient problem in RNNs and better capture dependencies for sequences with large time step distances. For instance, if the update gate has been close to 1 for all the time steps of an entire subsequence, the old hidden state at the time step of its beginning will be easily retained and passed to its end, regardless of the length of the subsequence. :numref:`fig_gru_3` illustrates the computational flow after the update gate is in action. .. _fig_gru_3: .. figure:: ../img/gru-3.svg Computing the hidden state in a GRU model. In summary, GRUs have the following two distinguishing features: - Reset gates help capture short-term dependencies in sequences. - Update gates help capture long-term dependencies in sequences. Implementation from Scratch --------------------------- To gain a better understanding of the GRU model, let us implement it from scratch. We begin by reading the time machine dataset that we used in :numref:`sec_rnn_scratch`. The code for reading the dataset is given below. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python from mxnet import np, npx from mxnet.gluon import rnn from d2l import mxnet as d2l npx.set_np() batch_size, num_steps = 32, 35 train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import torch from torch import nn from d2l import torch as d2l batch_size, num_steps = 32, 35 train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import tensorflow as tf from d2l import tensorflow as d2l batch_size, num_steps = 32, 35 train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps) .. raw:: html
.. raw:: html
Initializing Model Parameters ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The next step is to initialize the model parameters. We draw the weights from a Gaussian distribution with standard deviation to be 0.01 and set the bias to 0. The hyperparameter ``num_hiddens`` defines the number of hidden units. We instantiate all weights and biases relating to the update gate, the reset gate, the candidate hidden state, and the output layer. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def get_params(vocab_size, num_hiddens, device): num_inputs = num_outputs = vocab_size def normal(shape): return np.random.normal(scale=0.01, size=shape, ctx=device) def three(): return (normal((num_inputs, num_hiddens)), normal((num_hiddens, num_hiddens)), np.zeros(num_hiddens, ctx=device)) W_xz, W_hz, b_z = three() # Update gate parameters W_xr, W_hr, b_r = three() # Reset gate parameters W_xh, W_hh, b_h = three() # Candidate hidden state parameters # Output layer parameters W_hq = normal((num_hiddens, num_outputs)) b_q = np.zeros(num_outputs, ctx=device) # Attach gradients params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q] for param in params: param.attach_grad() return params .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def get_params(vocab_size, num_hiddens, device): num_inputs = num_outputs = vocab_size def normal(shape): return torch.randn(size=shape, device=device)*0.01 def three(): return (normal((num_inputs, num_hiddens)), normal((num_hiddens, num_hiddens)), torch.zeros(num_hiddens, device=device)) W_xz, W_hz, b_z = three() # Update gate parameters W_xr, W_hr, b_r = three() # Reset gate parameters W_xh, W_hh, b_h = three() # Candidate hidden state parameters # Output layer parameters W_hq = normal((num_hiddens, num_outputs)) b_q = torch.zeros(num_outputs, device=device) # Attach gradients params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q] for param in params: param.requires_grad_(True) return params .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def get_params(vocab_size, num_hiddens): num_inputs = num_outputs = vocab_size def normal(shape): return tf.random.normal(shape=shape,stddev=0.01,mean=0,dtype=tf.float32) def three(): return (tf.Variable(normal((num_inputs, num_hiddens)), dtype=tf.float32), tf.Variable(normal((num_hiddens, num_hiddens)), dtype=tf.float32), tf.Variable(tf.zeros(num_hiddens), dtype=tf.float32)) W_xz, W_hz, b_z = three() # Update gate parameters W_xr, W_hr, b_r = three() # Reset gate parameters W_xh, W_hh, b_h = three() # Candidate hidden state parameters # Output layer parameters W_hq = tf.Variable(normal((num_hiddens, num_outputs)), dtype=tf.float32) b_q = tf.Variable(tf.zeros(num_outputs), dtype=tf.float32) params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q] return params .. raw:: html
.. raw:: html
Defining the Model ~~~~~~~~~~~~~~~~~~ Now we will define the hidden state initialization function ``init_gru_state``. Just like the ``init_rnn_state`` function defined in :numref:`sec_rnn_scratch`, this function returns a tensor with a shape (batch size, number of hidden units) whose values are all zeros. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def init_gru_state(batch_size, num_hiddens, device): return (np.zeros(shape=(batch_size, num_hiddens), ctx=device), ) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def init_gru_state(batch_size, num_hiddens, device): return (torch.zeros((batch_size, num_hiddens), device=device), ) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def init_gru_state(batch_size, num_hiddens): return (tf.zeros((batch_size, num_hiddens)), ) .. raw:: html
.. raw:: html
Now we are ready to define the GRU model. Its structure is the same as that of the basic RNN cell, except that the update equations are more complex. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def gru(inputs, state, params): W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params H, = state outputs = [] for X in inputs: Z = npx.sigmoid(np.dot(X, W_xz) + np.dot(H, W_hz) + b_z) R = npx.sigmoid(np.dot(X, W_xr) + np.dot(H, W_hr) + b_r) H_tilda = np.tanh(np.dot(X, W_xh) + np.dot(R * H, W_hh) + b_h) H = Z * H + (1 - Z) * H_tilda Y = np.dot(H, W_hq) + b_q outputs.append(Y) return np.concatenate(outputs, axis=0), (H,) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def gru(inputs, state, params): W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params H, = state outputs = [] for X in inputs: Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z) R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r) H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h) H = Z * H + (1 - Z) * H_tilda Y = H @ W_hq + b_q outputs.append(Y) return torch.cat(outputs, dim=0), (H,) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def gru(inputs, state, params): W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params H, = state outputs = [] for X in inputs: X = tf.reshape(X,[-1,W_xh.shape[0]]) Z = tf.sigmoid(tf.matmul(X, W_xz) + tf.matmul(H, W_hz) + b_z) R = tf.sigmoid(tf.matmul(X, W_xr) + tf.matmul(H, W_hr) + b_r) H_tilda = tf.tanh(tf.matmul(X, W_xh) + tf.matmul(R * H, W_hh) + b_h) H = Z * H + (1 - Z) * H_tilda Y = tf.matmul(H, W_hq) + b_q outputs.append(Y) return tf.concat(outputs, axis=0), (H,) .. raw:: html
.. raw:: html
Training and Predicting ~~~~~~~~~~~~~~~~~~~~~~~ Training and prediction work in exactly the same manner as in :numref:`sec_rnn_scratch`. After training, we print out the perplexity on the training set and the predicted sequence following the provided prefixes "time traveller" and "traveller", respectively. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu() num_epochs, lr = 500, 1 model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params, init_gru_state, gru) d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output perplexity 1.1, 11354.3 tokens/sec on gpu(0) time traveller for so it will be convenient to speak of himwas e travelleryou can show black is white by argument said filby .. figure:: output_gru_b77a34_51_1.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu() num_epochs, lr = 500, 1 model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params, init_gru_state, gru) d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output perplexity 1.1, 20459.5 tokens/sec on cuda:0 time traveller for so it will be convenient to speak of himwas e travelleryou can show black is white by argument said filby .. figure:: output_gru_b77a34_54_1.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python vocab_size, num_hiddens, device_name = len(vocab), 256, d2l.try_gpu()._device_name # defining tensorflow training strategy strategy = tf.distribute.OneDeviceStrategy(device_name) num_epochs, lr = 500, 1 with strategy.scope(): model = d2l.RNNModelScratch(len(vocab), num_hiddens, init_gru_state, gru, get_params) d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, strategy) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output perplexity 1.3, 4375.9 tokens/sec on /GPU:0 time traveller but now you begin to seethe object on the fire wi traveller with a slight you cand the time traveller you can .. figure:: output_gru_b77a34_57_1.svg .. raw:: html
.. raw:: html
Concise Implementation ---------------------- In high-level APIs, we can directly instantiate a GPU model. This encapsulates all the configuration detail that we made explicit above. The code is significantly faster as it uses compiled operators rather than Python for many details that we spelled out before. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python gru_layer = rnn.GRU(num_hiddens) model = d2l.RNNModel(gru_layer, len(vocab)) d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output perplexity 1.1, 171217.1 tokens/sec on gpu(0) time traveller for so it will be convenient to speak of himwas e travelleryou can show black is white by argument said filby .. figure:: output_gru_b77a34_63_1.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python num_inputs = vocab_size gru_layer = nn.GRU(num_inputs, num_hiddens) model = d2l.RNNModel(gru_layer, len(vocab)) model = model.to(device) d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output perplexity 1.0, 283912.8 tokens/sec on cuda:0 time travelleryou can show black is white by argument said filby travelleryou can show black is white by argument said filby .. figure:: output_gru_b77a34_66_1.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python gru_cell = tf.keras.layers.GRUCell(num_hiddens, kernel_initializer='glorot_uniform') gru_layer = tf.keras.layers.RNN(gru_cell, time_major=True, return_sequences=True, return_state=True) device_name = d2l.try_gpu()._device_name strategy = tf.distribute.OneDeviceStrategy(device_name) with strategy.scope(): model = d2l.RNNModel(gru_layer, vocab_size=len(vocab)) d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, strategy) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output perplexity 1.0, 6472.3 tokens/sec on /GPU:0 time travelleryou can show black is white by argument said filby travelleryou can show black is white by argument said filby .. figure:: output_gru_b77a34_69_1.svg .. raw:: html
.. raw:: html
Summary ------- - Gated RNNs can better capture dependencies for sequences with large time step distances. - Reset gates help capture short-term dependencies in sequences. - Update gates help capture long-term dependencies in sequences. - GRUs contain basic RNNs as their extreme case whenever the reset gate is switched on. They can also skip subsequences by turning on the update gate. Exercises --------- 1. Assume that we only want to use the input at time step :math:`t'` to predict the output at time step :math:`t > t'`. What are the best values for the reset and update gates for each time step? 2. Adjust the hyperparameters and analyze the their influence on running time, perplexity, and the output sequence. 3. Compare runtime, perplexity, and the output strings for ``rnn.RNN`` and ``rnn.GRU`` implementations with each other. 4. What happens if you implement only parts of a GRU, e.g., with only a reset gate or only an update gate? .. raw:: html
mxnetpytorch
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html