Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[microTVM][CRT][DOCS] Add a PyTorch tutorial for microTVM with CRT (a…
Browse files Browse the repository at this point in the history
…pache#13324)

This commit adds a tutorial to compile and run a PyTorch model using   
microTVM, the AOT host-driven executor, and C runtime (CRT).
  • Loading branch information
mehrdadh authored and xinetzone committed Nov 25, 2022
1 parent 8894808 commit 588537b
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 6 deletions.
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ def git_describe_version(original_version):
"micro_ethosu.py",
"micro_tvmc.py",
"micro_aot.py",
"micro_pytorch.py",
],
}

Expand Down
6 changes: 1 addition & 5 deletions gallery/how_to/work_with_microtvm/micro_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
# Use the C runtime (crt) and enable static linking by setting system-lib to True
RUNTIME = Runtime("crt", {"system-lib": True})

# Simulate a microcontroller on the host machine. Uses the main() from `src/runtime/crt/host/main.cc <https://github.com/apache/tvm/blob/main/src/runtime/crt/host/main.cc>`_.
# Simulate a microcontroller on the host machine. Uses the main() from `src/runtime/crt/host/main.cc`.
# To use physical hardware, replace "host" with something matching your hardware.
TARGET = tvm.target.target.micro("host")

Expand Down Expand Up @@ -174,7 +174,3 @@
aot_executor.run()
result = aot_executor.get_output(0).numpy()
print(f"Label is `{labels[np.argmax(result)]}` with index `{np.argmax(result)}`")
#
# Output:
# Label is `left` with index `6`
#
206 changes: 206 additions & 0 deletions gallery/how_to/work_with_microtvm/micro_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
.. _tutorial-micro-Pytorch:
microTVM PyTorch Tutorial
===========================
**Authors**:
`Mehrdad Hessar <https://github.com/mehrdadh>`_
This tutorial is showcasing microTVM host-driven AoT compilation with
a PyTorch model. This tutorial can be executed on a x86 CPU using C runtime (CRT).
**Note:** This tutorial only runs on x86 CPU using CRT and does not run on Zephyr
since the model would not fit on our current supported Zephyr boards.
"""

# sphinx_gallery_start_ignore
from tvm import testing

testing.utils.install_request_hook(depth=3)
# sphinx_gallery_end_ignore

import pathlib

import torch
import torchvision
from torchvision import transforms
import numpy as np
from PIL import Image

import tvm
from tvm import relay
from tvm.contrib.download import download_testdata
from tvm.relay.backend import Executor

##################################
# Load a pre-trained PyTorch model
# --------------------------------
#
# To begin with, load pre-trained MobileNetV2 from torchvision. Then,
# download a cat image and preprocess it to use as the model input.
#

model = torchvision.models.quantization.mobilenet_v2(weights="DEFAULT", quantize=True)
model = model.eval()

input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(model, input_data).eval()

img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
img_path = download_testdata(img_url, "cat.png", module="data")
img = Image.open(img_path).resize((224, 224))

# Preprocess the image and convert to tensor
my_preprocess = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
img = my_preprocess(img)
img = np.expand_dims(img, 0)

input_name = "input0"
shape_list = [(input_name, input_shape)]
relay_mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)

#####################################
# Define Target, Runtime and Executor
# -----------------------------------
#
# In this tutorial we use AOT host-driven executor. To compile the model
# for an emulated embedded environment on an x86 machine we use C runtime (CRT)
# and we use `host` micro target. Using this setup, TVM compiles the model
# for C runtime which can run on a x86 CPU machine with the same flow that
# would run on a physical microcontroller.
#


# Simulate a microcontroller on the host machine. Uses the main() from `src/runtime/crt/host/main.cc`
# To use physical hardware, replace "host" with another physical micro target, e.g. `nrf52840`
# or `mps2_an521`. See more more target examples in micro_train.py and micro_tflite.py tutorials.
target = tvm.target.target.micro("host")

# Use the C runtime (crt) and enable static linking by setting system-lib to True
runtime = tvm.relay.backend.Runtime("crt", {"system-lib": True})

# Use the AOT executor rather than graph or vm executors. Don't use unpacked API or C calling style.
executor = Executor("aot")

####################
# Compile the model
# ------------------
#
# Now, we compile the model for the target:
#

with tvm.transform.PassContext(
opt_level=3,
config={"tir.disable_vectorize": True},
):
module = tvm.relay.build(
relay_mod, target=target, runtime=runtime, executor=executor, params=params
)

###########################
# Create a microTVM project
# -------------------------
#
# Now that we have the compiled model as an IRModule, we need to create a firmware project
# to use the compiled model with microTVM. To do this, we use Project API.
#

template_project_path = pathlib.Path(tvm.micro.get_microtvm_template_projects("crt"))
project_options = {"verbose": False, "memory_size_bytes": 6 * 1024 * 1024}

temp_dir = tvm.contrib.utils.tempdir() / "project"
project = tvm.micro.generate_project(
str(template_project_path),
module,
temp_dir,
project_options,
)

####################################
# Build, flash and execute the model
# ----------------------------------
# Next, we build the microTVM project and flash it. Flash step is specific to
# physical microcontroller and it is skipped if it is simulating a microcontroller
# via the host `main.cc`` or if a Zephyr emulated board is selected as the target.
#

project.build()
project.flash()

input_data = {input_name: tvm.nd.array(img.astype("float32"))}
with tvm.micro.Session(project.transport()) as session:
aot_executor = tvm.runtime.executor.aot_executor.AotModule(session.create_aot_executor())
aot_executor.set_input(**input_data)
aot_executor.run()
result = aot_executor.get_output(0).numpy()

#####################
# Look up synset name
# -------------------
# Look up prediction top 1 index in 1000 class synset.
#

synset_url = (
"https://raw.githubusercontent.com/Cadene/"
"pretrained-models.pytorch/master/data/"
"imagenet_synsets.txt"
)
synset_name = "imagenet_synsets.txt"
synset_path = download_testdata(synset_url, synset_name, module="data")
with open(synset_path) as f:
synsets = f.readlines()

synsets = [x.strip() for x in synsets]
splits = [line.split(" ") for line in synsets]
key_to_classname = {spl[0]: " ".join(spl[1:]) for spl in splits}

class_url = (
"https://raw.githubusercontent.com/Cadene/"
"pretrained-models.pytorch/master/data/"
"imagenet_classes.txt"
)
class_path = download_testdata(class_url, "imagenet_classes.txt", module="data")
with open(class_path) as f:
class_id_to_key = f.readlines()

class_id_to_key = [x.strip() for x in class_id_to_key]

# Get top-1 result for TVM
top1_tvm = np.argmax(result)
tvm_class_key = class_id_to_key[top1_tvm]

# Convert input to PyTorch variable and get PyTorch result for comparison
with torch.no_grad():
torch_img = torch.from_numpy(img)
output = model(torch_img)

# Get top-1 result for PyTorch
top1_torch = np.argmax(output.numpy())
torch_class_key = class_id_to_key[top1_torch]

print("Relay top-1 id: {}, class name: {}".format(top1_tvm, key_to_classname[tvm_class_key]))
print("Torch top-1 id: {}, class name: {}".format(top1_torch, key_to_classname[torch_class_key]))
2 changes: 1 addition & 1 deletion src/runtime/crt/host/Makefile.template
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ CXXFLAGS ?= -Werror -Wall -std=c++11 -DTVM_HOST_USE_GRAPH_EXECUTOR_MODULE -DMEMO
LDFLAGS ?= -Werror -Wall

# Codegen produces spurious lines like: int32_t arg2_code = ((int32_t*)arg_type_ids)[(2)];
MODEL_CFLAGS ?= -Wno-error=unused-variable -Wno-error=missing-braces -Wno-error=unused-const-variable
MODEL_CFLAGS ?= -Wno-error=unused-variable -Wno-error=missing-braces -Wno-error=unused-const-variable -Wno-unused-variable

AR ?= ${PREFIX}ar
CC ?= ${PREFIX}gcc
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/crt/host/microtvm_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

IS_TEMPLATE = not os.path.exists(os.path.join(PROJECT_DIR, MODEL_LIBRARY_FORMAT_RELPATH))

# Used this size to pass most CRT tests in TVM.
MEMORY_SIZE_BYTES = 2 * 1024 * 1024

MAKEFILE_FILENAME = "Makefile"
Expand All @@ -62,6 +63,7 @@ def server_info_query(self, tvm_version):
"verbose",
optional=["build"],
type="bool",
default=False,
help="Run make with verbose output",
),
server.ProjectOption(
Expand Down
1 change: 1 addition & 0 deletions tests/scripts/task_python_microtvm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ run_pytest ctypes python-microtvm-project_api tests/micro/project_api
python3 gallery/how_to/work_with_microtvm/micro_tflite.py
python3 gallery/how_to/work_with_microtvm/micro_autotune.py
python3 gallery/how_to/work_with_microtvm/micro_aot.py
python3 gallery/how_to/work_with_microtvm/micro_pytorch.py
./gallery/how_to/work_with_microtvm/micro_tvmc.sh

# Tutorials running with Zephyr
Expand Down

0 comments on commit 588537b

Please sign in to comment.