From 9f87e318b59efa02a1797b2821e8cdad2108e053 Mon Sep 17 00:00:00 2001 From: Alex Wong <11878166+alexwong@users.noreply.github.com> Date: Tue, 25 Feb 2020 21:41:53 -0800 Subject: [PATCH] [Tutorial] Add a tutorial for PyTorch (#4936) * Add a tutorial for PyTorch * Fix sphinx formatting, add version support * Remove space * Remove version check * Some refactoring * Use no grad * Rename input * Update cat img source --- tutorials/frontend/from_pytorch.py | 166 +++++++++++++++++++++++++++++ 1 file changed, 166 insertions(+) create mode 100644 tutorials/frontend/from_pytorch.py diff --git a/tutorials/frontend/from_pytorch.py b/tutorials/frontend/from_pytorch.py new file mode 100644 index 0000000000000..c280c259c1fe4 --- /dev/null +++ b/tutorials/frontend/from_pytorch.py @@ -0,0 +1,166 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Compile PyTorch Models +====================== +**Author**: `Alex Wong `_ + +This article is an introductory tutorial to deploy PyTorch models with Relay. + +For us to begin with, PyTorch should be installed. +TorchVision is also required since we will be using it as our model zoo. + +A quick solution is to install via pip + +.. code-block:: bash + + pip install torch==1.4.0 + pip install torchvision==0.5.0 + +or please refer to official site +https://pytorch.org/get-started/locally/ + +PyTorch versions should be backwards compatible but should be used +with the proper TorchVision version. + +Currently, TVM supports PyTorch 1.4, 1.3, and 1.2. Other versions may +be unstable. +""" + +# tvm, relay +import tvm +from tvm import relay + +# numpy, packaging +import numpy as np +from packaging import version +from tvm.contrib.download import download_testdata + +# PyTorch imports +import torch +import torchvision + +###################################################################### +# Load a pretrained PyTorch model +# ------------------------------- +model_name = 'resnet18' +model = getattr(torchvision.models, model_name)(pretrained=True) +model = model.eval() + +# We grab the TorchScripted model via tracing +input_shape = [1, 3, 224, 224] +input_data = torch.randn(input_shape) +scripted_model = torch.jit.trace(model, input_data).eval() + +###################################################################### +# Load a test image +# ----------------- +# Classic cat example! +from PIL import Image +img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true' +img_path = download_testdata(img_url, 'cat.png', module='data') +img = Image.open(img_path).resize((224, 224)) + +# Preprocess the image and convert to tensor +from torchvision import transforms +my_preprocess = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) +]) +img = my_preprocess(img) +img = np.expand_dims(img, 0) + +###################################################################### +# Import the graph to Relay +# ------------------------- +# Convert PyTorch graph to Relay graph. +shape_dict = {'img': img.shape} +mod, params = relay.frontend.from_pytorch(scripted_model, + shape_dict) + +###################################################################### +# Relay Build +# ----------- +# Compile the graph to llvm target with given input specification. +target = 'llvm' +target_host = 'llvm' +ctx = tvm.cpu(0) +with relay.build_config(opt_level=3): + graph, lib, params = relay.build(mod, + target=target, + target_host=target_host, + params=params) + +###################################################################### +# Execute the portable graph on TVM +# --------------------------------- +# Now we can try deploying the compiled model on target. +from tvm.contrib import graph_runtime +dtype = 'float32' +m = graph_runtime.create(graph, lib, ctx) +# Set inputs +m.set_input('img', tvm.nd.array(img.astype(dtype))) +m.set_input(**params) +# Execute +m.run() +# Get outputs +tvm_output = m.get_output(0, tvm.nd.empty(((1, 1000)), 'float32')) + +##################################################################### +# Look up synset name +# ------------------- +# Look up prediction top 1 index in 1000 class synset. +synset_url = ''.join(['https://raw.githubusercontent.com/Cadene/', + 'pretrained-models.pytorch/master/data/', + 'imagenet_synsets.txt']) +synset_name = 'imagenet_synsets.txt' +synset_path = download_testdata(synset_url, synset_name, module='data') +with open(synset_path) as f: + synsets = f.readlines() + +synsets = [x.strip() for x in synsets] +splits = [line.split(' ') for line in synsets] +key_to_classname = {spl[0]:' '.join(spl[1:]) for spl in splits} + +class_url = ''.join(['https://raw.githubusercontent.com/Cadene/', + 'pretrained-models.pytorch/master/data/', + 'imagenet_classes.txt']) +class_name = 'imagenet_classes.txt' +class_path = download_testdata(class_url, class_name, module='data') +with open(class_path) as f: + class_id_to_key = f.readlines() + +class_id_to_key = [x.strip() for x in class_id_to_key] + +# Get top-1 result for TVM +top1_tvm = np.argmax(tvm_output.asnumpy()[0]) +tvm_class_key = class_id_to_key[top1_tvm] + +# Convert input to PyTorch variable and get PyTorch result for comparison +with torch.no_grad(): + torch_img = torch.from_numpy(img) + output = model(torch_img) + + # Get top-1 result for PyTorch + top1_torch = np.argmax(output.numpy()) + torch_class_key = class_id_to_key[top1_torch] + +print('Relay top-1 id: {}, class name: {}'.format(top1_tvm, key_to_classname[tvm_class_key])) +print('Torch top-1 id: {}, class name: {}'.format(top1_torch, key_to_classname[torch_class_key])) \ No newline at end of file