.. _sec_autograd: Automatic Differentiation ========================= As we have explained in :numref:`sec_calculus`, differentiation is a crucial step in nearly all deep learning optimization algorithms. While the calculations for taking these derivatives are straightforward, requiring only some basic calculus, for complex models, working out the updates by hand can be a pain (and often error-prone). Deep learning frameworks expedite this work by automatically calculating derivatives, i.e., *automatic differentiation*. In practice, based on our designed model the system builds a *computational graph*, tracking which data combined through which operations to produce the output. Automatic differentiation enables the system to subsequently backpropagate gradients. Here, *backpropagate* simply means to trace through the computational graph, filling in the partial derivatives with respect to each parameter. A Simple Example ---------------- As a toy example, say that we are interested in differentiating the function :math:`y = 2\mathbf{x}^{\top}\mathbf{x}` with respect to the column vector :math:`\mathbf{x}`. To start, let us create the variable ``x`` and assign it an initial value. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python from mxnet import autograd, np, npx npx.set_np() x = np.arange(4.0) x .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output array([0., 1., 2., 3.]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import torch x = torch.arange(4.0) x .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tensor([0., 1., 2., 3.]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import tensorflow as tf x = tf.range(4, dtype=tf.float32) x .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output .. raw:: html
.. raw:: html
Before we even calculate the gradient of :math:`y` with respect to :math:`\mathbf{x}`, we will need a place to store it. It is important that we do not allocate new memory every time we take a derivative with respect to a parameter because we will often update the same parameters thousands or millions of times and could quickly run out of memory. Note that a gradient of a scalar-valued function with respect to a vector :math:`\mathbf{x}` is itself vector-valued and has the same shape as :math:`\mathbf{x}`. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python # We allocate memory for a tensor's gradient by invoking `attach_grad` x.attach_grad() # After we calculate a gradient taken with respect to `x`, we will be able to # access it via the `grad` attribute, whose values are initialized with 0s x.grad .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output array([0., 0., 0., 0.]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python x.requires_grad_(True) # Same as `x = torch.arange(4.0, requires_grad=True)` x.grad # The default value is None .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python x = tf.Variable(x) .. raw:: html
.. raw:: html
Now let us calculate :math:`y`. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python # Place our code inside an `autograd.record` scope to build the computational # graph with autograd.record(): y = 2 * np.dot(x, x) y .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output array(28.) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python y = 2 * torch.dot(x, x) y .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tensor(28., grad_fn=) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python # Record all computations onto a tape with tf.GradientTape() as t: y = 2 * tf.tensordot(x, x, axes=1) y .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output .. raw:: html
.. raw:: html
Since ``x`` is a vector of length 4, an dot product of ``x`` and ``x`` is performed, yielding the scalar output that we assign to ``y``. Next, we can automatically calculate the gradient of ``y`` with respect to each component of ``x`` by calling the function for backpropagation and printing the gradient. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python y.backward() x.grad .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output [19:54:27] src/base.cc:49: GPU context requested, but no GPUs found. .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output array([ 0., 4., 8., 12.]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python y.backward() x.grad .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tensor([ 0., 4., 8., 12.]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python x_grad = t.gradient(y, x) x_grad .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output .. raw:: html
.. raw:: html
The gradient of the function :math:`y = 2\mathbf{x}^{\top}\mathbf{x}` with respect to :math:`\mathbf{x}` should be :math:`4\mathbf{x}`. Let us quickly verify that our desired gradient was calculated correctly. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python x.grad == 4 * x .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output array([ True, True, True, True]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python x.grad == 4 * x .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tensor([True, True, True, True]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python x_grad == 4 * x .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output .. raw:: html
.. raw:: html
Now let us calculate another function of ``x``. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python with autograd.record(): y = x.sum() y.backward() x.grad # Overwritten by the newly calculated gradient .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output array([1., 1., 1., 1.]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python # PyTorch accumulates the gradient in default, we need to clear the previous # values x.grad.zero_() y = x.sum() y.backward() x.grad .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tensor([1., 1., 1., 1.]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python with tf.GradientTape() as t: y = tf.reduce_sum(x) t.gradient(y, x) # Overwritten by the newly calculated gradient .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output .. raw:: html
.. raw:: html
Backward for Non-Scalar Variables --------------------------------- Technically, when ``y`` is not a scalar, the most natural interpretation of the differentiation of a vector ``y`` with respect to a vector ``x`` is a matrix. For higher-order and higher-dimensional ``y`` and ``x``, the differentiation result could be a high-order tensor. However, while these more exotic objects do show up in advanced machine learning (including in deep learning), more often when we are calling backward on a vector, we are trying to calculate the derivatives of the loss functions for each constituent of a *batch* of training examples. Here, our intent is not to calculate the differentiation matrix but rather the sum of the partial derivatives computed individually for each example in the batch. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python # When we invoke `backward` on a vector-valued variable `y` (function of `x`), # a new scalar variable is created by summing the elements in `y`. Then the # gradient of that scalar variable with respect to `x` is computed with autograd.record(): y = x * x # `y` is a vector y.backward() x.grad # Equals to y = sum(x * x) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output array([0., 2., 4., 6.]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python # Invoking `backward` on a non-scalar requires passing in a `gradient` argument # which specifies the gradient of the differentiated function w.r.t `self`. # In our case, we simply want to sum the partial derivatives, so passing # in a gradient of ones is appropriate x.grad.zero_() y = x * x # y.backward(torch.ones(len(x))) equivalent to the below y.sum().backward() x.grad .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tensor([0., 2., 4., 6.]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python with tf.GradientTape() as t: y = x * x t.gradient(y, x) # Same as `y = tf.reduce_sum(x * x)` .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output .. raw:: html
.. raw:: html
Detaching Computation --------------------- Sometimes, we wish to move some calculations outside of the recorded computational graph. For example, say that ``y`` was calculated as a function of ``x``, and that subsequently ``z`` was calculated as a function of both ``y`` and ``x``. Now, imagine that we wanted to calculate the gradient of ``z`` with respect to ``x``, but wanted for some reason to treat ``y`` as a constant, and only take into account the role that ``x`` played after ``y`` was calculated. Here, we can detach ``y`` to return a new variable ``u`` that has the same value as ``y`` but discards any information about how ``y`` was computed in the computational graph. In other words, the gradient will not flow backwards through ``u`` to ``x``. Thus, the following backpropagation function computes the partial derivative of ``z = u * x`` with respect to ``x`` while treating ``u`` as a constant, instead of the partial derivative of ``z = x * x * x`` with respect to ``x``. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python with autograd.record(): y = x * x u = y.detach() z = u * x z.backward() x.grad == u .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output array([ True, True, True, True]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python x.grad.zero_() y = x * x u = y.detach() z = u * x z.sum().backward() x.grad == u .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tensor([True, True, True, True]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python # Set `persistent=True` to run `t.gradient` more than once with tf.GradientTape(persistent=True) as t: y = x * x u = tf.stop_gradient(y) z = u * x x_grad = t.gradient(z, x) x_grad == u .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output .. raw:: html
.. raw:: html
Since the computation of ``y`` was recorded, we can subsequently invoke backpropagation on ``y`` to get the derivative of ``y = x * x`` with respect to ``x``, which is ``2 * x``. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python y.backward() x.grad == 2 * x .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output array([ True, True, True, True]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python x.grad.zero_() y.sum().backward() x.grad == 2 * x .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tensor([True, True, True, True]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python t.gradient(y, x) == 2 * x .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output .. raw:: html
.. raw:: html
Computing the Gradient of Python Control Flow --------------------------------------------- One benefit of using automatic differentiation is that even if building the computational graph of a function required passing through a maze of Python control flow (e.g., conditionals, loops, and arbitrary function calls), we can still calculate the gradient of the resulting variable. In the following snippet, note that the number of iterations of the ``while`` loop and the evaluation of the ``if`` statement both depend on the value of the input ``a``. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def f(a): b = a * 2 while np.linalg.norm(b) < 1000: b = b * 2 if b.sum() > 0: c = b else: c = 100 * b return c .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def f(a): b = a * 2 while b.norm() < 1000: b = b * 2 if b.sum() > 0: c = b else: c = 100 * b return c .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def f(a): b = a * 2 while tf.norm(b) < 1000: b = b * 2 if tf.reduce_sum(b) > 0: c = b else: c = 100 * b return c .. raw:: html
.. raw:: html
Let us compute the gradient. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python a = np.random.normal() a.attach_grad() with autograd.record(): d = f(a) d.backward() .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python a = torch.randn(size=(), requires_grad=True) d = f(a) d.backward() .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python a = tf.Variable(tf.random.normal(shape=())) with tf.GradientTape() as t: d = f(a) d_grad = t.gradient(d, a) d_grad .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output .. raw:: html
.. raw:: html
We can now analyze the ``f`` function defined above. Note that it is piecewise linear in its input ``a``. In other words, for any ``a`` there exists some constant scalar ``k`` such that ``f(a) = k * a``, where the value of ``k`` depends on the input ``a``. Consequently ``d / a`` allows us to verify that the gradient is correct. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python a.grad == d / a .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output array(True) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python a.grad == d / a .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tensor(True) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python d_grad == d / a .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output .. raw:: html
.. raw:: html
Summary ------- - Deep learning frameworks can automate the calculation of derivatives. To use it, we first attach gradients to those variables with respect to which we desire partial derivatives. We then record the computation of our target value, execute its function for backpropagation, and access the resulting gradient. Exercises --------- 1. Why is the second derivative much more expensive to compute than the first derivative? 2. After running the function for backpropagation, immediately run it again and see what happens. 3. In the control flow example where we calculate the derivative of ``d`` with respect to ``a``, what would happen if we changed the variable ``a`` to a random vector or matrix. At this point, the result of the calculation ``f(a)`` is no longer a scalar. What happens to the result? How do we analyze this? 4. Redesign an example of finding the gradient of the control flow. Run and analyze the result. 5. Let :math:`f(x) = \sin(x)`. Plot :math:`f(x)` and :math:`\frac{df(x)}{dx}`, where the latter is computed without exploiting that :math:`f'(x) = \cos(x)`. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html