.. _sec_channels: Multiple Input and Multiple Output Channels =========================================== While we have described the multiple channels that comprise each image (e.g., color images have the standard RGB channels to indicate the amount of red, green and blue) and convolutional layers for multiple channels in :numref:`subsec_why-conv-channels`, until now, we simplified all of our numerical examples by working with just a single input and a single output channel. This has allowed us to think of our inputs, convolution kernels, and outputs each as two-dimensional tensors. When we add channels into the mix, our inputs and hidden representations both become three-dimensional tensors. For example, each RGB input image has shape :math:`3\times h\times w`. We refer to this axis, with a size of 3, as the *channel* dimension. In this section, we will take a deeper look at convolution kernels with multiple input and multiple output channels. Multiple Input Channels ----------------------- When the input data contain multiple channels, we need to construct a convolution kernel with the same number of input channels as the input data, so that it can perform cross-correlation with the input data. Assuming that the number of channels for the input data is :math:`c_i`, the number of input channels of the convolution kernel also needs to be :math:`c_i`. If our convolution kernel's window shape is :math:`k_h\times k_w`, then when :math:`c_i=1`, we can think of our convolution kernel as just a two-dimensional tensor of shape :math:`k_h\times k_w`. However, when :math:`c_i>1`, we need a kernel that contains a tensor of shape :math:`k_h\times k_w` for *every* input channel. Concatenating these :math:`c_i` tensors together yields a convolution kernel of shape :math:`c_i\times k_h\times k_w`. Since the input and convolution kernel each have :math:`c_i` channels, we can perform a cross-correlation operation on the two-dimensional tensor of the input and the two-dimensional tensor of the convolution kernel for each channel, adding the :math:`c_i` results together (summing over the channels) to yield a two-dimensional tensor. This is the result of a two-dimensional cross-correlation between a multi-channel input and a multi-input-channel convolution kernel. In :numref:`fig_conv_multi_in`, we demonstrate an example of a two-dimensional cross-correlation with two input channels. The shaded portions are the first output element as well as the input and kernel tensor elements used for the output computation: :math:`(1\times1+2\times2+4\times3+5\times4)+(0\times0+1\times1+3\times2+4\times3)=56`. .. _fig_conv_multi_in: .. figure:: ../img/conv-multi-in.svg Cross-correlation computation with 2 input channels. To make sure we really understand what is going on here, we can implement cross-correlation operations with multiple input channels ourselves. Notice that all we are doing is performing one cross-correlation operation per channel and then adding up the results. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python from mxnet import np, npx from d2l import mxnet as d2l npx.set_np() def corr2d_multi_in(X, K): # First, iterate through the 0th dimension (channel dimension) of `X` and # `K`. Then, add them together return sum(d2l.corr2d(x, k) for x, k in zip(X, K)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import torch from d2l import torch as d2l def corr2d_multi_in(X, K): # First, iterate through the 0th dimension (channel dimension) of `X` and # `K`. Then, add them together return sum(d2l.corr2d(x, k) for x, k in zip(X, K)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import tensorflow as tf from d2l import tensorflow as d2l def corr2d_multi_in(X, K): # First, iterate through the 0th dimension (channel dimension) of `X` and # `K`. Then, add them together return tf.reduce_sum([d2l.corr2d(x, k) for x, k in zip(X, K)], axis=0) .. raw:: html
.. raw:: html
We can construct the input tensor ``X`` and the kernel tensor ``K`` corresponding to the values in :numref:`fig_conv_multi_in` to validate the output of the cross-correlation operation. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python X = np.array([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]]) K = np.array([[[0.0, 1.0], [2.0, 3.0]], [[1.0, 2.0], [3.0, 4.0]]]) corr2d_multi_in(X, K) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output array([[ 56., 72.], [104., 120.]]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python X = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]]) K = torch.tensor([[[0.0, 1.0], [2.0, 3.0]], [[1.0, 2.0], [3.0, 4.0]]]) corr2d_multi_in(X, K) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tensor([[ 56., 72.], [104., 120.]]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python X = tf.constant([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]]) K = tf.constant([[[0.0, 1.0], [2.0, 3.0]], [[1.0, 2.0], [3.0, 4.0]]]) corr2d_multi_in(X, K) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output .. raw:: html
.. raw:: html
.. _subsec_multi-output-channels: Multiple Output Channels ------------------------ Regardless of the number of input channels, so far we always ended up with one output channel. However, as we discussed in :numref:`subsec_why-conv-channels`, it turns out to be essential to have multiple channels at each layer. In the most popular neural network architectures, we actually increase the channel dimension as we go higher up in the neural network, typically downsampling to trade off spatial resolution for greater *channel depth*. Intuitively, you could think of each channel as responding to some different set of features. Reality is a bit more complicated than the most naive interpretations of this intuition since representations are not learned independent but are rather optimized to be jointly useful. So it may not be that a single channel learns an edge detector but rather that some direction in channel space corresponds to detecting edges. Denote by :math:`c_i` and :math:`c_o` the number of input and output channels, respectively, and let :math:`k_h` and :math:`k_w` be the height and width of the kernel. To get an output with multiple channels, we can create a kernel tensor of shape :math:`c_i\times k_h\times k_w` for *every* output channel. We concatenate them on the output channel dimension, so that the shape of the convolution kernel is :math:`c_o\times c_i\times k_h\times k_w`. In cross-correlation operations, the result on each output channel is calculated from the convolution kernel corresponding to that output channel and takes input from all channels in the input tensor. We implement a cross-correlation function to calculate the output of multiple channels as shown below. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def corr2d_multi_in_out(X, K): # Iterate through the 0th dimension of `K`, and each time, perform # cross-correlation operations with input `X`. All of the results are # stacked together return np.stack([corr2d_multi_in(X, k) for k in K], 0) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def corr2d_multi_in_out(X, K): # Iterate through the 0th dimension of `K`, and each time, perform # cross-correlation operations with input `X`. All of the results are # stacked together return torch.stack([corr2d_multi_in(X, k) for k in K], 0) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def corr2d_multi_in_out(X, K): # Iterate through the 0th dimension of `K`, and each time, perform # cross-correlation operations with input `X`. All of the results are # stacked together return tf.stack([corr2d_multi_in(X, k) for k in K], 0) .. raw:: html
.. raw:: html
We construct a convolution kernel with 3 output channels by concatenating the kernel tensor ``K`` with ``K+1`` (plus one for each element in ``K``) and ``K+2``. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python K = np.stack((K, K + 1, K + 2), 0) K.shape .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output (3, 2, 2, 2) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python K = torch.stack((K, K + 1, K + 2), 0) K.shape .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output torch.Size([3, 2, 2, 2]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python K = tf.stack((K, K + 1, K + 2), 0) K.shape .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output TensorShape([3, 2, 2, 2]) .. raw:: html
.. raw:: html
Below, we perform cross-correlation operations on the input tensor ``X`` with the kernel tensor ``K``. Now the output contains 3 channels. The result of the first channel is consistent with the result of the previous input tensor ``X`` and the multi-input channel, single-output channel kernel. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python corr2d_multi_in_out(X, K) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output array([[[ 56., 72.], [104., 120.]], [[ 76., 100.], [148., 172.]], [[ 96., 128.], [192., 224.]]]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python corr2d_multi_in_out(X, K) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tensor([[[ 56., 72.], [104., 120.]], [[ 76., 100.], [148., 172.]], [[ 96., 128.], [192., 224.]]]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python corr2d_multi_in_out(X, K) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output .. raw:: html
.. raw:: html
:math:`1\times 1` Convolutional Layer ------------------------------------- At first, a :math:`1 \times 1` convolution, i.e., :math:`k_h = k_w = 1`, does not seem to make much sense. After all, a convolution correlates adjacent pixels. A :math:`1 \times 1` convolution obviously does not. Nonetheless, they are popular operations that are sometimes included in the designs of complex deep networks. Let us see in some detail what it actually does. Because the minimum window is used, the :math:`1\times 1` convolution loses the ability of larger convolutional layers to recognize patterns consisting of interactions among adjacent elements in the height and width dimensions. The only computation of the :math:`1\times 1` convolution occurs on the channel dimension. :numref:`fig_conv_1x1` shows the cross-correlation computation using the :math:`1\times 1` convolution kernel with 3 input channels and 2 output channels. Note that the inputs and outputs have the same height and width. Each element in the output is derived from a linear combination of elements *at the same position* in the input image. You could think of the :math:`1\times 1` convolutional layer as constituting a fully-connected layer applied at every single pixel location to transform the :math:`c_i` corresponding input values into :math:`c_o` output values. Because this is still a convolutional layer, the weights are tied across pixel location. Thus the :math:`1\times 1` convolutional layer requires :math:`c_o\times c_i` weights (plus the bias). .. _fig_conv_1x1: .. figure:: ../img/conv-1x1.svg The cross-correlation computation uses the :math:`1\times 1` convolution kernel with 3 input channels and 2 output channels. The input and output have the same height and width. Let us check whether this works in practice: we implement a :math:`1 \times 1` convolution using a fully-connected layer. The only thing is that we need to make some adjustments to the data shape before and after the matrix multiplication. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def corr2d_multi_in_out_1x1(X, K): c_i, h, w = X.shape c_o = K.shape[0] X = X.reshape((c_i, h * w)) K = K.reshape((c_o, c_i)) # Matrix multiplication in the fully-connected layer Y = np.dot(K, X) return Y.reshape((c_o, h, w)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def corr2d_multi_in_out_1x1(X, K): c_i, h, w = X.shape c_o = K.shape[0] X = X.reshape((c_i, h * w)) K = K.reshape((c_o, c_i)) # Matrix multiplication in the fully-connected layer Y = torch.matmul(K, X) return Y.reshape((c_o, h, w)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def corr2d_multi_in_out_1x1(X, K): c_i, h, w = X.shape c_o = K.shape[0] X = tf.reshape(X, (c_i, h * w)) K = tf.reshape(K, (c_o, c_i)) # Matrix multiplication in the fully-connected layer Y = tf.matmul(K, X) return tf.reshape(Y, (c_o, h, w)) .. raw:: html
.. raw:: html
When performing :math:`1\times 1` convolution, the above function is equivalent to the previously implemented cross-correlation function ``corr2d_multi_in_out``. Let us check this with some sample data. .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python X = np.random.normal(0, 1, (3, 3, 3)) K = np.random.normal(0, 1, (2, 3, 1, 1)) Y1 = corr2d_multi_in_out_1x1(X, K) Y2 = corr2d_multi_in_out(X, K) assert float(np.abs(Y1 - Y2).sum()) < 1e-6 .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python X = torch.normal(0, 1, (3, 3, 3)) K = torch.normal(0, 1, (2, 3, 1, 1)) Y1 = corr2d_multi_in_out_1x1(X, K) Y2 = corr2d_multi_in_out(X, K) assert float(torch.abs(Y1 - Y2).sum()) < 1e-6 .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python X = tf.random.normal((3, 3, 3), 0, 1) K = tf.random.normal((2, 3, 1, 1), 0, 1) Y1 = corr2d_multi_in_out_1x1(X, K) Y2 = corr2d_multi_in_out(X, K) assert float(tf.reduce_sum(tf.abs(Y1 - Y2))) < 1e-6 .. raw:: html
.. raw:: html
Summary ------- - Multiple channels can be used to extend the model parameters of the convolutional layer. - The :math:`1\times 1` convolutional layer is equivalent to the fully-connected layer, when applied on a per pixel basis. - The :math:`1\times 1` convolutional layer is typically used to adjust the number of channels between network layers and to control model complexity. Exercises --------- 1. Assume that we have two convolution kernels of size :math:`k_1` and :math:`k_2`, respectively (with no nonlinearity in between). 1. Prove that the result of the operation can be expressed by a single convolution. 2. What is the dimensionality of the equivalent single convolution? 3. Is the converse true? 2. Assume an input of shape :math:`c_i\times h\times w` and a convolution kernel of shape :math:`c_o\times c_i\times k_h\times k_w`, padding of :math:`(p_h, p_w)`, and stride of :math:`(s_h, s_w)`. 1. What is the computational cost (multiplications and additions) for the forward propagation? 2. What is the memory footprint? 3. What is the memory footprint for the backward computation? 4. What is the computational cost for the backpropagation? 3. By what factor does the number of calculations increase if we double the number of input channels :math:`c_i` and the number of output channels :math:`c_o`? What happens if we double the padding? 4. If the height and width of a convolution kernel is :math:`k_h=k_w=1`, what is the computational complexity of the forward propagation? 5. Are the variables ``Y1`` and ``Y2`` in the last example of this section exactly the same? Why? 6. How would you implement convolutions using matrix multiplication when the convolution window is not :math:`1\times 1`? .. raw:: html
mxnetpytorchtensorflow
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html