Skip to content

1.6 model freezing tutorial #1077

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Jul 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
a9bd604
Update feature classification labels
Jun 18, 2020
45d02c7
Update NVidia -> Nvidia
Jun 23, 2020
8d32bae
Merge branch 'master' into master
Jun 25, 2020
9a7250d
Merge branch 'master' into master
Jul 2, 2020
0d48b72
Merge pull request #1035 from jlin27/master
Jul 4, 2020
68c22a0
Bring back default filename_pattern so that by default we run all gal…
ezyang Jul 8, 2020
b6d1838
Add prototype_source directory
Jul 8, 2020
01fc130
Add prototype directory
Jul 8, 2020
26511cc
Add prototype
Jul 8, 2020
fb779e1
Remove extra "done"
Jul 8, 2020
494d037
Add REAME.txt
Jul 9, 2020
23fb4c7
Merge pull request #1058 from jlin27/master
Jul 9, 2020
d32aa04
Update for prototype instructions
Jul 9, 2020
67f76d3
Update for prototype feature
Jul 9, 2020
958aa33
refine torchvision_tutorial doc for windows
guyang3532 Jul 9, 2020
c83c23d
Merge pull request #1060 from guyang3532/fix_torchvision_tutorial_win
ezyang Jul 9, 2020
9b0635d
Update neural_style_tutorial.py (#1059)
hritikbhandari Jul 9, 2020
3740027
torch_script_custom_ops restructure (#1057)
ezyang Jul 9, 2020
3e32d22
Port custom ops tutorial to new registration API, increase testability.
ezyang Jul 9, 2020
999a029
Kill some other occurrences of RegisterOperators
ezyang Jul 9, 2020
f90f773
Update README.md
Jul 9, 2020
c6059ec
Make torch_script_custom_classes tutorial runnable
ezyang Jul 9, 2020
32e5407
Update torch_script_custom_classes to use TORCH_LIBRARY (#1062)
ezyang Jul 14, 2020
36c67f6
Add Model Freezing in TorchScript
Jul 21, 2020
36089ae
Merge branch 'release/1.6' into 1.6-model-freezing-tutorial
Jul 22, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion prototype_source/README.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ Prototype Tutorials
------------------
1. distributed_rpc_profiling.rst
Profiling PyTorch RPC-Based Workloads
https://github.com/pytorch/tutorials/blob/release/1.6/prototype_source/distributed_rpc_profiling.rst
https://github.com/pytorch/tutorials/blob/release/1.6/prototype_source/distributed_rpc_profiling.rst
134 changes: 134 additions & 0 deletions prototype_source/torchscript_freezing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
Model Freezing in TorchScript
=============================

In this tutorial, we introduce the syntax for *model freezing* in TorchScript.
Freezing is the process of inlining Pytorch module parameters and attributes
values into the TorchScript internal representation. Parameter and attribute
values are treated as final values and they cannot be modified in the resulting
Frozen module.

Basic Syntax
------------
Model freezing can be invoked using API below:

``torch.jit.freeze(mod : ScriptModule, names : str[]) -> SciptModule``

Note the input module can either be the result of scripting or tracing.
See https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html

Next, we demonstrate how freezing works using an example:
"""

import torch, time

class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 32, 3, 1)
self.conv2 = torch.nn.Conv2d(32, 64, 3, 1)
self.dropout1 = torch.nn.Dropout2d(0.25)
self.dropout2 = torch.nn.Dropout2d(0.5)
self.fc1 = torch.nn.Linear(9216, 128)
self.fc2 = torch.nn.Linear(128, 10)

def forward(self, x):
x = self.conv1(x)
x = torch.nn.functional.relu(x)
x = self.conv2(x)
x = torch.nn.functional.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = torch.nn.functional.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = torch.nn.functional.log_softmax(x, dim=1)
return output

@torch.jit.export
def version(self):
return 1.0

net = torch.jit.script(Net())
fnet = torch.jit.freeze(net)

print(net.conv1.weight.size())
print(net.conv1.bias)

try:
print(fnet.conv1.bias)
# without exception handling, prints:
# RuntimeError: __torch__.z.___torch_mangle_3.Net does not have a field
# with name 'conv1'
except RuntimeError:
print("field 'conv1' is inlined. It does not exist in 'fnet'")

try:
fnet.version()
# without exception handling, prints:
# RuntimeError: __torch__.z.___torch_mangle_3.Net does not have a field
# with name 'version'
except RuntimeError:
print("method 'version' is not deleted in fnet. Only 'forward' is preserved")

fnet2 = torch.jit.freeze(net, ["version"])

print(fnet2.version())

B=1
warmup = 1
iter = 1000
input = torch.rand(B, 1,28, 28)

start = time.time()
for i in range(warmup):
net(input)
end = time.time()
print("Scripted - Warm up time: {0:7.4f}".format(end-start), flush=True)

start = time.time()
for i in range(warmup):
fnet(input)
end = time.time()
print("Frozen - Warm up time: {0:7.4f}".format(end-start), flush=True)

start = time.time()
for i in range(iter):
input = torch.rand(B, 1,28, 28)
net(input)
end = time.time()
print("Scripted - Inference: {0:5.2f}".format(end-start), flush=True)

start = time.time()
for i in range(iter):
input = torch.rand(B, 1,28, 28)
fnet2(input)
end = time.time()
print("Frozen - Inference time: {0:5.2f}".format(end-start), flush =True)

###############################################################
# On my machine, I measured the time:
#
# * Scripted - Warm up time: 0.0107
# * Frozen - Warm up time: 0.0048
# * Scripted - Inference: 1.35
# * Frozen - Inference time: 1.17

###############################################################
# In our example, warm up time measures the first two runs. The frozen model
# is 50% faster than the scripted model. On some more complex models, we
# observed even higher speed up of warm up time. freezing achieves this speed up
# because it is doing some the work TorchScript has to do when the first couple
# runs are initiated.
#
# Inference time measures inference execution time after the model is warmed up.
# Although we observed significant variation in execution time, the
# frozen model is often about 15% faster than the scripted model. When input is larger,
# we observe a smaller speed up because the execution is dominated by tensor operations.

###############################################################
# Conclusion
# -----------
# In this tutorial, we learned about model freezing. Freezing is a useful technique to
# optimize models for inference and it also can significantly reduce TorchScript warmup time.