5.6. GPUs
Open the notebook in Colab
Open the notebook in Colab
Open the notebook in Colab
Open the notebook in SageMaker Studio Lab

In Section 1.5, we discussed the rapid growth of computation over the past two decades. In a nutshell, GPU performance has increased by a factor of 1000 every decade since 2000. This offers great opportunities but it also suggests a significant need to provide such performance.

In this section, we begin to discuss how to harness this computational performance for your research. First by using single GPUs and at a later point, how to use multiple GPUs and multiple servers (with multiple GPUs).

Specifically, we will discuss how to use a single NVIDIA GPU for calculations. First, make sure you have at least one NVIDIA GPU installed. Then, download the NVIDIA driver and CUDA and follow the prompts to set the appropriate path. Once these preparations are complete, the nvidia-smi command can be used to view the graphics card information.

!nvidia-smi
Sat Nov 12 19:53:16 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.106.00   Driver Version: 460.106.00   CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla V100-SXM2...  Off  | 00000000:00:1B.0 Off |                    0 |
| N/A   45C    P0    63W / 300W |   1784MiB / 16160MiB |     12%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2...  Off  | 00000000:00:1C.0 Off |                    0 |
| N/A   36C    P0    48W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   2  Tesla V100-SXM2...  Off  | 00000000:00:1D.0 Off |                    0 |
| N/A   37C    P0    54W / 300W |   1752MiB / 16160MiB |     11%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   3  Tesla V100-SXM2...  Off  | 00000000:00:1E.0 Off |                    0 |
| N/A   35C    P0    51W / 300W |      0MiB / 16160MiB |      4%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A     36249      C   ...l-en-classic-1/bin/python     1781MiB |
|    2   N/A  N/A     36249      C   ...l-en-classic-1/bin/python     1749MiB |
+-----------------------------------------------------------------------------+

You might have noticed that a MXNet tensor looks almost identical to a NumPy ndarray. But there are a few crucial differences. One of the key features that distinguishes MXNet from NumPy is its support for diverse hardware devices.

In MXNet, every array has a context. So far, by default, all variables and associated computation have been assigned to the CPU. Typically, other contexts might be various GPUs. Things can get even hairier when we deploy jobs across multiple servers. By assigning arrays to contexts intelligently, we can minimize the time spent transferring data between devices. For example, when training neural networks on a server with a GPU, we typically prefer for the model’s parameters to live on the GPU.

Next, we need to confirm that the GPU version of MXNet is installed. If a CPU version of MXNet is already installed, we need to uninstall it first. For example, use the pip uninstall mxnet command, then install the corresponding MXNet version according to your CUDA version. Assuming you have CUDA 10.0 installed, you can install the MXNet version that supports CUDA 10.0 via pip install mxnet-cu100.

!nvidia-smi
Sat Nov 12 20:46:02 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.106.00   Driver Version: 460.106.00   CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla V100-SXM2...  Off  | 00000000:00:1B.0 Off |                    0 |
| N/A   54C    P0    41W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2...  Off  | 00000000:00:1C.0 Off |                    0 |
| N/A   39C    P0    59W / 300W |   3376MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   2  Tesla V100-SXM2...  Off  | 00000000:00:1D.0 Off |                    0 |
| N/A   55C    P0    52W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   3  Tesla V100-SXM2...  Off  | 00000000:00:1E.0 Off |                    0 |
| N/A   41C    P0    54W / 300W |   3284MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    1   N/A  N/A     65370      C   ...l-en-classic-1/bin/python     3373MiB |
|    3   N/A  N/A     65370      C   ...l-en-classic-1/bin/python     3281MiB |
+-----------------------------------------------------------------------------+

In PyTorch, every array has a device, we often refer it as a context. So far, by default, all variables and associated computation have been assigned to the CPU. Typically, other contexts might be various GPUs. Things can get even hairier when we deploy jobs across multiple servers. By assigning arrays to contexts intelligently, we can minimize the time spent transferring data between devices. For example, when training neural networks on a server with a GPU, we typically prefer for the model’s parameters to live on the GPU.

Next, we need to confirm that the GPU version of PyTorch is installed. If a CPU version of PyTorch is already installed, we need to uninstall it first. For example, use the pip uninstall torch command, then install the corresponding PyTorch version according to your CUDA version. Assuming you have CUDA 10.0 installed, you can install the PyTorch version that supports CUDA 10.0 via pip install torch-cu100.

!nvidia-smi
Sat Nov 12 21:23:06 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.106.00   Driver Version: 460.106.00   CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla V100-SXM2...  Off  | 00000000:00:1B.0 Off |                    0 |
| N/A   23C    P0    46W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2...  Off  | 00000000:00:1C.0 Off |                    0 |
| N/A   23C    P0    46W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   2  Tesla V100-SXM2...  Off  | 00000000:00:1D.0 Off |                    0 |
| N/A   23C    P0    49W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   3  Tesla V100-SXM2...  Off  | 00000000:00:1E.0 Off |                    0 |
| N/A   20C    P0    48W / 300W |      0MiB / 16160MiB |      5%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

To run the programs in this section, you need at least two GPUs. Note that this might be extravagant for most desktop computers but it is easily available in the cloud, e.g., by using the AWS EC2 multi-GPU instances. Almost all other sections do not require multiple GPUs. Instead, this is simply to illustrate how data flow between different devices.

5.6.1. Computing Devices

We can specify devices, such as CPUs and GPUs, for storage and calculation. By default, tensors are created in the main memory and then use the CPU to calculate it.

In MXNet, the CPU and GPU can be indicated by cpu() and gpu(). It should be noted that cpu() (or any integer in the parentheses) means all physical CPUs and memory. This means that MXNet’s calculations will try to use all CPU cores. However, gpu() only represents one card and the corresponding memory. If there are multiple GPUs, we use gpu(i) to represent the \(i^\mathrm{th}\) GPU (\(i\) starts from 0). Also, gpu(0) and gpu() are equivalent.

from mxnet import np, npx
from mxnet.gluon import nn

npx.set_np()

npx.cpu(), npx.gpu(), npx.gpu(1)
(cpu(0), gpu(0), gpu(1))

In PyTorch, the CPU and GPU can be indicated by torch.device('cpu') and torch.device('cuda'). It should be noted that the cpu device means all physical CPUs and memory. This means that PyTorch’s calculations will try to use all CPU cores. However, a gpu device only represents one card and the corresponding memory. If there are multiple GPUs, we use torch.device(f'cuda:{i}') to represent the \(i^\mathrm{th}\) GPU (\(i\) starts from 0). Also, gpu:0 and gpu are equivalent.

import torch
from torch import nn

torch.device('cpu'), torch.device('cuda'), torch.device('cuda:1')
(device(type='cpu'), device(type='cuda'), device(type='cuda', index=1))
import tensorflow as tf

tf.device('/CPU:0'), tf.device('/GPU:0'), tf.device('/GPU:1')
(<tensorflow.python.eager.context._EagerDeviceContext at 0x7fc75ed73a00>,
 <tensorflow.python.eager.context._EagerDeviceContext at 0x7fc768284c80>,
 <tensorflow.python.eager.context._EagerDeviceContext at 0x7fc75ed75340>)

We can query the number of available GPUs.

npx.num_gpus()
2
torch.cuda.device_count()
2
len(tf.config.experimental.list_physical_devices('GPU'))
2

Now we define two convenient functions that allow us to run code even if the requested GPUs do not exist.

def try_gpu(i=0):  #@save
    """Return gpu(i) if exists, otherwise return cpu()."""
    return npx.gpu(i) if npx.num_gpus() >= i + 1 else npx.cpu()

def try_all_gpus():  #@save
    """Return all available GPUs, or [cpu()] if no GPU exists."""
    devices = [npx.gpu(i) for i in range(npx.num_gpus())]
    return devices if devices else [npx.cpu()]

try_gpu(), try_gpu(10), try_all_gpus()
(gpu(0), cpu(0), [gpu(0), gpu(1)])
def try_gpu(i=0):  #@save
    """Return gpu(i) if exists, otherwise return cpu()."""
    if torch.cuda.device_count() >= i + 1:
        return torch.device(f'cuda:{i}')
    return torch.device('cpu')

def try_all_gpus():  #@save
    """Return all available GPUs, or [cpu(),] if no GPU exists."""
    devices = [torch.device(f'cuda:{i}')
             for i in range(torch.cuda.device_count())]
    return devices if devices else [torch.device('cpu')]

try_gpu(), try_gpu(10), try_all_gpus()
(device(type='cuda', index=0),
 device(type='cpu'),
 [device(type='cuda', index=0), device(type='cuda', index=1)])
def try_gpu(i=0):  #@save
    """Return gpu(i) if exists, otherwise return cpu()."""
    if len(tf.config.experimental.list_physical_devices('GPU')) >= i + 1:
        return tf.device(f'/GPU:{i}')
    return tf.device('/CPU:0')

def try_all_gpus():  #@save
    """Return all available GPUs, or [cpu(),] if no GPU exists."""
    num_gpus = len(tf.config.experimental.list_physical_devices('GPU'))
    devices = [tf.device(f'/GPU:{i}') for i in range(num_gpus)]
    return devices if devices else [tf.device('/CPU:0')]

try_gpu(), try_gpu(10), try_all_gpus()
(<tensorflow.python.eager.context._EagerDeviceContext at 0x7fc75ed10080>,
 <tensorflow.python.eager.context._EagerDeviceContext at 0x7fc75ed10400>,
 [<tensorflow.python.eager.context._EagerDeviceContext at 0x7fc75ed10440>,
  <tensorflow.python.eager.context._EagerDeviceContext at 0x7fc75ed104c0>])

5.6.2. Tensors and GPUs

By default, tensors are created on the CPU. We can query the device where the tensor is located.

x = np.array([1, 2, 3])
x.ctx
cpu(0)
x = torch.tensor([1, 2, 3])
x.device
device(type='cpu')
x = tf.constant([1, 2, 3])
x.device
'/job:localhost/replica:0/task:0/device:GPU:0'

It is important to note that whenever we want to operate on multiple terms, they need to be on the same device. For instance, if we sum two tensors, we need to make sure that both arguments live on the same device—otherwise the framework would not know where to store the result or even how to decide where to perform the computation.

5.6.2.1. Storage on the GPU

There are several ways to store a tensor on the GPU. For example, we can specify a storage device when creating a tensor. Next, we create the tensor variable X on the first gpu. The tensor created on a GPU only consumes the memory of this GPU. We can use the nvidia-smi command to view GPU memory usage. In general, we need to make sure that we do not create data that exceed the GPU memory limit.

X = np.ones((2, 3), ctx=try_gpu())
X
array([[1., 1., 1.],
       [1., 1., 1.]], ctx=gpu(0))
X = torch.ones(2, 3, device=try_gpu())
X
tensor([[1., 1., 1.],
        [1., 1., 1.]], device='cuda:0')
with try_gpu():
    X = tf.ones((2, 3))
X
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[1., 1., 1.],
       [1., 1., 1.]], dtype=float32)>

Assuming that you have at least two GPUs, the following code will create a random tensor on the second GPU.

Y = np.random.uniform(size=(2, 3), ctx=try_gpu(1))
Y
array([[0.67478997, 0.07540122, 0.9956977 ],
       [0.09488854, 0.415456  , 0.11231736]], ctx=gpu(1))
Y = torch.rand(2, 3, device=try_gpu(1))
Y
tensor([[0.4144, 0.7294, 0.5390],
        [0.1520, 0.7057, 0.2716]], device='cuda:1')
with try_gpu(1):
    Y = tf.random.uniform((2, 3))
Y
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[0.70973015, 0.35046446, 0.00469756],
       [0.16234517, 0.6785749 , 0.14071739]], dtype=float32)>

5.6.2.2. Copying

If we want to compute X + Y, we need to decide where to perform this operation. For instance, as shown in Fig. 5.6.1, we can transfer X to the second GPU and perform the operation there. Do not simply add X and Y, since this will result in an exception. The runtime engine would not know what to do: it cannot find data on the same device and it fails. Since Y lives on the second GPU, we need to move X there before we can add the two.

../_images/copyto.svg

Fig. 5.6.1 Copy data to perform an operation on the same device.

Z = X.copyto(try_gpu(1))
print(X)
print(Z)
[[1. 1. 1.]
 [1. 1. 1.]] @gpu(0)
[[1. 1. 1.]
 [1. 1. 1.]] @gpu(1)
Z = X.cuda(1)
print(X)
print(Z)
tensor([[1., 1., 1.],
        [1., 1., 1.]], device='cuda:0')
tensor([[1., 1., 1.],
        [1., 1., 1.]], device='cuda:1')
with try_gpu(1):
    Z = X
print(X)
print(Z)
tf.Tensor(
[[1. 1. 1.]
 [1. 1. 1.]], shape=(2, 3), dtype=float32)
tf.Tensor(
[[1. 1. 1.]
 [1. 1. 1.]], shape=(2, 3), dtype=float32)

Now that the data are on the same GPU (both Z and Y are), we can add them up.

Y + Z
array([[1.6747899, 1.0754012, 1.9956977],
       [1.0948886, 1.415456 , 1.1123173]], ctx=gpu(1))

Imagine that your variable Z already lives on your second GPU. What happens if we still call Z.copyto(gpu(1))? It will make a copy and allocate new memory, even though that variable already lives on the desired device. There are times where, depending on the environment our code is running in, two variables may already live on the same device. So we want to make a copy only if the variables currently live in different devices. In these cases, we can call as_in_ctx. If the variable already live in the specified device then this is a no-op. Unless you specifically want to make a copy, as_in_ctx is the method of choice.

Z.as_in_ctx(try_gpu(1)) is Z
True
Y + Z
tensor([[1.4144, 1.7294, 1.5390],
        [1.1520, 1.7057, 1.2716]], device='cuda:1')

Imagine that your variable Z already lives on your second GPU. What happens if we still call Z.cuda(1)? It will return Z instead of making a copy and allocating new memory.

Z.cuda(1) is Z
True
Y + Z
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[1.7097301, 1.3504645, 1.0046976],
       [1.1623452, 1.6785749, 1.1407174]], dtype=float32)>

Imagine that your variable Z already lives on your second GPU. What happens if we still call Z2 = Z under the same device scope? It will return Z instead of making a copy and allocating new memory.

with try_gpu(1):
    Z2 = Z
Z2 is Z
True

5.6.2.3. Side Notes

People use GPUs to do machine learning because they expect them to be fast. But transferring variables between devices is slow. So we want you to be 100% certain that you want to do something slow before we let you do it. If the deep learning framework just did the copy automatically without crashing then you might not realize that you had written some slow code.

Also, transferring data between devices (CPU, GPUs, and other machines) is something that is much slower than computation. It also makes parallelization a lot more difficult, since we have to wait for data to be sent (or rather to be received) before we can proceed with more operations. This is why copy operations should be taken with great care. As a rule of thumb, many small operations are much worse than one big operation. Moreover, several operations at a time are much better than many single operations interspersed in the code unless you know what you are doing. This is the case since such operations can block if one device has to wait for the other before it can do something else. It is a bit like ordering your coffee in a queue rather than pre-ordering it by phone and finding out that it is ready when you are.

Last, when we print tensors or convert tensors to the NumPy format, if the data is not in the main memory, the framework will copy it to the main memory first, resulting in additional transmission overhead. Even worse, it is now subject to the dreaded global interpreter lock that makes everything wait for Python to complete.

5.6.3. Neural Networks and GPUs

Similarly, a neural network model can specify devices. The following code puts the model parameters on the GPU.

net = nn.Sequential()
net.add(nn.Dense(1))
net.initialize(ctx=try_gpu())
net = nn.Sequential(nn.Linear(3, 1))
net = net.to(device=try_gpu())
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    net = tf.keras.models.Sequential([
        tf.keras.layers.Dense(1)])
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')

We will see many more examples of how to run models on GPUs in the following chapters, simply since they will become somewhat more computationally intensive.

When the input is a tensor on the GPU, the model will calculate the result on the same GPU.

net(X)
array([[0.04995865],
       [0.04995865]], ctx=gpu(0))
net(X)
tensor([[-0.2492],
        [-0.2492]], device='cuda:0', grad_fn=<AddmmBackward0>)
net(X)
<tf.Tensor: shape=(2, 1), dtype=float32, numpy=
array([[1.2325809],
       [1.2325809]], dtype=float32)>

Let us confirm that the model parameters are stored on the same GPU.

net[0].weight.data().ctx
gpu(0)
net[0].weight.data.device
device(type='cuda', index=0)
net.layers[0].weights[0].device, net.layers[0].weights[1].device
('/job:localhost/replica:0/task:0/device:GPU:0',
 '/job:localhost/replica:0/task:0/device:GPU:0')

In short, as long as all data and parameters are on the same device, we can learn models efficiently. In the following chapters we will see several such examples.

5.6.4. Summary

  • We can specify devices for storage and calculation, such as the CPU or GPU. By default, data are created in the main memory and then use the CPU for calculations.

  • The deep learning framework requires all input data for calculation to be on the same device, be it CPU or the same GPU.

  • You can lose significant performance by moving data without care. A typical mistake is as follows: computing the loss for every minibatch on the GPU and reporting it back to the user on the command line (or logging it in a NumPy ndarray) will trigger a global interpreter lock which stalls all GPUs. It is much better to allocate memory for logging inside the GPU and only move larger logs.

5.6.5. Exercises

  1. Try a larger computation task, such as the multiplication of large matrices, and see the difference in speed between the CPU and GPU. What about a task with a small amount of calculations?

  2. How should we read and write model parameters on the GPU?

  3. Measure the time it takes to compute 1000 matrix-matrix multiplications of \(100 \times 100\) matrices and log the Frobenius norm of the output matrix one result at a time vs. keeping a log on the GPU and transferring only the final result.

  4. Measure how much time it takes to perform two matrix-matrix multiplications on two GPUs at the same time vs. in sequence on one GPU. Hint: you should see almost linear scaling.