11.8. RMSProp¶ Open the notebook in SageMaker Studio Lab
One of the key issues in Section 11.7 is that the learning rate decreases at a predefined schedule of effectively \(\mathcal{O}(t^{-\frac{1}{2}})\). While this is generally appropriate for convex problems, it might not be ideal for nonconvex ones, such as those encountered in deep learning. Yet, the coordinate-wise adaptivity of Adagrad is highly desirable as a preconditioner.
() proposed the RMSProp algorithm as a simple fix to decouple rate scheduling from coordinate-adaptive learning rates. The issue is that Adagrad accumulates the squares of the gradient \(\mathbf{g}_t\) into a state vector \(\mathbf{s}_t = \mathbf{s}_{t-1} + \mathbf{g}_t^2\). As a result \(\mathbf{s}_t\) keeps on growing without bound due to the lack of normalization, essentially linearly as the algorithm converges.
One way of fixing this problem would be to use \(\mathbf{s}_t / t\). For reasonable distributions of \(\mathbf{g}_t\) this will converge. Unfortunately it might take a very long time until the limit behavior starts to matter since the procedure remembers the full trajectory of values. An alternative is to use a leaky average in the same way we used in the momentum method, i.e., \(\mathbf{s}_t \leftarrow \gamma \mathbf{s}_{t-1} + (1-\gamma) \mathbf{g}_t^2\) for some parameter \(\gamma > 0\). Keeping all other parts unchanged yields RMSProp.
11.8.1. The Algorithm¶
Let us write out the equations in detail.
The constant \(\epsilon > 0\) is typically set to \(10^{-6}\) to ensure that we do not suffer from division by zero or overly large step sizes. Given this expansion we are now free to control the learning rate \(\eta\) independently of the scaling that is applied on a per-coordinate basis. In terms of leaky averages we can apply the same reasoning as previously applied in the case of the momentum method. Expanding the definition of \(\mathbf{s}_t\) yields
As before in Section 11.6 we use \(1 + \gamma + \gamma^2 + \ldots, = \frac{1}{1-\gamma}\). Hence the sum of weights is normalized to \(1\) with a half-life time of an observation of \(\gamma^{-1}\). Let us visualize the weights for the past 40 time steps for various choices of \(\gamma\).
%matplotlib inline
import math
from mxnet import np, npx
from d2l import mxnet as d2l
npx.set_np()
d2l.set_figsize()
gammas = [0.95, 0.9, 0.8, 0.7]
for gamma in gammas:
x = np.arange(40).asnumpy()
d2l.plt.plot(x, (1-gamma) * gamma ** x, label=f'gamma = {gamma:.2f}')
d2l.plt.xlabel('time');
import math
import torch
from d2l import torch as d2l
d2l.set_figsize()
gammas = [0.95, 0.9, 0.8, 0.7]
for gamma in gammas:
x = torch.arange(40).detach().numpy()
d2l.plt.plot(x, (1-gamma) * gamma ** x, label=f'gamma = {gamma:.2f}')
d2l.plt.xlabel('time');
import math
import tensorflow as tf
from d2l import tensorflow as d2l
d2l.set_figsize()
gammas = [0.95, 0.9, 0.8, 0.7]
for gamma in gammas:
x = tf.range(40).numpy()
d2l.plt.plot(x, (1-gamma) * gamma ** x, label=f'gamma = {gamma:.2f}')
d2l.plt.xlabel('time');
11.8.2. Implementation from Scratch¶
As before we use the quadratic function \(f(\mathbf{x})=0.1x_1^2+2x_2^2\) to observe the trajectory of RMSProp. Recall that in Section 11.7, when we used Adagrad with a learning rate of 0.4, the variables moved only very slowly in the later stages of the algorithm since the learning rate decreased too quickly. Since \(\eta\) is controlled separately this does not happen with RMSProp.
def rmsprop_2d(x1, x2, s1, s2):
g1, g2, eps = 0.2 * x1, 4 * x2, 1e-6
s1 = gamma * s1 + (1 - gamma) * g1 ** 2
s2 = gamma * s2 + (1 - gamma) * g2 ** 2
x1 -= eta / math.sqrt(s1 + eps) * g1
x2 -= eta / math.sqrt(s2 + eps) * g2
return x1, x2, s1, s2
def f_2d(x1, x2):
return 0.1 * x1 ** 2 + 2 * x2 ** 2
eta, gamma = 0.4, 0.9
d2l.show_trace_2d(f_2d, d2l.train_2d(rmsprop_2d))
epoch 20, x1: -0.010599, x2: 0.000000
def rmsprop_2d(x1, x2, s1, s2):
g1, g2, eps = 0.2 * x1, 4 * x2, 1e-6
s1 = gamma * s1 + (1 - gamma) * g1 ** 2
s2 = gamma * s2 + (1 - gamma) * g2 ** 2
x1 -= eta / math.sqrt(s1 + eps) * g1
x2 -= eta / math.sqrt(s2 + eps) * g2
return x1, x2, s1, s2
def f_2d(x1, x2):
return 0.1 * x1 ** 2 + 2 * x2 ** 2
eta, gamma = 0.4, 0.9
d2l.show_trace_2d(f_2d, d2l.train_2d(rmsprop_2d))
epoch 20, x1: -0.010599, x2: 0.000000
/home/d2l-worker/miniconda3/envs/d2l-en-classic-1/lib/python3.9/site-packages/torch/functional.py:478: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:2895.)
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
def rmsprop_2d(x1, x2, s1, s2):
g1, g2, eps = 0.2 * x1, 4 * x2, 1e-6
s1 = gamma * s1 + (1 - gamma) * g1 ** 2
s2 = gamma * s2 + (1 - gamma) * g2 ** 2
x1 -= eta / math.sqrt(s1 + eps) * g1
x2 -= eta / math.sqrt(s2 + eps) * g2
return x1, x2, s1, s2
def f_2d(x1, x2):
return 0.1 * x1 ** 2 + 2 * x2 ** 2
eta, gamma = 0.4, 0.9
d2l.show_trace_2d(f_2d, d2l.train_2d(rmsprop_2d))
epoch 20, x1: -0.010599, x2: 0.000000
Next, we implement RMSProp to be used in a deep network. This is equally straightforward.
def init_rmsprop_states(feature_dim):
s_w = np.zeros((feature_dim, 1))
s_b = np.zeros(1)
return (s_w, s_b)
def rmsprop(params, states, hyperparams):
gamma, eps = hyperparams['gamma'], 1e-6
for p, s in zip(params, states):
s[:] = gamma * s + (1 - gamma) * np.square(p.grad)
p[:] -= hyperparams['lr'] * p.grad / np.sqrt(s + eps)
def init_rmsprop_states(feature_dim):
s_w = torch.zeros((feature_dim, 1))
s_b = torch.zeros(1)
return (s_w, s_b)
def rmsprop(params, states, hyperparams):
gamma, eps = hyperparams['gamma'], 1e-6
for p, s in zip(params, states):
with torch.no_grad():
s[:] = gamma * s + (1 - gamma) * torch.square(p.grad)
p[:] -= hyperparams['lr'] * p.grad / torch.sqrt(s + eps)
p.grad.data.zero_()
def init_rmsprop_states(feature_dim):
s_w = tf.Variable(tf.zeros((feature_dim, 1)))
s_b = tf.Variable(tf.zeros(1))
return (s_w, s_b)
def rmsprop(params, grads, states, hyperparams):
gamma, eps = hyperparams['gamma'], 1e-6
for p, s, g in zip(params, states, grads):
s[:].assign(gamma * s + (1 - gamma) * tf.math.square(g))
p[:].assign(p - hyperparams['lr'] * g / tf.math.sqrt(s + eps))
We set the initial learning rate to 0.01 and the weighting term \(\gamma\) to 0.9. That is, \(\mathbf{s}\) aggregates on average over the past \(1/(1-\gamma) = 10\) observations of the square gradient.
data_iter, feature_dim = d2l.get_data_ch11(batch_size=10)
d2l.train_ch11(rmsprop, init_rmsprop_states(feature_dim),
{'lr': 0.01, 'gamma': 0.9}, data_iter, feature_dim);
loss: 0.246, 0.169 sec/epoch
data_iter, feature_dim = d2l.get_data_ch11(batch_size=10)
d2l.train_ch11(rmsprop, init_rmsprop_states(feature_dim),
{'lr': 0.01, 'gamma': 0.9}, data_iter, feature_dim);
loss: 0.244, 0.015 sec/epoch
data_iter, feature_dim = d2l.get_data_ch11(batch_size=10)
d2l.train_ch11(rmsprop, init_rmsprop_states(feature_dim),
{'lr': 0.01, 'gamma': 0.9}, data_iter, feature_dim);
loss: 0.244, 0.113 sec/epoch
11.8.3. Concise Implementation¶
Since RMSProp is a rather popular algorithm it is also available in the
Trainer
instance. All we need to do is instantiate it using an
algorithm named rmsprop
, assigning \(\gamma\) to the parameter
gamma1
.
d2l.train_concise_ch11('rmsprop', {'learning_rate': 0.01, 'gamma1': 0.9},
data_iter)
loss: 0.244, 0.086 sec/epoch
trainer = torch.optim.RMSprop
d2l.train_concise_ch11(trainer, {'lr': 0.01, 'alpha': 0.9},
data_iter)
loss: 0.244, 0.014 sec/epoch
trainer = tf.keras.optimizers.RMSprop
d2l.train_concise_ch11(trainer, {'learning_rate': 0.01, 'rho': 0.9},
data_iter)
loss: 0.243, 0.139 sec/epoch
11.8.4. Summary¶
RMSProp is very similar to Adagrad insofar as both use the square of the gradient to scale coefficients.
RMSProp shares with momentum the leaky averaging. However, RMSProp uses the technique to adjust the coefficient-wise preconditioner.
The learning rate needs to be scheduled by the experimenter in practice.
The coefficient \(\gamma\) determines how long the history is when adjusting the per-coordinate scale.
11.8.5. Exercises¶
What happens experimentally if we set \(\gamma = 1\)? Why?
Rotate the optimization problem to minimize \(f(\mathbf{x}) = 0.1 (x_1 + x_2)^2 + 2 (x_1 - x_2)^2\). What happens to the convergence?
Try out what happens to RMSProp on a real machine learning problem, such as training on Fashion-MNIST. Experiment with different choices for adjusting the learning rate.
Would you want to adjust \(\gamma\) as optimization progresses? How sensitive is RMSProp to this?