From 2ca73d9250d4172454c4f165807b1bcb993230d2 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 13 Apr 2020 15:57:41 +0900 Subject: [PATCH 1/8] add pytorch tutorial code and doc stub --- tutorials/frontend/deploy_prequantized.py | 147 ++++++++++++++++++++++ 1 file changed, 147 insertions(+) create mode 100644 tutorials/frontend/deploy_prequantized.py diff --git a/tutorials/frontend/deploy_prequantized.py b/tutorials/frontend/deploy_prequantized.py new file mode 100644 index 000000000000..c5b5a749b400 --- /dev/null +++ b/tutorials/frontend/deploy_prequantized.py @@ -0,0 +1,147 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Deploy a Quantized Model on Cuda +================================ +**Author**: `Masahiro Masuda `_ + +This is an a tutorial on loading models quantized by deep learning frameworks into TVM. +Pre-quantized model import is one of the quantization support we have in TVM. More details on +the quantization story in TVM can be found +`here `_. +""" +from PIL import Image + +import numpy as np + +import torch +from torchvision.models.quantization import mobilenet as qmobilenet + +import tvm +from tvm import relay +from tvm.contrib.download import download_testdata + + +################################################################################# +# Helper functions +def get_transform(): + import torchvision.transforms as transforms + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + return transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ]) + + +def get_real_image(im_height, im_width): + img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true' + img_path = download_testdata(img_url, 'cat.png', module='data') + return Image.open(img_path).resize((im_height, im_width)) + + +def get_imagenet_input(): + im = get_real_image(224, 224) + preprocess = get_transform() + pt_tensor = preprocess(im) + return np.expand_dims(pt_tensor.numpy(), 0) + + +def get_synset(): + synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/', + '4d0b62f3d01426887599d4f7ede23ee5/raw/', + '596b27d23537e5a1b5751d2b0481ef172f58b539/', + 'imagenet1000_clsid_to_human.txt']) + synset_name = 'imagenet1000_clsid_to_human.txt' + synset_path = download_testdata(synset_url, synset_name, module='data') + with open(synset_path) as f: + return eval(f.read()) + + +################################################################################# +# A mapping from label to class name and an input cat image used for demonstration +synset = get_synset() +inp = get_imagenet_input() + +############################################################################### +# Deploy quantized PyTorch Model +# ------------------ +def quantize_model(model, inp): + model.fuse_model() + model.qconfig = torch.quantization.get_default_qconfig('fbgemm') + torch.quantization.prepare(model, inplace=True) + # Dummy calibration + model(inp) + torch.quantization.convert(model, inplace=True) + + +###################################################################### +# Load quantization ready Mobilenet v2 model from torchvision +# ----------------- +qmodel = qmobilenet.mobilenet_v2(pretrained=True).eval() + +###################################################################### +# Quantize, trace and run the PyTorch Mobilenet v2 model +# ----------------- +pt_inp = torch.from_numpy(inp) +quantize_model(qmodel, pt_inp) +script_module = torch.jit.trace(qmodel, pt_inp).eval() + +with torch.no_grad(): + pt_result = script_module(pt_inp).numpy() + +###################################################################### +# Convert quantized Mobilenet v2 to Relay-QNN using the PyTorch frontend +# ----------------- +input_name = "input" +input_shapes = [(input_name, (1, 3, 224, 224))] +mod, params = relay.frontend.from_pytorch(script_module, input_shapes) + +###################################################################### +# Compile and run the Relay module +# ----------------- +with relay.build_config(opt_level=3): + json, lib, params = relay.build(mod, target="llvm", params=params) + +runtime = tvm.contrib.graph_runtime.create(json, lib, tvm.cpu(0)) +runtime.set_input(**params) + +runtime.set_input(input_name, inp) +runtime.run() +tvm_result = runtime.get_output(0).asnumpy() + +###################################################################### +# Compare the output labels +# ----------------- +pt_top3_labels = np.argsort(pt_result[0])[::-1][:3] +tvm_top3_labels = np.argsort(tvm_result[0])[::-1][:3] + +print("PyTorch top3 label:", [synset[label] for label in pt_top3_labels]) +print("TVM top3 label:", [synset[label] for label in tvm_top3_labels]) + + +############################################################################### +# Deploy quantized MXNet Model +# ------------------ +# TODO + +############################################################################### +# Deploy quantized TFLite Model +# ------------------ +# TODO From 992e96513a84b373782871ec97d695143576ae78 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 14 Apr 2020 04:01:05 +0900 Subject: [PATCH 2/8] add more docs --- tutorials/frontend/deploy_prequantized.py | 81 +++++++++++++++++------ tutorials/frontend/from_pytorch.py | 4 +- 2 files changed, 61 insertions(+), 24 deletions(-) diff --git a/tutorials/frontend/deploy_prequantized.py b/tutorials/frontend/deploy_prequantized.py index c5b5a749b400..e541d71eb5c5 100644 --- a/tutorials/frontend/deploy_prequantized.py +++ b/tutorials/frontend/deploy_prequantized.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """ -Deploy a Quantized Model on Cuda +Deploy a Framework-prequantized Model with TVM ================================ **Author**: `Masahiro Masuda `_ @@ -23,6 +23,8 @@ Pre-quantized model import is one of the quantization support we have in TVM. More details on the quantization story in TVM can be found `here `_. +Here, we demonstrate how to load and run models quantized by PyTorch, MXNet, and TFLite. +Once loaded, we can run quantized models on any hardware TVM supports. """ from PIL import Image @@ -37,7 +39,7 @@ ################################################################################# -# Helper functions +# Helper functions to run the demo def get_transform(): import torchvision.transforms as transforms normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], @@ -74,14 +76,40 @@ def get_synset(): return eval(f.read()) +def run_tvm_model(mod, params, input_name, inp, target="llvm"): + with relay.build_config(opt_level=3): + json, lib, params = relay.build(mod, target=target, params=params) + + runtime = tvm.contrib.graph_runtime.create(json, lib, tvm.context(target, 0)) + runtime.set_input(**params) + + runtime.set_input(input_name, inp) + runtime.run() + return runtime.get_output(0).asnumpy() + + ################################################################################# -# A mapping from label to class name and an input cat image used for demonstration +# A mapping from label to class name, to verify that the outputs from models below +# are reasonable synset = get_synset() + +################################################################################# +# Everyone's favorite cat image for demonstration inp = get_imagenet_input() -############################################################################### -# Deploy quantized PyTorch Model +################################################################################ +# Deploy a quantized PyTorch Model # ------------------ +# First, we demonstrate how to load deep learning models quantized by PyTorch, +# using our PyTorch frontend. + +################################################################################## +# A helper function for converting floating point PyTorch models to quantized ones +# Please refer to the PyTorch static quantization tutorial below to learn about +# their quantization workflow. +# https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html +# In short, this function takes a floating point model and converts it to a uint8 +# model. A model is per-channel quantized. def quantize_model(model, inp): model.fuse_model() model.qconfig = torch.quantization.get_default_qconfig('fbgemm') @@ -91,14 +119,18 @@ def quantize_model(model, inp): torch.quantization.convert(model, inplace=True) -###################################################################### -# Load quantization ready Mobilenet v2 model from torchvision +############################################################################## +# Load quantization-ready, pretrained Mobilenet v2 model from torchvision # ----------------- +# We choose mobilenet v2 because this model was trained with quantization aware +# training. Other models require a full post training calibration. qmodel = qmobilenet.mobilenet_v2(pretrained=True).eval() -###################################################################### +############################################################################## # Quantize, trace and run the PyTorch Mobilenet v2 model # ----------------- +# The details are out of scope for this tutorial. Please refer to the tutorials +# on the PyTorch website to learn about quantization and jit. pt_inp = torch.from_numpy(inp) quantize_model(qmodel, pt_inp) script_module = torch.jit.trace(qmodel, pt_inp).eval() @@ -106,29 +138,34 @@ def quantize_model(model, inp): with torch.no_grad(): pt_result = script_module(pt_inp).numpy() -###################################################################### +############################################################################## # Convert quantized Mobilenet v2 to Relay-QNN using the PyTorch frontend # ----------------- -input_name = "input" +# The PyTorch frontend has support for converting a quantized PyTorch model to +# an equivalent Relay module enriched with quantization-aware operators. +# We call this representation Relay QNN dialect. +# You can print the output from the frontend to see how quantized models are +# represented. You would see operators specfic to quantization such as +# qnn.quantize, qnn.dequantize, qnn.requantize, and qnn.conv2d etc. +input_name = "input" # the input name can be be arbitrary for PyTorch frontend. input_shapes = [(input_name, (1, 3, 224, 224))] mod, params = relay.frontend.from_pytorch(script_module, input_shapes) +# print(mod) -###################################################################### +############################################################################## # Compile and run the Relay module # ----------------- -with relay.build_config(opt_level=3): - json, lib, params = relay.build(mod, target="llvm", params=params) - -runtime = tvm.contrib.graph_runtime.create(json, lib, tvm.cpu(0)) -runtime.set_input(**params) - -runtime.set_input(input_name, inp) -runtime.run() -tvm_result = runtime.get_output(0).asnumpy() +# Once we obtained the quantized Relay module, the rest of the workflow +# is the same as running floating point models. Please refer to other +# tutorials for more details. +# Under the hood, quantization specific operators are lowered to a sequence of +# standard Relay operators before compilation. +tvm_result = run_tvm_model(mod, params, input_name, inp, target="llvm") ###################################################################### # Compare the output labels # ----------------- +# We should see identical labels printed. pt_top3_labels = np.argsort(pt_result[0])[::-1][:3] tvm_top3_labels = np.argsort(tvm_result[0])[::-1][:3] @@ -137,11 +174,11 @@ def quantize_model(model, inp): ############################################################################### -# Deploy quantized MXNet Model +# Deploy a quantized MXNet Model # ------------------ # TODO ############################################################################### -# Deploy quantized TFLite Model +# Deploy a quantized TFLite Model # ------------------ # TODO diff --git a/tutorials/frontend/from_pytorch.py b/tutorials/frontend/from_pytorch.py index 45e3cb8af8ff..8354b0eca193 100644 --- a/tutorials/frontend/from_pytorch.py +++ b/tutorials/frontend/from_pytorch.py @@ -88,8 +88,8 @@ ###################################################################### # Import the graph to Relay # ------------------------- -# Convert PyTorch graph to Relay graph. -input_name = 'input0' # only one input, set it to this name +# Convert PyTorch graph to Relay graph. The input name can be arbitrary. +input_name = 'input0' shape_list = [(input_name, img.shape)] mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) From 18aeaf04dd002bd39f17c4e109f185bd3a9bca1b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 14 Apr 2020 04:21:35 +0900 Subject: [PATCH 3/8] formatting, more docs --- tutorials/frontend/deploy_prequantized.py | 27 ++++++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/tutorials/frontend/deploy_prequantized.py b/tutorials/frontend/deploy_prequantized.py index e541d71eb5c5..81b917ce17f3 100644 --- a/tutorials/frontend/deploy_prequantized.py +++ b/tutorials/frontend/deploy_prequantized.py @@ -23,9 +23,13 @@ Pre-quantized model import is one of the quantization support we have in TVM. More details on the quantization story in TVM can be found `here `_. + Here, we demonstrate how to load and run models quantized by PyTorch, MXNet, and TFLite. Once loaded, we can run quantized models on any hardware TVM supports. """ + +################################################################################# +# First, necessary imports from PIL import Image import numpy as np @@ -102,14 +106,15 @@ def run_tvm_model(mod, params, input_name, inp, target="llvm"): # ------------------ # First, we demonstrate how to load deep learning models quantized by PyTorch, # using our PyTorch frontend. - -################################################################################## -# A helper function for converting floating point PyTorch models to quantized ones +# # Please refer to the PyTorch static quantization tutorial below to learn about # their quantization workflow. # https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html -# In short, this function takes a floating point model and converts it to a uint8 -# model. A model is per-channel quantized. +# +# We use this function to quantize PyTorch models. +# In short, this function takes a floating point model and converts it to uint8. +# The model is per-channel quantized. + def quantize_model(model, inp): model.fuse_model() model.qconfig = torch.quantization.get_default_qconfig('fbgemm') @@ -144,8 +149,11 @@ def quantize_model(model, inp): # The PyTorch frontend has support for converting a quantized PyTorch model to # an equivalent Relay module enriched with quantization-aware operators. # We call this representation Relay QNN dialect. +# # You can print the output from the frontend to see how quantized models are -# represented. You would see operators specfic to quantization such as +# represented. +# +# You would see operators specfic to quantization such as # qnn.quantize, qnn.dequantize, qnn.requantize, and qnn.conv2d etc. input_name = "input" # the input name can be be arbitrary for PyTorch frontend. input_shapes = [(input_name, (1, 3, 224, 224))] @@ -158,6 +166,7 @@ def quantize_model(model, inp): # Once we obtained the quantized Relay module, the rest of the workflow # is the same as running floating point models. Please refer to other # tutorials for more details. +# # Under the hood, quantization specific operators are lowered to a sequence of # standard Relay operators before compilation. tvm_result = run_tvm_model(mod, params, input_name, inp, target="llvm") @@ -172,6 +181,12 @@ def quantize_model(model, inp): print("PyTorch top3 label:", [synset[label] for label in pt_top3_labels]) print("TVM top3 label:", [synset[label] for label in tvm_top3_labels]) +############################################################################## +# However, due to the difference in numerics, in general the raw floating point +# outputs are not expected to be identical. Here, we print how many floating point +# output values are identical out of 1000 outputs from mobilenet v2. +print("%d in 1000 raw floating outputs identical." % np.sum(tvm_result[0] == pt_result[0])) + ############################################################################### # Deploy a quantized MXNet Model From 567000c76cc5f0c5ea5f4fe11e56a7a07486b088 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 14 Apr 2020 04:58:22 +0900 Subject: [PATCH 4/8] typo fix --- tutorials/frontend/deploy_prequantized.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tutorials/frontend/deploy_prequantized.py b/tutorials/frontend/deploy_prequantized.py index 81b917ce17f3..49a08cb64f0f 100644 --- a/tutorials/frontend/deploy_prequantized.py +++ b/tutorials/frontend/deploy_prequantized.py @@ -19,13 +19,13 @@ ================================ **Author**: `Masahiro Masuda `_ -This is an a tutorial on loading models quantized by deep learning frameworks into TVM. +This is a tutorial on loading models quantized by deep learning frameworks into TVM. Pre-quantized model import is one of the quantization support we have in TVM. More details on the quantization story in TVM can be found `here `_. Here, we demonstrate how to load and run models quantized by PyTorch, MXNet, and TFLite. -Once loaded, we can run quantized models on any hardware TVM supports. +Once loaded, we can run compiled, quantized models on any hardware TVM supports. """ ################################################################################# @@ -153,7 +153,7 @@ def quantize_model(model, inp): # You can print the output from the frontend to see how quantized models are # represented. # -# You would see operators specfic to quantization such as +# You would see operators specific to quantization such as # qnn.quantize, qnn.dequantize, qnn.requantize, and qnn.conv2d etc. input_name = "input" # the input name can be be arbitrary for PyTorch frontend. input_shapes = [(input_name, (1, 3, 224, 224))] From 5e954f34d66ba55876dd62dd7493d07e9f57dde0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 14 Apr 2020 06:14:20 +0900 Subject: [PATCH 5/8] try make sphinx happy --- tutorials/frontend/deploy_prequantized.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tutorials/frontend/deploy_prequantized.py b/tutorials/frontend/deploy_prequantized.py index 49a08cb64f0f..a90e9dad9abf 100644 --- a/tutorials/frontend/deploy_prequantized.py +++ b/tutorials/frontend/deploy_prequantized.py @@ -16,7 +16,7 @@ # under the License. """ Deploy a Framework-prequantized Model with TVM -================================ +============================================== **Author**: `Masahiro Masuda `_ This is a tutorial on loading models quantized by deep learning frameworks into TVM. @@ -103,7 +103,7 @@ def run_tvm_model(mod, params, input_name, inp, target="llvm"): ################################################################################ # Deploy a quantized PyTorch Model -# ------------------ +# -------------------------------- # First, we demonstrate how to load deep learning models quantized by PyTorch, # using our PyTorch frontend. # @@ -126,14 +126,14 @@ def quantize_model(model, inp): ############################################################################## # Load quantization-ready, pretrained Mobilenet v2 model from torchvision -# ----------------- +# ----------------------------------------------------------------------- # We choose mobilenet v2 because this model was trained with quantization aware # training. Other models require a full post training calibration. qmodel = qmobilenet.mobilenet_v2(pretrained=True).eval() ############################################################################## # Quantize, trace and run the PyTorch Mobilenet v2 model -# ----------------- +# ------------------------------------------------------ # The details are out of scope for this tutorial. Please refer to the tutorials # on the PyTorch website to learn about quantization and jit. pt_inp = torch.from_numpy(inp) @@ -145,7 +145,7 @@ def quantize_model(model, inp): ############################################################################## # Convert quantized Mobilenet v2 to Relay-QNN using the PyTorch frontend -# ----------------- +# ---------------------------------------------------------------------- # The PyTorch frontend has support for converting a quantized PyTorch model to # an equivalent Relay module enriched with quantization-aware operators. # We call this representation Relay QNN dialect. @@ -162,7 +162,7 @@ def quantize_model(model, inp): ############################################################################## # Compile and run the Relay module -# ----------------- +# -------------------------------- # Once we obtained the quantized Relay module, the rest of the workflow # is the same as running floating point models. Please refer to other # tutorials for more details. @@ -171,9 +171,9 @@ def quantize_model(model, inp): # standard Relay operators before compilation. tvm_result = run_tvm_model(mod, params, input_name, inp, target="llvm") -###################################################################### +########################################################################## # Compare the output labels -# ----------------- +# ------------------------- # We should see identical labels printed. pt_top3_labels = np.argsort(pt_result[0])[::-1][:3] tvm_top3_labels = np.argsort(tvm_result[0])[::-1][:3] @@ -181,7 +181,7 @@ def quantize_model(model, inp): print("PyTorch top3 label:", [synset[label] for label in pt_top3_labels]) print("TVM top3 label:", [synset[label] for label in tvm_top3_labels]) -############################################################################## +########################################################################################### # However, due to the difference in numerics, in general the raw floating point # outputs are not expected to be identical. Here, we print how many floating point # output values are identical out of 1000 outputs from mobilenet v2. @@ -190,10 +190,10 @@ def quantize_model(model, inp): ############################################################################### # Deploy a quantized MXNet Model -# ------------------ +# ------------------------------ # TODO ############################################################################### # Deploy a quantized TFLite Model -# ------------------ +# ------------------------------- # TODO From af8831499870e7d9192e3447b0d17bc08c3aeac1 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 14 Apr 2020 09:55:31 +0900 Subject: [PATCH 6/8] add performance section --- tutorials/frontend/deploy_prequantized.py | 43 +++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/tutorials/frontend/deploy_prequantized.py b/tutorials/frontend/deploy_prequantized.py index a90e9dad9abf..4b55c1c862f5 100644 --- a/tutorials/frontend/deploy_prequantized.py +++ b/tutorials/frontend/deploy_prequantized.py @@ -89,7 +89,7 @@ def run_tvm_model(mod, params, input_name, inp, target="llvm"): runtime.set_input(input_name, inp) runtime.run() - return runtime.get_output(0).asnumpy() + return runtime.get_output(0).asnumpy(), runtime ################################################################################# @@ -169,7 +169,7 @@ def quantize_model(model, inp): # # Under the hood, quantization specific operators are lowered to a sequence of # standard Relay operators before compilation. -tvm_result = run_tvm_model(mod, params, input_name, inp, target="llvm") +tvm_result, rt_mod = run_tvm_model(mod, params, input_name, inp, target="llvm") ########################################################################## # Compare the output labels @@ -188,6 +188,45 @@ def quantize_model(model, inp): print("%d in 1000 raw floating outputs identical." % np.sum(tvm_result[0] == pt_result[0])) +########################################################################## +# Measure performance +# ------------------------- +# Here we give an example of how to measure performance of TVM compiled models. +n_repeat = 100 # should be bigger to make the measurement more accurate +ctx = tvm.cpu(0) +ftimer = rt_mod.module.time_evaluator("run", ctx, number=1, + repeat=n_repeat) +prof_res = np.array(ftimer().results) * 1e3 +print("Elapsed ms:", np.mean(prof_res)) + +###################################################################### +# .. note:: +# +# We recommend this method for the following reasons: +# +# * Measurements are done in C++, so there is no Python overhead +# * It includes several warm up runs +# * The same method can be used to profile on remote devices (android etc.). + + +###################################################################### +# .. note:: +# +# Unless the hardware has special support for fast 8 bit instructions, quantized models are +# not expected to be any faster than FP32 models. Without fast 8 bit instructions, TVM does +# quantized convolution in 16 bit, even if the model itself is 8 bit. +# +# For x86, the best performance can be acheived on CPUs with AVX512 instructions set. +# In this case, TVM utilizes the fastest available 8 bit instructions for the given target. +# This includes support for the VNNI 8 bit dot product instruction (CascadeLake or newer). +# +# Moreover, the following general tips for CPU performance equally applies: +# +# * Set the environment variable TVM_NUM_THREADS to the number of physical cores +# * Choose the best target for your hardware, such as "llvm -mcpu=skylake-avx512" or +# "llvm -mcpu=cascadelake" (more CPUs with AVX512 would come in the future) + + ############################################################################### # Deploy a quantized MXNet Model # ------------------------------ From 9e3c5c0fcde2463ce28a9e7d1d9a562a077d3798 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 15 Apr 2020 11:53:30 +0900 Subject: [PATCH 7/8] type and nit fix --- docs/dev/relay_pass_infra.rst | 2 +- tutorials/frontend/deploy_prequantized.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/docs/dev/relay_pass_infra.rst b/docs/dev/relay_pass_infra.rst index 3b443fab9e57..b40b06e21d0a 100644 --- a/docs/dev/relay_pass_infra.rst +++ b/docs/dev/relay_pass_infra.rst @@ -612,7 +612,7 @@ sequential pass example could be like the following to enable IR dumping for seq = tvm.transform.Sequential([ relay.transform.InferType(), relay.transform.FoldConstant(), - relay.transform.PrintIR(), + transform.PrintIR(), relay.transform.EliminateCommonSubexpr(), relay.transform.AlterOpLayout() ]) diff --git a/tutorials/frontend/deploy_prequantized.py b/tutorials/frontend/deploy_prequantized.py index 4b55c1c862f5..9be3b725f61b 100644 --- a/tutorials/frontend/deploy_prequantized.py +++ b/tutorials/frontend/deploy_prequantized.py @@ -149,15 +149,16 @@ def quantize_model(model, inp): # The PyTorch frontend has support for converting a quantized PyTorch model to # an equivalent Relay module enriched with quantization-aware operators. # We call this representation Relay QNN dialect. -# +input_name = "input" # the input name can be be arbitrary for PyTorch frontend. +input_shapes = [(input_name, (1, 3, 224, 224))] +mod, params = relay.frontend.from_pytorch(script_module, input_shapes) + # You can print the output from the frontend to see how quantized models are # represented. # # You would see operators specific to quantization such as # qnn.quantize, qnn.dequantize, qnn.requantize, and qnn.conv2d etc. -input_name = "input" # the input name can be be arbitrary for PyTorch frontend. -input_shapes = [(input_name, (1, 3, 224, 224))] -mod, params = relay.frontend.from_pytorch(script_module, input_shapes) +# # print(mod) ############################################################################## @@ -178,8 +179,8 @@ def quantize_model(model, inp): pt_top3_labels = np.argsort(pt_result[0])[::-1][:3] tvm_top3_labels = np.argsort(tvm_result[0])[::-1][:3] -print("PyTorch top3 label:", [synset[label] for label in pt_top3_labels]) -print("TVM top3 label:", [synset[label] for label in tvm_top3_labels]) +print("PyTorch top3 labels:", [synset[label] for label in pt_top3_labels]) +print("TVM top3 labels:", [synset[label] for label in tvm_top3_labels]) ########################################################################################### # However, due to the difference in numerics, in general the raw floating point @@ -187,7 +188,6 @@ def quantize_model(model, inp): # output values are identical out of 1000 outputs from mobilenet v2. print("%d in 1000 raw floating outputs identical." % np.sum(tvm_result[0] == pt_result[0])) - ########################################################################## # Measure performance # ------------------------- @@ -197,7 +197,7 @@ def quantize_model(model, inp): ftimer = rt_mod.module.time_evaluator("run", ctx, number=1, repeat=n_repeat) prof_res = np.array(ftimer().results) * 1e3 -print("Elapsed ms:", np.mean(prof_res)) +print("Elapsed average ms:", np.mean(prof_res)) ###################################################################### # .. note:: @@ -216,7 +216,7 @@ def quantize_model(model, inp): # not expected to be any faster than FP32 models. Without fast 8 bit instructions, TVM does # quantized convolution in 16 bit, even if the model itself is 8 bit. # -# For x86, the best performance can be acheived on CPUs with AVX512 instructions set. +# For x86, the best performance can be achieved on CPUs with AVX512 instructions set. # In this case, TVM utilizes the fastest available 8 bit instructions for the given target. # This includes support for the VNNI 8 bit dot product instruction (CascadeLake or newer). # From bf4f5e4df766ba5bb08b92193a41dc3bd59663cc Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 15 Apr 2020 12:07:16 +0900 Subject: [PATCH 8/8] format fix --- tutorials/frontend/deploy_prequantized.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tutorials/frontend/deploy_prequantized.py b/tutorials/frontend/deploy_prequantized.py index 9be3b725f61b..40279778c045 100644 --- a/tutorials/frontend/deploy_prequantized.py +++ b/tutorials/frontend/deploy_prequantized.py @@ -149,17 +149,16 @@ def quantize_model(model, inp): # The PyTorch frontend has support for converting a quantized PyTorch model to # an equivalent Relay module enriched with quantization-aware operators. # We call this representation Relay QNN dialect. -input_name = "input" # the input name can be be arbitrary for PyTorch frontend. -input_shapes = [(input_name, (1, 3, 224, 224))] -mod, params = relay.frontend.from_pytorch(script_module, input_shapes) - +# # You can print the output from the frontend to see how quantized models are # represented. # # You would see operators specific to quantization such as # qnn.quantize, qnn.dequantize, qnn.requantize, and qnn.conv2d etc. -# -# print(mod) +input_name = "input" # the input name can be be arbitrary for PyTorch frontend. +input_shapes = [(input_name, (1, 3, 224, 224))] +mod, params = relay.frontend.from_pytorch(script_module, input_shapes) +# print(mod) # comment in to see the QNN IR dump ############################################################################## # Compile and run the Relay module