Skip to content

Commit

Permalink
[Unity] Relay -> Relax translator (#14026)
Browse files Browse the repository at this point in the history
This PR implements a Relay to Relax translator, which allows us to import Relay workloads to Relax for benchmarking and development purposes (tests and examples are added).
  • Loading branch information
YuchenJin authored and tqchen committed Mar 4, 2023
1 parent 4d17209 commit a2a59b0
Show file tree
Hide file tree
Showing 9 changed files with 1,116 additions and 0 deletions.
253 changes: 253 additions & 0 deletions apps/relax_examples/e2e_auto_tir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import datetime
import os
import csv
import json
import argparse
import logging
from typing import Dict
import numpy as np # type: ignore

import tvm
from tvm import relay, relax, runtime, transform
from tvm.ir.module import IRModule
from tvm import meta_schedule as ms
from tvm.meta_schedule.testing.relay_workload import get_network
from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
from tvm.relax.testing import relay_translator
from tvm.target.target import Target


def _parse_args():
args = argparse.ArgumentParser()
args.add_argument(
"--workload",
type=str,
required=True,
)
args.add_argument(
"--input-shape",
type=str,
required=True,
)
args.add_argument(
"--target",
type=str,
required=True,
)
args.add_argument(
"--num-trials",
type=int,
required=True,
)
args.add_argument(
"--rpc-host",
type=str,
default=None,
)
args.add_argument(
"--rpc-port",
type=int,
default=None,
)
args.add_argument(
"--rpc-key",
type=str,
default=None,
)
args.add_argument(
"--work-dir",
type=str,
required=True,
)
args.add_argument(
"--cache-dir",
type=str,
default=None,
)
args.add_argument(
"--rpc-timeout-sec",
type=int,
default=180,
)
args.add_argument("--num-measurement-repeats", type=int, default=5)
args.add_argument("--num-measurements", type=int, default=10)
args.add_argument("--results-file", type=str, required=False, default=None)
parsed = args.parse_args()
parsed.target = tvm.target.Target(parsed.target)
parsed.input_shape = json.loads(parsed.input_shape)
if parsed.target.attrs.get("mtriple", None) == "aarch64-linux-gnu":
parsed.alloc_repeat = 3
else:
parsed.alloc_repeat = 1
if parsed.rpc_host and parsed.rpc_port and parsed.rpc_key:
parsed.rpc_config = ms.runner.RPCConfig(
tracker_host=parsed.rpc_host,
tracker_port=parsed.rpc_port,
tracker_key=parsed.rpc_key,
session_timeout_sec=parsed.rpc_timeout_sec,
)
parsed.workers = parsed.rpc_config.count_num_servers(allow_missing=False)
else:
# check all rpc configs are None
assert (
(parsed.rpc_host is None) and (parsed.rpc_port is None) and (parsed.rpc_key is None)
), "Please set all 'rpc_host', 'rpc_port' and 'rpc_key' to use PRC server"
parsed.rpc_config = None
parsed.workers = 1
return parsed


logging.basicConfig()
logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG)
ARGS = _parse_args()


def apply_opt_before_tuning(
relay_mod: IRModule, params: Dict[str, runtime.NDArray], target: Target
):
with transform.PassContext(opt_level=3):
main_func = relay_mod["main"]
bind_main_func = relay.build_module.bind_params_by_name(main_func, params)
relay_mod = IRModule.from_expr(bind_main_func)
relay_mod = relay.transform.SimplifyInference()(relay_mod)
relay_mod = relay.transform.FoldConstant()(relay_mod)
relay_mod = relay.transform.FoldScaleAxis()(relay_mod)
relay_mod = relay.transform.CanonicalizeOps()(relay_mod)
relay_mod = relay.transform.AlterOpLayout()(relay_mod)
relay_mod = relay.transform.FoldConstant()(relay_mod)

relax_mod = relay_translator.from_relay(relay_mod["main"], target=target)
relax_mod = relax.transform.AnnotateTIROpPattern()(relax_mod)
relax_mod = relax.transform.FuseOps()(relax_mod)
relax_mod = relax.transform.FuseTIR()(relax_mod)
return relax_mod


def f_measurement(
rt_mod: runtime.Module, device: runtime.ndarray.Device, input_data: Dict[str, runtime.NDArray]
):
vm = relax.vm.VirtualMachine(exec=rt_mod, device=device)
vm.save_function("main", "measure_func", **input_data, include_return=False)
evaluator = vm.time_evaluator(
func_name="measure_func",
dev=device,
repeat=ARGS.num_measurement_repeats,
number=ARGS.num_measurements,
min_repeat_ms=500,
)
return evaluator()


def get_runner():
runner_config = {
"evaluator_config": ms.runner.EvaluatorConfig(
number=3,
repeat=1,
min_repeat_ms=100,
enable_cpu_cache_flush=False,
),
"alloc_repeat": ARGS.alloc_repeat,
}
if ARGS.rpc_config:
runner = ms.runner.RPCRunner(
rpc_config=ARGS.rpc_config, max_workers=ARGS.workers, **runner_config
)
else:
runner = ms.runner.LocalRunner(**runner_config)

return runner


def main():
relay_mod, params, (input_name, input_shape, input_dtype) = get_network(
ARGS.workload,
ARGS.input_shape,
cache_dir=ARGS.cache_dir,
)
input_info = {input_name: input_shape}
input_data = {}
for input_name, input_shape in input_info.items():
print(f" input_name: {input_name}")
print(f" input_shape: {input_shape}")
print(f" input_dtype: {input_dtype}")

# translate the ResNet model from Relay to Relax
relax_mod = apply_opt_before_tuning(relay_mod, params, target=ARGS.target)
assert isinstance(relax_mod, tvm.IRModule)

db = ms.relax_integration.tune_relax(
mod=relax_mod,
target=ARGS.target,
params=params,
num_trials_per_iter=64,
max_trials_per_task=ARGS.num_trials,
max_trials_global=ARGS.num_trials,
runner=get_runner(),
work_dir=ARGS.work_dir,
)
executable = ms.relax_integration.compile_relax(
db,
mod=relax_mod,
target=ARGS.target,
params=params,
)

for input_name, input_shape in input_info.items():
if input_dtype.startswith("float"):
input_data[input_name] = np.random.uniform(size=input_shape).astype(input_dtype)
else:
input_data[input_name] = np.random.randint(
low=0, high=10000, size=input_shape, dtype=input_dtype
)

# for documentation purposes
start_time = datetime.datetime.now()

if ARGS.rpc_config:
result = run_module_via_rpc(
rpc_config=ARGS.rpc_config,
lib=executable.mod,
dev_type=ARGS.target.kind.name,
args=input_data,
continuation=f_measurement,
)
else:
dev = tvm.device(ARGS.target.kind.name)
result = f_measurement(executable.mod, dev, input_data)

print(result)

if not ARGS.results_file:
return

out_path = os.path.abspath(os.path.expanduser(ARGS.results_file))
with open(out_path, "w") as out_file:
writer = csv.writer(out_file)
# write experiment parameters at the top as a record
writer.writerow(["start", str(start_time)])
writer.writerow(["workload", ARGS.workload])
writer.writerow(["input_shape", ARGS.input_shape])
writer.writerow(["target", ARGS.target])
writer.writerow(["num_measurement_repeats", ARGS.num_measurement_repeats])
for res in result.results:
writer.writerow([str(res)])


if __name__ == "__main__":
main()
57 changes: 57 additions & 0 deletions apps/relax_examples/mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# Example code on creating, compiling, and running an MLP model in relax


import tvm
from tvm import relax, tir, topi
import numpy as np


def build_mlp(data, weight):
bb = relax.BlockBuilder()

with bb.function("mlp", [data, weight]):
gv0 = bb.emit_te(tvm.contrib.cblas.matmul, data, weight, transa=False, transb=False)
gv1 = bb.emit_te(topi.nn.relu, gv0)
bb.emit_func_output(gv1)

mod = bb.get()
return mod


if __name__ == "__main__":
# symbolic dimensions
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
# create data and weight variables
data = relax.Var("data", relax.TensorStructInfo([n, m], "float32"))
weight = relax.Var("weight", relax.TensorStructInfo([m, n], "float32"))

# construct a mlp model
mod = build_mlp(data, weight)

# build and create vm executor
target = tvm.target.Target("llvm", host="llvm")
ex = relax.vm.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())

# run the mlp model on relax vm
data = tvm.nd.array(np.random.rand(16, 32).astype(np.float32))
weight = tvm.nd.array(np.random.rand(32, 16).astype(np.float32))
res = vm["mlp"](data, weight)
print(res)
69 changes: 69 additions & 0 deletions apps/relax_examples/nn_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# Example code on creating, compiling, and running a neural network with pytorch-like API


import tvm
from tvm.relay import Call
from tvm import relax, tir
from tvm.relax.testing import nn
from tvm.script import relax as R
import numpy as np


if __name__ == "__main__":
builder = relax.BlockBuilder()

# a symbolic variable to represent minibatch size
n = tir.Var("n", "int64")
input_size = 784
hidden_sizes = [128, 32]
output_size = 10

# build a three linear-layer neural network for a classification task
with builder.function("main"):
model = nn.Sequential(
nn.Linear(input_size, hidden_sizes[0]),
nn.ReLU(),
nn.Linear(hidden_sizes[0], hidden_sizes[1]),
nn.ReLU(),
nn.Linear(hidden_sizes[1], output_size),
nn.LogSoftmax(),
)
data = nn.Placeholder((n, input_size), name="data")
output = model(data)
params = [data] + model.parameters()
builder.emit_func_output(output, params=params)

# get and print the IRmodule being built
mod = builder.get()
mod.show()

# build the IRModule and create relax vm
target = tvm.target.Target("llvm", host="llvm")
ex = relax.vm.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())

# init parameters
params = nn.init_params(mod)

# run the model on relax vm
# the input data has a minibatch size of 3
data = tvm.nd.array(np.random.rand(3, input_size).astype(np.float32))
res = vm["main"](data, *params)
print(res)
Loading

0 comments on commit a2a59b0

Please sign in to comment.