12.1. Compilers and Interpreters¶ Open the notebook in SageMaker Studio Lab
So far, this book has focused on imperative programming, which makes use
of statements such as print
, +
, and if
to change a program’s
state. Consider the following example of a simple imperative program.
def add(a, b):
return a + b
def fancy_func(a, b, c, d):
e = add(a, b)
f = add(c, d)
g = add(e, f)
return g
print(fancy_func(1, 2, 3, 4))
10
def add(a, b):
return a + b
def fancy_func(a, b, c, d):
e = add(a, b)
f = add(c, d)
g = add(e, f)
return g
print(fancy_func(1, 2, 3, 4))
10
def add(a, b):
return a + b
def fancy_func(a, b, c, d):
e = add(a, b)
f = add(c, d)
g = add(e, f)
return g
print(fancy_func(1, 2, 3, 4))
10
Python is an interpreted language. When evaluating the above
fancy_func
function it performs the operations making up the
function’s body in sequence. That is, it will evaluate
e = add(a, b)
and store the results as variable e
, thereby
changing the program’s state. The next two statements f = add(c, d)
and g = add(e, f)
will be executed similarly, performing additions
and storing the results as variables. Fig. 12.1.1
illustrates the flow of data.
Although imperative programming is convenient, it may be inefficient. On
one hand, even if the add
function is repeatedly called throughout
fancy_func
, Python will execute the three function calls
individually. If these are executed, say, on a GPU (or even on multiple
GPUs), the overhead arising from the Python interpreter can become
overwhelming. Moreover, it will need to save the variable values of
e
and f
until all the statements in fancy_func
have been
executed. This is because we do not know whether the variables e
and
f
will be used by other parts of the program after the statements
e = add(a, b)
and f = add(c, d)
are executed.
12.1.1. Symbolic Programming¶
Consider the alternative, symbolic programming, where computation is usually performed only once the process has been fully defined. This strategy is used by multiple deep learning frameworks, including Theano and TensorFlow (the latter has acquired imperative extensions). It usually involves the following steps:
Define the operations to be executed.
Compile the operations into an executable program.
Provide the required inputs and call the compiled program for execution.
This allows for a significant amount of optimization. First, we can skip
the Python interpreter in many cases, thus removing a performance
bottleneck that can become significant on multiple fast GPUs paired with
a single Python thread on a CPU. Second, a compiler might optimize and
rewrite the above code into print((1 + 2) + (3 + 4))
or even
print(10)
. This is possible since a compiler gets to see the full
code before turning it into machine instructions. For instance, it can
release memory (or never allocate it) whenever a variable is no longer
needed. Or it can transform the code entirely into an equivalent piece.
To get a better idea, consider the following simulation of imperative
programming (it is Python after all) below.
def add_():
return '''
def add(a, b):
return a + b
'''
def fancy_func_():
return '''
def fancy_func(a, b, c, d):
e = add(a, b)
f = add(c, d)
g = add(e, f)
return g
'''
def evoke_():
return add_() + fancy_func_() + 'print(fancy_func(1, 2, 3, 4))'
prog = evoke_()
print(prog)
y = compile(prog, '', 'exec')
exec(y)
def add(a, b):
return a + b
def fancy_func(a, b, c, d):
e = add(a, b)
f = add(c, d)
g = add(e, f)
return g
print(fancy_func(1, 2, 3, 4))
10
def add_():
return '''
def add(a, b):
return a + b
'''
def fancy_func_():
return '''
def fancy_func(a, b, c, d):
e = add(a, b)
f = add(c, d)
g = add(e, f)
return g
'''
def evoke_():
return add_() + fancy_func_() + 'print(fancy_func(1, 2, 3, 4))'
prog = evoke_()
print(prog)
y = compile(prog, '', 'exec')
exec(y)
def add(a, b):
return a + b
def fancy_func(a, b, c, d):
e = add(a, b)
f = add(c, d)
g = add(e, f)
return g
print(fancy_func(1, 2, 3, 4))
10
def add_():
return '''
def add(a, b):
return a + b
'''
def fancy_func_():
return '''
def fancy_func(a, b, c, d):
e = add(a, b)
f = add(c, d)
g = add(e, f)
return g
'''
def evoke_():
return add_() + fancy_func_() + 'print(fancy_func(1, 2, 3, 4))'
prog = evoke_()
print(prog)
y = compile(prog, '', 'exec')
exec(y)
def add(a, b):
return a + b
def fancy_func(a, b, c, d):
e = add(a, b)
f = add(c, d)
g = add(e, f)
return g
print(fancy_func(1, 2, 3, 4))
10
The differences between imperative (interpreted) programming and symbolic programming are as follows:
Imperative programming is easier. When imperative programming is used in Python, the majority of the code is straightforward and easy to write. It is also easier to debug imperative programming code. This is because it is easier to obtain and print all relevant intermediate variable values, or use Python’s built-in debugging tools.
Symbolic programming is more efficient and easier to port. Symbolic programming makes it easier to optimize the code during compilation, while also having the ability to port the program into a format independent of Python. This allows the program to be run in a non-Python environment, thus avoiding any potential performance issues related to the Python interpreter.
12.1.2. Hybrid Programming¶
Historically most deep learning frameworks choose between an imperative or a symbolic approach. For example, Theano, TensorFlow (inspired by the former), Keras, and CNTK formulate models symbolically. Conversely, Chainer and PyTorch take an imperative approach. An imperative mode was added to TensorFlow 2.0 and Keras in later revisions.
When designing Gluon, developers considered whether it would be possible to combine the benefits of both programming paradigms. This led to a hybrid model that lets users develop and debug with pure imperative programming, while having the ability to convert most programs into symbolic programs to be run when product-level computing performance and deployment are required.
In practice this means that we build models using the HybridBlock
or
HybridSequential
class. By default, either of them is executed in
the same way the Block
or Sequential
class is executed in
imperative programming. The HybridSequential
class is a subclass of
HybridBlock
(just like Sequential
subclasses Block
). When
the hybridize
function is called, Gluon compiles the model into the
form used in symbolic programming. This allows one to optimize the
computation-intensive components without sacrifices in the way a model
is implemented. We will illustrate the benefits below, focusing on
sequential models and blocks.
As mentioned above, PyTorch is based on imperative programming and uses dynamic computation graphs. In an effort to leverage the portability and efficiency of symbolic programming, developers considered whether it would be possible to combine the benefits of both programming models. This led to a torchscript that lets users develop and debug using pure imperative programming, while having the ability to convert most programs into symbolic programs to be run when product-level computing performance and deployment are required.
The imperative programming paradigm is now the default in Tensorflow 2,
a welcoming change for those new to the language. However, the same
symbolic programming techniques and subsequent computational graphs
still exist in TensorFlow, and can be accessed by the easy-to-use
tf.function
decorator. This brought the imperative programming
paradigm to TensorFlow, allowed users to define more intuitive
functions, then wrap them and compile them into computational graphs
automatically using a feature the TensorFlow team refers to as
autograph.
12.1.3. Hybridizing the Sequential
Class¶
The easiest way to get a feel for how hybridization works is to consider
deep networks with multiple layers. Conventionally the Python
interpreter will need to execute the code for all layers to generate an
instruction that can then be forwarded to a CPU or a GPU. For a single
(fast) computing device this does not cause any major issues. On the
other hand, if we use an advanced 8-GPU server such as an AWS
P3dn.24xlarge instance Python will struggle to keep all GPUs busy. The
single-threaded Python interpreter becomes the bottleneck here. Let us
see how we can address this for significant parts of the code by
replacing Sequential
with HybridSequential
. We begin by defining
a simple MLP.
from mxnet import np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l
npx.set_np()
# Factory for networks
def get_net():
net = nn.HybridSequential()
net.add(nn.Dense(256, activation='relu'),
nn.Dense(128, activation='relu'),
nn.Dense(2))
net.initialize()
return net
x = np.random.normal(size=(1, 512))
net = get_net()
net(x)
array([[ 0.16526186, -0.14005628]])
By calling the hybridize
function, we are able to compile and
optimize the computation in the MLP. The model’s computation result
remains unchanged.
net.hybridize()
net(x)
array([[ 0.16526186, -0.14005628]])
This seems almost too good to be true: simply designate a block to be
HybridSequential
, write the same code as before and invoke
hybridize
. Once this happens the network is optimized (we will
benchmark the performance below). Unfortunately this does not work
magically for every layer. That said, a layer will not be optimized if
it inherits from the Block
class instead of the HybridBlock
class.
import torch
from torch import nn
from d2l import torch as d2l
# Factory for networks
def get_net():
net = nn.Sequential(nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 2))
return net
x = torch.randn(size=(1, 512))
net = get_net()
net(x)
tensor([[-0.1078, 0.1045]], grad_fn=<AddmmBackward0>)
By converting the model using torch.jit.script
function, we are able
to compile and optimize the computation in the MLP. The model’s
computation result remains unchanged.
net = torch.jit.script(net)
net(x)
tensor([[-0.1078, 0.1045]], grad_fn=<AddmmBackward0>)
This seems almost too good to be true: write the same code as before and
simply convert the model using torch.jit.script
. Once this happens
the network is optimized (we will benchmark the performance below).
import tensorflow as tf
from tensorflow.keras.layers import Dense
from d2l import tensorflow as d2l
# Factory for networks
def get_net():
net = tf.keras.Sequential()
net.add(Dense(256, input_shape = (512,), activation = "relu"))
net.add(Dense(128, activation = "relu"))
net.add(Dense(2, activation = "linear"))
return net
x = tf.random.normal([1,512])
net = get_net()
net(x)
<tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.26736248, 0.6471752 ]], dtype=float32)>
Formerly, all functions built in TensorFlow were built as a computational graph, and therefore JIT compiled by default. However, with the release of TensorFlow 2.X and EagerTensor, this is no longer the default behavor. We cen re-enable this functionality with tf.function. tf.function is more commonly used as a function decorator, however it is possible to call it direcly as a normal python function, shown below. The model’s computation result remains unchanged.
net = tf.function(net)
net(x)
<tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.26736248, 0.6471752 ]], dtype=float32)>
This seems almost too good to be true: write the same code as before and
simply convert the model using tf.function
. Once this happens the
network is built as a computational graph in TensorFlow’s MLIR
intermediate representation and is heavily optimized at the compiler
level for rapid execution (we will benchmark the performance below).
Explicitly adding the jit_compile = True
flag to the
tf.function()
call enables XLA (Accelerated Linear Algebra)
functionality in TensorFlow. XLA can further optimize JIT compiled code
in certain instances. Graph-mode execution is enabled without this
explicit definition, however XLA can make certain large linear algebra
operations (in the vein of those we see in deep learning applications)
much faster, particularly in a GPU environment.
12.1.3.1. Acceleration by Hybridization¶
To demonstrate the performance improvement gained by compilation we
compare the time needed to evaluate net(x)
before and after
hybridization. Let us define a class to measure this time first. It will
come handy throughout the chapter as we set out to measure (and improve)
performance.
#@save
class Benchmark:
"""For measuring running time."""
def __init__(self, description='Done'):
self.description = description
def __enter__(self):
self.timer = d2l.Timer()
return self
def __exit__(self, *args):
print(f'{self.description}: {self.timer.stop():.4f} sec')
Now we can invoke the network twice, once with and once without hybridization.
net = get_net()
with Benchmark('Without hybridization'):
for i in range(1000): net(x)
npx.waitall()
net.hybridize()
with Benchmark('With hybridization'):
for i in range(1000): net(x)
npx.waitall()
Without hybridization: 0.8476 sec
With hybridization: 0.2086 sec
As is observed in the above results, after a HybridSequential
instance calls the hybridize
function, computing performance is
improved through the use of symbolic programming.
#@save
class Benchmark:
"""For measuring running time."""
def __init__(self, description='Done'):
self.description = description
def __enter__(self):
self.timer = d2l.Timer()
return self
def __exit__(self, *args):
print(f'{self.description}: {self.timer.stop():.4f} sec')
Now we can invoke the network twice, once with and once without torchscript.
net = get_net()
with Benchmark('Without torchscript'):
for i in range(1000): net(x)
net = torch.jit.script(net)
with Benchmark('With torchscript'):
for i in range(1000): net(x)
Without torchscript: 2.4646 sec
With torchscript: 2.7696 sec
As is observed in the above results, after an nn.Sequential
instance
is scripted using the torch.jit.script
function, computing
performance is improved through the use of symbolic programming.
#@save
class Benchmark:
"""For measuring running time."""
def __init__(self, description='Done'):
self.description = description
def __enter__(self):
self.timer = d2l.Timer()
return self
def __exit__(self, *args):
print(f'{self.description}: {self.timer.stop():.4f} sec')
Now we can invoke the network three times, once executed eagerly, once with graph-mode execution, and again using JIT compiled XLA.
net = get_net()
with Benchmark('Eager Mode'):
for i in range(1000): net(x)
net = tf.function(net)
with Benchmark('Graph Mode'):
for i in range(1000): net(x)
Eager Mode: 1.2665 sec
Graph Mode: 0.4969 sec
As is observed in the above results, after a tf.keras.Sequential
instance is scripted using the tf.function
function, computing
performance is improved through the use of symbolic programming via
graph-mode execution in tensorflow.
12.1.3.2. Serialization¶
One of the benefits of compiling the models is that we can serialize
(save) the model and its parameters to disk. This allows us to store a
model in a manner that is independent of the front-end language of
choice. This allows us to deploy trained models to other devices and
easily use other front-end programming languages. At the same time the
code is often faster than what can be achieved in imperative
programming. Let us see the export
function in action.
net.export('my_mlp')
!ls -lh my_mlp*
-rw-rw-r-- 1 d2l-worker d2l-worker 643K Nov 12 19:59 my_mlp-0000.params
-rw-rw-r-- 1 d2l-worker d2l-worker 3.0K Nov 12 19:59 my_mlp-symbol.json
The model is decomposed into a (large binary) parameter file and a JSON description of the program required to execute the model computation. The files can be read by other front-end languages supported by Python or MXNet, such as C++, R, Scala, and Perl. Let us have a look at the first few lines in the model description.
!head my_mlp-symbol.json
{
"nodes": [
{
"op": "null",
"name": "data",
"inputs": []
},
{
"op": "null",
"name": "dense3_weight",
Earlier, we demonstrated that, after calling the hybridize
function,
the model is able to achieve superior computing performance and
portability. Note, though that hybridization can affect model
flexibility, in particular in terms of control flow.
Besides, contrary to the Block
instance, which needs to use the
forward
function, for a HybridBlock
instance we need to use the
hybrid_forward
function.
class HybridNet(nn.HybridBlock):
def __init__(self, **kwargs):
super(HybridNet, self).__init__(**kwargs)
self.hidden = nn.Dense(4)
self.output = nn.Dense(2)
def hybrid_forward(self, F, x):
print('module F: ', F)
print('value x: ', x)
x = F.npx.relu(self.hidden(x))
print('result : ', x)
return self.output(x)
The code above implements a simple network with 4 hidden units and 2
outputs. The hybrid_forward
function takes an additional argument
F
. This is needed since, depending on whether the code has been
hybridized or not, it will use a slightly different library (ndarray
or symbol
) for processing. Both classes perform very similar
functions and MXNet automatically determines the argument. To understand
what is going on we print the arguments as part of the function
invocation.
net = HybridNet()
net.initialize()
x = np.random.normal(size=(1, 3))
net(x)
module F: <module 'mxnet.ndarray' from '/home/d2l-worker/miniconda3/envs/d2l-en-classic-1/lib/python3.9/site-packages/mxnet/ndarray/__init__.py'>
value x: [[-0.6338663 0.40156594 0.46456942]]
result : [[0.01641375 0. 0. 0. ]]
array([[0.00097611, 0.00019453]])
Repeating the forward computation will lead to the same output (we omit
details). Now let us see what happens if we invoke the hybridize
function.
net.hybridize()
net(x)
module F: <module 'mxnet.symbol' from '/home/d2l-worker/miniconda3/envs/d2l-en-classic-1/lib/python3.9/site-packages/mxnet/symbol/__init__.py'>
value x: <_Symbol data>
result : <_Symbol hybridnet0_relu0>
array([[0.00097611, 0.00019453]])
Instead of using ndarray
we now use the symbol
module for F
.
Moreover, even though the input is of ndarray
type, the data flowing
through the network is now converted to symbol
type as part of the
compilation process. Repeating the function call leads to a surprising
outcome:
net(x)
array([[0.00097611, 0.00019453]])
This is quite different from what we saw previously. All print
statements, as defined in hybrid_forward
, are omitted. Indeed, after
hybridization the execution of net(x)
does not involve the Python
interpreter any longer. This means that any spurious Python code is
omitted (such as print statements) in favor of a much more streamlined
execution and better performance. Instead, MXNet directly calls the C++
backend. Also note that some functions are not supported in the
symbol
module (e.g., asnumpy
) and operations in-place such as
a += b
and a[:] = a + b
must be rewritten as a = a + b
.
Nonetheless, compilation of models is worth the effort whenever speed
matters. The benefit can range from small percentage points to more than
twice the speed, depending on the complexity of the model, the speed of
the CPU, and the speed and number of GPUs.
One of the benefits of compiling the models is that we can serialize
(save) the model and its parameters to disk. This allows us to store a
model in a manner that is independent of the front-end language of
choice. This allows us to deploy trained models to other devices and
easily use other front-end programming languages. At the same time the
code is often faster than what can be achieved in imperative
programming. Let us see the save
function in action.
net.save('my_mlp')
!ls -lh my_mlp*
-rw-rw-r-- 1 d2l-worker d2l-worker 652K Nov 12 20:45 my_mlp
One of the benefits of compiling the models is that we can serialize
(save) the model and its parameters to disk. This allows us to store a
model in a manner that is independent of the front-end language of
choice. This allows us to deploy trained models to other devices and
easily use other front-end programming languages or execute a trained
model on a server. At the same time the code is often faster than what
can be achieved in imperative programming. The low-level API that allows
us to save in tensorflow is tf.saved_model
. Let’s see the
saved_model
instance in action.
net = get_net()
tf.saved_model.save(net, 'my_mlp')
!ls -lh my_mlp*
INFO:tensorflow:Assets written to: my_mlp/assets
total 72K
drwxr-xr-x 2 d2l-worker d2l-worker 4.0K Nov 12 22:09 assets
-rw-rw-r-- 1 d2l-worker d2l-worker 64K Nov 12 22:09 saved_model.pb
drwxr-xr-x 2 d2l-worker d2l-worker 4.0K Nov 12 22:09 variables
12.1.4. Summary¶
Imperative programming makes it easy to design new models since it is possible to write code with control flow and the ability to use a large amount of the Python software ecosystem.
Symbolic programming requires that we specify the program and compile it before executing it. The benefit is improved performance.
MXNet is able to combine the advantages of both approaches as needed.
Models constructed by the
HybridSequential
andHybridBlock
classes are able to convert imperative programs into symbolic programs by calling thehybridize
function.
12.1.5. Exercises¶
Add
x.asnumpy()
to the first line of thehybrid_forward
function of theHybridNet
class in this section. Execute the code and observe the errors you encounter. Why do they happen?What happens if we add control flow, i.e., the Python statements
if
andfor
in thehybrid_forward
function?Review the models that interest you in the previous chapters. Can you improve their computational performance by reimplementing them?
Review the models that interest you in the previous chapters. Can you improve their computational performance by reimplementing them?
Review the models that interest you in the previous chapters. Can you improve their computational performance by reimplementing them?