.. _sec_lstm:
Long Short-Term Memory (LSTM)
=============================
The challenge to address long-term information preservation and
short-term input skipping in latent variable models has existed for a
long time. One of the earliest approaches to address this was the long
short-term memory (LSTM) :cite:`Hochreiter.Schmidhuber.1997`. It
shares many of the properties of the GRU. Interestingly, LSTMs have a
slightly more complex design than GRUs but predates GRUs by almost two
decades.
Gated Memory Cell
-----------------
Arguably LSTM's design is inspired by logic gates of a computer. LSTM
introduces a *memory cell* (or *cell* for short) that has the same shape
as the hidden state (some literatures consider the memory cell as a
special type of the hidden state), engineered to record additional
information. To control the memory cell we need a number of gates. One
gate is needed to read out the entries from the cell. We will refer to
this as the *output gate*. A second gate is needed to decide when to
read data into the cell. We refer to this as the *input gate*. Last, we
need a mechanism to reset the content of the cell, governed by a *forget
gate*. The motivation for such a design is the same as that of GRUs,
namely to be able to decide when to remember and when to ignore inputs
in the hidden state via a dedicated mechanism. Let us see how this works
in practice.
Input Gate, Forget Gate, and Output Gate
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Just like in GRUs, the data feeding into the LSTM gates are the input at
the current time step and the hidden state of the previous time step, as
illustrated in :numref:`lstm_0`. They are processed by three
fully-connected layers with a sigmoid activation function to compute the
values of the input, forget. and output gates. As a result, values of
the three gates are in the range of :math:`(0, 1)`.
.. _lstm_0:
.. figure:: ../img/lstm-0.svg
Computing the input gate, the forget gate, and the output gate in an
LSTM model.
Mathematically, suppose that there are :math:`h` hidden units, the batch
size is :math:`n`, and the number of inputs is :math:`d`. Thus, the
input is :math:`\mathbf{X}_t \in \mathbb{R}^{n \times d}` and the hidden
state of the previous time step is
:math:`\mathbf{H}_{t-1} \in \mathbb{R}^{n \times h}`. Correspondingly,
the gates at time step :math:`t` are defined as follows: the input gate
is :math:`\mathbf{I}_t \in \mathbb{R}^{n \times h}`, the forget gate is
:math:`\mathbf{F}_t \in \mathbb{R}^{n \times h}`, and the output gate is
:math:`\mathbf{O}_t \in \mathbb{R}^{n \times h}`. They are calculated as
follows:
.. math::
\begin{aligned}
\mathbf{I}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xi} + \mathbf{H}_{t-1} \mathbf{W}_{hi} + \mathbf{b}_i),\\
\mathbf{F}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xf} + \mathbf{H}_{t-1} \mathbf{W}_{hf} + \mathbf{b}_f),\\
\mathbf{O}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xo} + \mathbf{H}_{t-1} \mathbf{W}_{ho} + \mathbf{b}_o),
\end{aligned}
where
:math:`\mathbf{W}_{xi}, \mathbf{W}_{xf}, \mathbf{W}_{xo} \in \mathbb{R}^{d \times h}`
and
:math:`\mathbf{W}_{hi}, \mathbf{W}_{hf}, \mathbf{W}_{ho} \in \mathbb{R}^{h \times h}`
are weight parameters and
:math:`\mathbf{b}_i, \mathbf{b}_f, \mathbf{b}_o \in \mathbb{R}^{1 \times h}`
are bias parameters.
Candidate Memory Cell
~~~~~~~~~~~~~~~~~~~~~
Next we design the memory cell. Since we have not specified the action
of the various gates yet, we first introduce the *candidate* memory cell
:math:`\tilde{\mathbf{C}}_t \in \mathbb{R}^{n \times h}`. Its
computation is similar to that of the three gates described above, but
using a :math:`\tanh` function with a value range for :math:`(-1, 1)` as
the activation function. This leads to the following equation at time
step :math:`t`:
.. math:: \tilde{\mathbf{C}}_t = \text{tanh}(\mathbf{X}_t \mathbf{W}_{xc} + \mathbf{H}_{t-1} \mathbf{W}_{hc} + \mathbf{b}_c),
where :math:`\mathbf{W}_{xc} \in \mathbb{R}^{d \times h}` and
:math:`\mathbf{W}_{hc} \in \mathbb{R}^{h \times h}` are weight
parameters and :math:`\mathbf{b}_c \in \mathbb{R}^{1 \times h}` is a
bias parameter.
A quick illustration of the candidate memory cell is shown in
:numref:`lstm_1`.
.. _lstm_1:
.. figure:: ../img/lstm-1.svg
Computing the candidate memory cell in an LSTM model.
Memory Cell
~~~~~~~~~~~
In GRUs, we have a mechanism to govern input and forgetting (or
skipping). Similarly, in LSTMs we have two dedicated gates for such
purposes: the input gate :math:`\mathbf{I}_t` governs how much we take
new data into account via :math:`\tilde{\mathbf{C}}_t` and the forget
gate :math:`\mathbf{F}_t` addresses how much of the old memory cell
content :math:`\mathbf{C}_{t-1} \in \mathbb{R}^{n \times h}` we retain.
Using the same pointwise multiplication trick as before, we arrive at
the following update equation:
.. math:: \mathbf{C}_t = \mathbf{F}_t \odot \mathbf{C}_{t-1} + \mathbf{I}_t \odot \tilde{\mathbf{C}}_t.
If the forget gate is always approximately 1 and the input gate is
always approximately 0, the past memory cells :math:`\mathbf{C}_{t-1}`
will be saved over time and passed to the current time step. This design
is introduced to alleviate the vanishing gradient problem and to better
capture long range dependencies within sequences.
We thus arrive at the flow diagram in :numref:`lstm_2`.
.. _lstm_2:
.. figure:: ../img/lstm-2.svg
Computing the memory cell in an LSTM model.
Hidden State
~~~~~~~~~~~~
Last, we need to define how to compute the hidden state
:math:`\mathbf{H}_t \in \mathbb{R}^{n \times h}`. This is where the
output gate comes into play. In LSTM it is simply a gated version of the
:math:`\tanh` of the memory cell. This ensures that the values of
:math:`\mathbf{H}_t` are always in the interval :math:`(-1, 1)`.
.. math:: \mathbf{H}_t = \mathbf{O}_t \odot \tanh(\mathbf{C}_t).
Whenever the output gate approximates 1 we effectively pass all memory
information through to the predictor, whereas for the output gate close
to 0 we retain all the information only within the memory cell and
perform no further processing.
:numref:`lstm_3` has a graphical illustration of the data flow.
.. _lstm_3:
.. figure:: ../img/lstm-3.svg
Computing the hidden state in an LSTM model.
Implementation from Scratch
---------------------------
Now let us implement an LSTM from scratch. As same as the experiments in
:numref:`sec_rnn_scratch`, we first load the time machine dataset.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from mxnet import np, npx
from mxnet.gluon import rnn
from d2l import mxnet as d2l
npx.set_np()
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import torch
from torch import nn
from d2l import torch as d2l
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import tensorflow as tf
from d2l import tensorflow as d2l
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Downloading ../data/timemachine.txt from http://d2l-data.s3-accelerate.amazonaws.com/timemachine.txt...
.. raw:: html
.. raw:: html
Initializing Model Parameters
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Next we need to define and initialize the model parameters. As
previously, the hyperparameter ``num_hiddens`` defines the number of
hidden units. We initialize weights following a Gaussian distribution
with 0.01 standard deviation, and we set the biases to 0.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def get_lstm_params(vocab_size, num_hiddens, device):
num_inputs = num_outputs = vocab_size
def normal(shape):
return np.random.normal(scale=0.01, size=shape, ctx=device)
def three():
return (normal((num_inputs, num_hiddens)),
normal((num_hiddens, num_hiddens)),
np.zeros(num_hiddens, ctx=device))
W_xi, W_hi, b_i = three() # Input gate parameters
W_xf, W_hf, b_f = three() # Forget gate parameters
W_xo, W_ho, b_o = three() # Output gate parameters
W_xc, W_hc, b_c = three() # Candidate memory cell parameters
# Output layer parameters
W_hq = normal((num_hiddens, num_outputs))
b_q = np.zeros(num_outputs, ctx=device)
# Attach gradients
params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
b_c, W_hq, b_q]
for param in params:
param.attach_grad()
return params
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def get_lstm_params(vocab_size, num_hiddens, device):
num_inputs = num_outputs = vocab_size
def normal(shape):
return torch.randn(size=shape, device=device)*0.01
def three():
return (normal((num_inputs, num_hiddens)),
normal((num_hiddens, num_hiddens)),
torch.zeros(num_hiddens, device=device))
W_xi, W_hi, b_i = three() # Input gate parameters
W_xf, W_hf, b_f = three() # Forget gate parameters
W_xo, W_ho, b_o = three() # Output gate parameters
W_xc, W_hc, b_c = three() # Candidate memory cell parameters
# Output layer parameters
W_hq = normal((num_hiddens, num_outputs))
b_q = torch.zeros(num_outputs, device=device)
# Attach gradients
params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
b_c, W_hq, b_q]
for param in params:
param.requires_grad_(True)
return params
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def get_lstm_params(vocab_size, num_hiddens):
num_inputs = num_outputs = vocab_size
def normal(shape):
return tf.Variable(tf.random.normal(shape=shape, stddev=0.01,
mean=0, dtype=tf.float32))
def three():
return (normal((num_inputs, num_hiddens)),
normal((num_hiddens, num_hiddens)),
tf.Variable(tf.zeros(num_hiddens), dtype=tf.float32))
W_xi, W_hi, b_i = three() # Input gate parameters
W_xf, W_hf, b_f = three() # Forget gate parameters
W_xo, W_ho, b_o = three() # Output gate parameters
W_xc, W_hc, b_c = three() # Candidate memory cell parameters
# Output layer parameters
W_hq = normal((num_hiddens, num_outputs))
b_q = tf.Variable(tf.zeros(num_outputs), dtype=tf.float32)
# Attach gradients
params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
b_c, W_hq, b_q]
return params
.. raw:: html
.. raw:: html
Defining the Model
~~~~~~~~~~~~~~~~~~
In the initialization function, the hidden state of the LSTM needs to
return an *additional* memory cell with a value of 0 and a shape of
(batch size, number of hidden units). Hence we get the following state
initialization.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def init_lstm_state(batch_size, num_hiddens, device):
return (np.zeros((batch_size, num_hiddens), ctx=device),
np.zeros((batch_size, num_hiddens), ctx=device))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def init_lstm_state(batch_size, num_hiddens, device):
return (torch.zeros((batch_size, num_hiddens), device=device),
torch.zeros((batch_size, num_hiddens), device=device))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def init_lstm_state(batch_size, num_hiddens):
return (tf.zeros(shape=(batch_size, num_hiddens)),
tf.zeros(shape=(batch_size, num_hiddens)))
.. raw:: html
.. raw:: html
The actual model is defined just like what we discussed before:
providing three gates and an auxiliary memory cell. Note that only the
hidden state is passed to the output layer. The memory cell
:math:`\mathbf{C}_t` does not directly participate in the output
computation.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def lstm(inputs, state, params):
[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
W_hq, b_q] = params
(H, C) = state
outputs = []
for X in inputs:
I = npx.sigmoid(np.dot(X, W_xi) + np.dot(H, W_hi) + b_i)
F = npx.sigmoid(np.dot(X, W_xf) + np.dot(H, W_hf) + b_f)
O = npx.sigmoid(np.dot(X, W_xo) + np.dot(H, W_ho) + b_o)
C_tilda = np.tanh(np.dot(X, W_xc) + np.dot(H, W_hc) + b_c)
C = F * C + I * C_tilda
H = O * np.tanh(C)
Y = np.dot(H, W_hq) + b_q
outputs.append(Y)
return np.concatenate(outputs, axis=0), (H, C)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def lstm(inputs, state, params):
[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
W_hq, b_q] = params
(H, C) = state
outputs = []
for X in inputs:
I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
C = F * C + I * C_tilda
H = O * torch.tanh(C)
Y = (H @ W_hq) + b_q
outputs.append(Y)
return torch.cat(outputs, dim=0), (H, C)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def lstm(inputs, state, params):
W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q = params
(H, C) = state
outputs = []
for X in inputs:
X=tf.reshape(X,[-1,W_xi.shape[0]])
I = tf.sigmoid(tf.matmul(X, W_xi) + tf.matmul(H, W_hi) + b_i)
F = tf.sigmoid(tf.matmul(X, W_xf) + tf.matmul(H, W_hf) + b_f)
O = tf.sigmoid(tf.matmul(X, W_xo) + tf.matmul(H, W_ho) + b_o)
C_tilda = tf.tanh(tf.matmul(X, W_xc) + tf.matmul(H, W_hc) + b_c)
C = F * C + I * C_tilda
H = O * tf.tanh(C)
Y = tf.matmul(H, W_hq) + b_q
outputs.append(Y)
return tf.concat(outputs, axis=0), (H,C)
.. raw:: html
.. raw:: html
Training and Prediction
~~~~~~~~~~~~~~~~~~~~~~~
Let us train an LSTM as same as what we did in :numref:`sec_gru`, by
instantiating the ``RNNModelScratch`` class as introduced in
:numref:`sec_rnn_scratch`.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,
init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
perplexity 1.3, 10214.7 tokens/sec on gpu(0)
time travellerit which we had on instance that is this onicalle
travellerit would be remarkably convester is the time dime
.. figure:: output_lstm_86eb9f_51_1.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,
init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
perplexity 1.1, 22161.1 tokens/sec on cuda:0
time traveller three the lermot you of a listouscount but hewh t
traveller comes the larther absert and this now more a mong
.. figure:: output_lstm_86eb9f_54_1.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
vocab_size, num_hiddens, device_name = len(vocab), 256, d2l.try_gpu()._device_name
num_epochs, lr = 500, 1
strategy = tf.distribute.OneDeviceStrategy(device_name)
with strategy.scope():
model = d2l.RNNModelScratch(len(vocab), num_hiddens, init_lstm_state, lstm, get_lstm_params)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, strategy)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
perplexity 1.1, 4139.1 tokens/sec on /GPU:0
time traveller for so it will be convenient to speak of himwas e
travelleryou can show black is white by argument said filby
.. figure:: output_lstm_86eb9f_57_1.svg
.. raw:: html
.. raw:: html
Concise Implementation
----------------------
Using high-level APIs, we can directly instantiate an ``LSTM`` model.
This encapsulates all the configuration details that we made explicit
above. The code is significantly faster as it uses compiled operators
rather than Python for many details that we spelled out in detail
before.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
lstm_layer = rnn.LSTM(num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
perplexity 1.1, 166028.1 tokens/sec on gpu(0)
time traveller of the surmacely caneerianser sish whishand trive
traveller afcer the pauserepsines of space and timeas the d
.. figure:: output_lstm_86eb9f_63_1.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
perplexity 1.0, 315170.8 tokens/sec on cuda:0
time travelleryou can show black is white by argument said filby
travelleryou can show black is white by argument said filby
.. figure:: output_lstm_86eb9f_66_1.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
lstm_cell = tf.keras.layers.LSTMCell(num_hiddens,
kernel_initializer='glorot_uniform')
lstm_layer = tf.keras.layers.RNN(lstm_cell, time_major=True,
return_sequences=True, return_state=True)
device_name = d2l.try_gpu()._device_name
strategy = tf.distribute.OneDeviceStrategy(device_name)
with strategy.scope():
model = d2l.RNNModel(lstm_layer, vocab_size=len(vocab))
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, strategy)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
perplexity 1.0, 9225.6 tokens/sec on /GPU:0
time travelleryou can show black is white by argument said filby
travelleryou can show black is white by argument said filby
.. figure:: output_lstm_86eb9f_69_1.svg
.. raw:: html
.. raw:: html
LSTMs are the prototypical latent variable autoregressive model with
nontrivial state control. Many variants thereof have been proposed over
the years, e.g., multiple layers, residual connections, different types
of regularization. However, training LSTMs and other sequence models
(such as GRUs) are quite costly due to the long range dependency of the
sequence. Later we will encounter alternative models such as
transformers that can be used in some cases.
Summary
-------
- LSTMs have three types of gates: input gates, forget gates, and
output gates that control the flow of information.
- The hidden layer output of LSTM includes the hidden state and the
memory cell. Only the hidden state is passed into the output layer.
The memory cell is entirely internal.
- LSTMs can alleviate vanishing and exploding gradients.
Exercises
---------
1. Adjust the hyperparameters and analyze the their influence on running
time, perplexity, and the output sequence.
2. How would you need to change the model to generate proper words as
opposed to sequences of characters?
3. Compare the computational cost for GRUs, LSTMs, and regular RNNs for
a given hidden dimension. Pay special attention to the training and
inference cost.
4. Since the candidate memory cell ensures that the value range is
between :math:`-1` and :math:`1` by using the :math:`\tanh` function,
why does the hidden state need to use the :math:`\tanh` function
again to ensure that the output value range is between :math:`-1` and
:math:`1`?
5. Implement an LSTM model for time series prediction rather than
character sequence prediction.
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html