diff --git a/prototype_source/README.txt b/prototype_source/README.txt index 42753fe18f8..2da500a830c 100644 --- a/prototype_source/README.txt +++ b/prototype_source/README.txt @@ -2,4 +2,4 @@ Prototype Tutorials ------------------ 1. distributed_rpc_profiling.rst Profiling PyTorch RPC-Based Workloads - https://github.com/pytorch/tutorials/blob/release/1.6/prototype_source/distributed_rpc_profiling.rst \ No newline at end of file + https://github.com/pytorch/tutorials/blob/release/1.6/prototype_source/distributed_rpc_profiling.rst diff --git a/prototype_source/torchscript_freezing.py b/prototype_source/torchscript_freezing.py new file mode 100644 index 00000000000..0b6115c3dc7 --- /dev/null +++ b/prototype_source/torchscript_freezing.py @@ -0,0 +1,134 @@ +""" +Model Freezing in TorchScript +============================= + +In this tutorial, we introduce the syntax for *model freezing* in TorchScript. +Freezing is the process of inlining Pytorch module parameters and attributes +values into the TorchScript internal representation. Parameter and attribute +values are treated as final values and they cannot be modified in the resulting +Frozen module. + +Basic Syntax +------------ +Model freezing can be invoked using API below: + + ``torch.jit.freeze(mod : ScriptModule, names : str[]) -> SciptModule`` + +Note the input module can either be the result of scripting or tracing. +See https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html + +Next, we demonstrate how freezing works using an example: +""" + +import torch, time + +class Net(torch.nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = torch.nn.Conv2d(1, 32, 3, 1) + self.conv2 = torch.nn.Conv2d(32, 64, 3, 1) + self.dropout1 = torch.nn.Dropout2d(0.25) + self.dropout2 = torch.nn.Dropout2d(0.5) + self.fc1 = torch.nn.Linear(9216, 128) + self.fc2 = torch.nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = torch.nn.functional.relu(x) + x = self.conv2(x) + x = torch.nn.functional.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = torch.nn.functional.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = torch.nn.functional.log_softmax(x, dim=1) + return output + + @torch.jit.export + def version(self): + return 1.0 + +net = torch.jit.script(Net()) +fnet = torch.jit.freeze(net) + +print(net.conv1.weight.size()) +print(net.conv1.bias) + +try: + print(fnet.conv1.bias) + # without exception handling, prints: + # RuntimeError: __torch__.z.___torch_mangle_3.Net does not have a field + # with name 'conv1' +except RuntimeError: + print("field 'conv1' is inlined. It does not exist in 'fnet'") + +try: + fnet.version() + # without exception handling, prints: + # RuntimeError: __torch__.z.___torch_mangle_3.Net does not have a field + # with name 'version' +except RuntimeError: + print("method 'version' is not deleted in fnet. Only 'forward' is preserved") + +fnet2 = torch.jit.freeze(net, ["version"]) + +print(fnet2.version()) + +B=1 +warmup = 1 +iter = 1000 +input = torch.rand(B, 1,28, 28) + +start = time.time() +for i in range(warmup): + net(input) +end = time.time() +print("Scripted - Warm up time: {0:7.4f}".format(end-start), flush=True) + +start = time.time() +for i in range(warmup): + fnet(input) +end = time.time() +print("Frozen - Warm up time: {0:7.4f}".format(end-start), flush=True) + +start = time.time() +for i in range(iter): + input = torch.rand(B, 1,28, 28) + net(input) +end = time.time() +print("Scripted - Inference: {0:5.2f}".format(end-start), flush=True) + +start = time.time() +for i in range(iter): + input = torch.rand(B, 1,28, 28) + fnet2(input) +end = time.time() +print("Frozen - Inference time: {0:5.2f}".format(end-start), flush =True) + +############################################################### +# On my machine, I measured the time: +# +# * Scripted - Warm up time: 0.0107 +# * Frozen - Warm up time: 0.0048 +# * Scripted - Inference: 1.35 +# * Frozen - Inference time: 1.17 + +############################################################### +# In our example, warm up time measures the first two runs. The frozen model +# is 50% faster than the scripted model. On some more complex models, we +# observed even higher speed up of warm up time. freezing achieves this speed up +# because it is doing some the work TorchScript has to do when the first couple +# runs are initiated. +# +# Inference time measures inference execution time after the model is warmed up. +# Although we observed significant variation in execution time, the +# frozen model is often about 15% faster than the scripted model. When input is larger, +# we observe a smaller speed up because the execution is dominated by tensor operations. + +############################################################### +# Conclusion +# ----------- +# In this tutorial, we learned about model freezing. Freezing is a useful technique to +# optimize models for inference and it also can significantly reduce TorchScript warmup time.