9.1. Gated Recurrent Units (GRU)¶ Open the notebook in SageMaker Studio Lab
In Section 8.7, we discussed how gradients are calculated in RNNs. In particular we found that long products of matrices can lead to vanishing or exploding gradients. Let us briefly think about what such gradient anomalies mean in practice:
We might encounter a situation where an early observation is highly significant for predicting all future observations. Consider the somewhat contrived case where the first observation contains a checksum and the goal is to discern whether the checksum is correct at the end of the sequence. In this case, the influence of the first token is vital. We would like to have some mechanisms for storing vital early information in a memory cell. Without such a mechanism, we will have to assign a very large gradient to this observation, since it affects all the subsequent observations.
We might encounter situations where some tokens carry no pertinent observation. For instance, when parsing a web page there might be auxiliary HTML code that is irrelevant for the purpose of assessing the sentiment conveyed on the page. We would like to have some mechanism for skipping such tokens in the latent state representation.
We might encounter situations where there is a logical break between parts of a sequence. For instance, there might be a transition between chapters in a book, or a transition between a bear and a bull market for securities. In this case it would be nice to have a means of resetting our internal state representation.
A number of methods have been proposed to address this. One of the earliest is long short-term memory () which we will discuss in Section 9.2. The gated recurrent unit (GRU) () is a slightly more streamlined variant that often offers comparable performance and is significantly faster to compute (). Due to its simplicity, let us start with the GRU.
9.1.2. Implementation from Scratch¶
To gain a better understanding of the GRU model, let us implement it from scratch. We begin by reading the time machine dataset that we used in Section 8.5. The code for reading the dataset is given below.
from mxnet import np, npx
from mxnet.gluon import rnn
from d2l import mxnet as d2l
npx.set_np()
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
import torch
from torch import nn
from d2l import torch as d2l
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
import tensorflow as tf
from d2l import tensorflow as d2l
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
9.1.2.1. Initializing Model Parameters¶
The next step is to initialize the model parameters. We draw the weights
from a Gaussian distribution with standard deviation to be 0.01 and set
the bias to 0. The hyperparameter num_hiddens
defines the number of
hidden units. We instantiate all weights and biases relating to the
update gate, the reset gate, the candidate hidden state, and the output
layer.
def get_params(vocab_size, num_hiddens, device):
num_inputs = num_outputs = vocab_size
def normal(shape):
return np.random.normal(scale=0.01, size=shape, ctx=device)
def three():
return (normal((num_inputs, num_hiddens)),
normal((num_hiddens, num_hiddens)),
np.zeros(num_hiddens, ctx=device))
W_xz, W_hz, b_z = three() # Update gate parameters
W_xr, W_hr, b_r = three() # Reset gate parameters
W_xh, W_hh, b_h = three() # Candidate hidden state parameters
# Output layer parameters
W_hq = normal((num_hiddens, num_outputs))
b_q = np.zeros(num_outputs, ctx=device)
# Attach gradients
params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
for param in params:
param.attach_grad()
return params
def get_params(vocab_size, num_hiddens, device):
num_inputs = num_outputs = vocab_size
def normal(shape):
return torch.randn(size=shape, device=device)*0.01
def three():
return (normal((num_inputs, num_hiddens)),
normal((num_hiddens, num_hiddens)),
torch.zeros(num_hiddens, device=device))
W_xz, W_hz, b_z = three() # Update gate parameters
W_xr, W_hr, b_r = three() # Reset gate parameters
W_xh, W_hh, b_h = three() # Candidate hidden state parameters
# Output layer parameters
W_hq = normal((num_hiddens, num_outputs))
b_q = torch.zeros(num_outputs, device=device)
# Attach gradients
params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
for param in params:
param.requires_grad_(True)
return params
def get_params(vocab_size, num_hiddens):
num_inputs = num_outputs = vocab_size
def normal(shape):
return tf.random.normal(shape=shape,stddev=0.01,mean=0,dtype=tf.float32)
def three():
return (tf.Variable(normal((num_inputs, num_hiddens)), dtype=tf.float32),
tf.Variable(normal((num_hiddens, num_hiddens)), dtype=tf.float32),
tf.Variable(tf.zeros(num_hiddens), dtype=tf.float32))
W_xz, W_hz, b_z = three() # Update gate parameters
W_xr, W_hr, b_r = three() # Reset gate parameters
W_xh, W_hh, b_h = three() # Candidate hidden state parameters
# Output layer parameters
W_hq = tf.Variable(normal((num_hiddens, num_outputs)), dtype=tf.float32)
b_q = tf.Variable(tf.zeros(num_outputs), dtype=tf.float32)
params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
return params
9.1.2.2. Defining the Model¶
Now we will define the hidden state initialization function
init_gru_state
. Just like the init_rnn_state
function defined in
Section 8.5, this function returns a tensor with a shape
(batch size, number of hidden units) whose values are all zeros.
def init_gru_state(batch_size, num_hiddens, device):
return (np.zeros(shape=(batch_size, num_hiddens), ctx=device), )
def init_gru_state(batch_size, num_hiddens, device):
return (torch.zeros((batch_size, num_hiddens), device=device), )
def init_gru_state(batch_size, num_hiddens):
return (tf.zeros((batch_size, num_hiddens)), )
Now we are ready to define the GRU model. Its structure is the same as that of the basic RNN cell, except that the update equations are more complex.
def gru(inputs, state, params):
W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
H, = state
outputs = []
for X in inputs:
Z = npx.sigmoid(np.dot(X, W_xz) + np.dot(H, W_hz) + b_z)
R = npx.sigmoid(np.dot(X, W_xr) + np.dot(H, W_hr) + b_r)
H_tilda = np.tanh(np.dot(X, W_xh) + np.dot(R * H, W_hh) + b_h)
H = Z * H + (1 - Z) * H_tilda
Y = np.dot(H, W_hq) + b_q
outputs.append(Y)
return np.concatenate(outputs, axis=0), (H,)
def gru(inputs, state, params):
W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
H, = state
outputs = []
for X in inputs:
Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)
R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)
H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)
H = Z * H + (1 - Z) * H_tilda
Y = H @ W_hq + b_q
outputs.append(Y)
return torch.cat(outputs, dim=0), (H,)
def gru(inputs, state, params):
W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
H, = state
outputs = []
for X in inputs:
X = tf.reshape(X,[-1,W_xh.shape[0]])
Z = tf.sigmoid(tf.matmul(X, W_xz) + tf.matmul(H, W_hz) + b_z)
R = tf.sigmoid(tf.matmul(X, W_xr) + tf.matmul(H, W_hr) + b_r)
H_tilda = tf.tanh(tf.matmul(X, W_xh) + tf.matmul(R * H, W_hh) + b_h)
H = Z * H + (1 - Z) * H_tilda
Y = tf.matmul(H, W_hq) + b_q
outputs.append(Y)
return tf.concat(outputs, axis=0), (H,)
9.1.2.3. Training and Predicting¶
Training and prediction work in exactly the same manner as in Section 8.5. After training, we print out the perplexity on the training set and the predicted sequence following the provided prefixes “time traveller” and “traveller”, respectively.
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,
init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.1, 11354.3 tokens/sec on gpu(0)
time traveller for so it will be convenient to speak of himwas e
travelleryou can show black is white by argument said filby
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,
init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.1, 20459.5 tokens/sec on cuda:0
time traveller for so it will be convenient to speak of himwas e
travelleryou can show black is white by argument said filby
vocab_size, num_hiddens, device_name = len(vocab), 256, d2l.try_gpu()._device_name
# defining tensorflow training strategy
strategy = tf.distribute.OneDeviceStrategy(device_name)
num_epochs, lr = 500, 1
with strategy.scope():
model = d2l.RNNModelScratch(len(vocab), num_hiddens, init_gru_state, gru, get_params)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, strategy)
perplexity 1.3, 4375.9 tokens/sec on /GPU:0
time traveller but now you begin to seethe object on the fire wi
traveller with a slight you cand the time traveller you can
9.1.3. Concise Implementation¶
In high-level APIs, we can directly instantiate a GPU model. This encapsulates all the configuration detail that we made explicit above. The code is significantly faster as it uses compiled operators rather than Python for many details that we spelled out before.
gru_layer = rnn.GRU(num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.1, 171217.1 tokens/sec on gpu(0)
time traveller for so it will be convenient to speak of himwas e
travelleryou can show black is white by argument said filby
num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.0, 283912.8 tokens/sec on cuda:0
time travelleryou can show black is white by argument said filby
travelleryou can show black is white by argument said filby
gru_cell = tf.keras.layers.GRUCell(num_hiddens,
kernel_initializer='glorot_uniform')
gru_layer = tf.keras.layers.RNN(gru_cell, time_major=True,
return_sequences=True, return_state=True)
device_name = d2l.try_gpu()._device_name
strategy = tf.distribute.OneDeviceStrategy(device_name)
with strategy.scope():
model = d2l.RNNModel(gru_layer, vocab_size=len(vocab))
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, strategy)
perplexity 1.0, 6472.3 tokens/sec on /GPU:0
time travelleryou can show black is white by argument said filby
travelleryou can show black is white by argument said filby
9.1.4. Summary¶
Gated RNNs can better capture dependencies for sequences with large time step distances.
Reset gates help capture short-term dependencies in sequences.
Update gates help capture long-term dependencies in sequences.
GRUs contain basic RNNs as their extreme case whenever the reset gate is switched on. They can also skip subsequences by turning on the update gate.
9.1.5. Exercises¶
Assume that we only want to use the input at time step \(t'\) to predict the output at time step \(t > t'\). What are the best values for the reset and update gates for each time step?
Adjust the hyperparameters and analyze the their influence on running time, perplexity, and the output sequence.
Compare runtime, perplexity, and the output strings for
rnn.RNN
andrnn.GRU
implementations with each other.What happens if you implement only parts of a GRU, e.g., with only a reset gate or only an update gate?