Skip to content

Commit

Permalink
Merge branch 'exp/amp-nhwc-autoscheduler' into 'develop'
Browse files Browse the repository at this point in the history
[EXP] TVM AMP + NHWC + autoscheduler

Closes apache#18

See merge request RTST_AI/tvm!4
  • Loading branch information
neigh80 authored and Minhae Ye committed Jan 19, 2024
2 parents 18467c9 + 8c4a6ea commit d91b2db
Showing 1 changed file with 177 additions and 0 deletions.
177 changes: 177 additions & 0 deletions experiments/amp_nhwc_autoscheduler/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
from copy import deepcopy

from PIL import Image
import numpy as np
import timm
import torch
import torchvision
from torchvision import transforms
import tvm
from tvm import relay, auto_scheduler, autotvm
from tvm.contrib.download import download_testdata
from tvm.contrib import graph_executor
from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt


def tvm_amp(mod, params, to_nhwc=False):
mod = tvm.relay.transform.EliminateCommonSubexpr()(mod)

BindPass = tvm.relay.transform.function_pass(
lambda fn, new_mod, ctx: tvm.relay.build_module.bind_params_by_name(fn, params),
opt_level=1,
)
mod = BindPass(mod)
mod = tvm.relay.transform.FoldConstant()(mod)

mod = tvm.relay.transform.CombineParallelBatchMatmul()(mod)
mod = tvm.relay.transform.FoldConstant()(mod)

mod = tvm.relay.transform.InferType()(mod)
mod = tvm.relay.transform.ToMixedPrecision()(mod)

if to_nhwc:
desired_layouts = {"nn.conv2d": ["NHWC", "default"], "qnn.conv2d": ["NHWC", "default"]}
mod = relay.transform.ConvertLayout(desired_layouts)(mod)

mod = tvm.relay.transform.EliminateCommonSubexpr()(mod)
mod = tvm.relay.transform.FoldConstant()(mod)

return mod


if __name__ == "__main__":
import os

from absl import app, flags

os.environ["PATH"] += os.pathsep + "/usr/local/cuda/bin/"

FLAGS = flags.FLAGS

flags.DEFINE_enum(
"model", "mobilenet_v2", ["mobilenet_v2", "resnet50", "efficientnet_v2_s"], "Choose model."
)

def main(_):
img_url = "https://github.com/pytorch/hub/raw/master/images/dog.jpg"
img_path = download_testdata(img_url, "dog.jpg", module="data")
img = Image.open(img_path)
preprocess_input = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
data = preprocess_input(img)
data = np.expand_dims(data, 0)

torch_model = (
timm.create_model("tf_efficientnetv2_s", pretrained=True)
if FLAGS.model == "efficientnet_v2_s"
else getattr(torchvision.models, FLAGS.model)(pretrained=True)
)
torch_model.eval()
scripted_torch_model = torch.jit.trace(torch_model, torch.randn(data.shape)).eval()
shape_list = [("input_1", data.shape)]
mod, params = relay.frontend.from_pytorch(scripted_torch_model, shape_list)
mod = tvm_amp(mod, params, to_nhwc=True)
params = None

target = tvm.target.Target(
"cuda -arch=sm_72", host="llvm -mtriple=aarch64-linux-gnu -mcpu=carmel"
)
device_key = "xavier"
host = "0.0.0.0"
port = 9190

log_file = f"{FLAGS.model}.json"
lib_file = f"{FLAGS.model}.tar"

# Extract tasks from the network
print("Extract tasks...")
tasks, task_weights = auto_scheduler.extract_tasks(deepcopy(mod), params, target)

for idx, task in enumerate(tasks):
print("========== Task %d (workload key: %s) ==========" % (idx, task.workload_key))
print(task.compute_dag)

def run_tuning():
print("Begin tuning...")
remote_runner = auto_scheduler.RPCRunner(
key=device_key, host=host, port=port, repeat=1, min_repeat_ms=300, timeout=600
)

tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=len(tasks) * 750,
builder=auto_scheduler.LocalBuilder(timeout=60),
runner=remote_runner,
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
)

tuner.tune(tune_option)

run_tuning()

# Compile with the history best
print("Compile...")
with auto_scheduler.ApplyHistoryBest(log_file):
with tvm.transform.PassContext(
opt_level=3, config={"relay.backend.use_auto_scheduler": True}
):
lib = relay.build(mod, target=target)
lib.export_library(lib_file)

remote = tvm.auto_scheduler.utils.request_remote(
device_key=device_key, host=host, port=port, timeout=180
)
dev = remote.device(str(target))

remote.upload(lib_file)
lib = remote.load_module(lib_file)

# Create graph executor
module = graph_executor.GraphModule(lib["default"](dev))
dtype = "float32"
data_tvm = tvm.nd.array(data.astype(dtype))
module.set_input("input_1", data_tvm)

# Evaluate
print("Evaluate inference time cost...")
ftimer = module.module.time_evaluator("run", dev, repeat=50)
prof_res = np.array(ftimer().results) * 1e3 # convert to millisecond
print(
"Mean inference time (std dev): %.2f ms (%.2f ms)"
% (np.mean(prof_res), np.std(prof_res))
)

module.run()
tvm_out = module.get_output(0)
top1_tvm = np.argmax(tvm_out.asnumpy())

synset_url = "".join(
[
"https://gist.githubusercontent.com/zhreshold/",
"4d0b62f3d01426887599d4f7ede23ee5/raw/",
"596b27d23537e5a1b5751d2b0481ef172f58b539/",
"imagenet1000_clsid_to_human.txt",
]
)
synset_name = "imagenet1000_clsid_to_human.txt"
synset_path = download_testdata(synset_url, synset_name, module="data")
with open(synset_path) as f:
synset = eval(f.read())
print("Relay top-1 id: {}, class name: {}".format(top1_tvm, synset[top1_tvm]))
# confirm correctness with torch output
with torch.no_grad():
torch_img = torch.from_numpy(data)
output = torch_model(torch_img)

# Get top-1 result for PyTorch
top1_torch = np.argmax(output.numpy())

print("Torch top-1 id: {}, class name: {}".format(top1_torch, synset[top1_torch]))

app.run(main)

0 comments on commit d91b2db

Please sign in to comment.