From aa999f2c9c6984a738c081911c18cf957e523f4a Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Wed, 31 Mar 2021 10:44:39 -0700 Subject: [PATCH] [Target] Add support for target object with host field compatible with previous api (#7534) * Fix legacy code on target host * Modify legacy code for target host change * Add tests and fix merge issue * Add condition for same host * Modify all files for new target host api compatibility * Add newline * Change import format * Optimize test file * Add match error info for unit tests * Fix for heterogeneous targets * Fix format for dict iteration * Fix target host type error * Skip one testcase for tvm infinite loop bug * Fixed bug for target map compatibility * Fix another TargetsMap issue * Fix typo and infinite loop error * Temporary fix for handle issue * Fix vm target * Add condition support for str case * Add GetHost function and fix previous bugs * Fix measure_record.cc * Fix search_task.cc * Fix compiler.cc, memory_alloc.cc * Fix driver_api.cc * Fix format * Fix bugs and GetHost function usage * Fix clang format * Fix bug * Modify python tests * Change python unit tests to new target api * Fi test_runtime_heterogeneous.py * Modify tutorials & remove extra print * Update more tests to new api * Refine the tutorial target usage * change argument name for Target constructor function * Fix target export function * Fix and validate all tutorial usage * Remove unused argument * Fix format * Fix bug in driver/build_module.py for heterogeneous target * Fix bug in driver/build_module.py for heterogeneous target more * Fix target host type error * Fix cudnn target host bug * Fix according to reviews, add helper function in python * Refactor code as helper function * Expand helper function * Fix bug add and update python helper function * Update target hosts * Fix format & refresh function * Fix unit test bug * Fix bug in refreshing host * Fix bug * Add SetHost function * Update export function * Fix format * Fix export bug in target * Fix bug on host referencing * Addtional tests * Address review issues * Fix format target.py * Fix issues and format * Add some 3rd party dependencies * Merge main branch * Fix target.h format * Remove redundent import * Fix function name * Add parameter name * Fix new code bug * Fix bug in lowering --- include/tvm/target/target.h | 37 +++++++++- python/tvm/auto_scheduler/measure.py | 13 ++-- .../tvm/auto_scheduler/relay_integration.py | 8 +-- python/tvm/auto_scheduler/search_task.py | 14 ++-- .../autotvm/graph_tuner/base_graph_tuner.py | 7 +- .../graph_tuner/utils/traverse_graph.py | 2 +- python/tvm/autotvm/measure/measure_methods.py | 3 + python/tvm/autotvm/task/relay_integration.py | 9 ++- python/tvm/autotvm/task/task.py | 14 ++-- python/tvm/contrib/peak.py | 13 +++- python/tvm/driver/build_module.py | 13 +++- python/tvm/driver/tvmc/autotuner.py | 11 +-- python/tvm/driver/tvmc/compiler.py | 12 ++-- python/tvm/exec/measure_peak.py | 4 ++ python/tvm/relay/backend/_backend.py | 7 +- python/tvm/relay/backend/vm.py | 17 ++++- python/tvm/relay/build_module.py | 15 ++-- python/tvm/target/target.py | 62 +++++++++++++---- src/auto_scheduler/feature.cc | 18 +++-- src/auto_scheduler/measure_record.cc | 8 ++- src/auto_scheduler/search_task.cc | 3 + src/driver/driver_api.cc | 47 ++++++++----- src/relay/backend/build_module.cc | 17 +++-- src/relay/backend/vm/compiler.cc | 10 ++- src/relay/transforms/memory_alloc.cc | 2 + src/target/target.cc | 57 ++++++++++++--- tests/micro/zephyr/test_zephyr.py | 3 +- tests/python/contrib/test_cudnn.py | 8 +-- tests/python/contrib/test_dlpack.py | 4 +- tests/python/contrib/test_miopen.py | 4 +- tests/python/driver/tvmc/test_compiler.py | 3 +- .../frontend/tensorflow/test_forward.py | 3 +- tests/python/integration/test_reduce.py | 4 +- tests/python/integration/test_tuning.py | 7 +- tests/python/relay/test_vm.py | 5 +- .../unittest/test_auto_scheduler_measure.py | 6 +- .../test_auto_scheduler_search_task.py | 16 ++--- tests/python/unittest/test_crt.py | 3 +- .../unittest/test_runtime_heterogeneous.py | 6 +- tests/python/unittest/test_runtime_rpc.py | 2 +- .../unittest/test_target_codegen_blob.py | 2 +- .../unittest/test_target_codegen_device.py | 2 +- .../unittest/test_target_codegen_hexagon.py | 6 +- tests/python/unittest/test_target_target.py | 69 +++++++++++++++---- ...tir_transform_instrument_bound_checkers.py | 13 ++-- tutorials/auto_scheduler/tune_network_mali.py | 11 +-- tutorials/autotvm/tune_relay_mobile_gpu.py | 10 +-- tutorials/frontend/deploy_model_on_android.py | 14 ++-- tutorials/frontend/from_darknet.py | 5 +- tutorials/frontend/from_pytorch.py | 5 +- tutorials/frontend/from_tensorflow.py | 8 +-- .../get_started/cross_compilation_and_rpc.py | 4 +- .../get_started/tensor_expr_get_started.py | 27 ++++---- 53 files changed, 463 insertions(+), 210 deletions(-) diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 64bd251c0dedc..9c1fe55749e49 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -24,6 +24,8 @@ #ifndef TVM_TARGET_TARGET_H_ #define TVM_TARGET_TARGET_H_ +#include +#include #include #include #include @@ -35,6 +37,7 @@ namespace tvm { class TargetInternal; +class Target; /*! * \brief Compilation target. @@ -60,6 +63,8 @@ class TargetNode : public Object { TVM_DLL const std::string& str() const; /*! \return Export target to JSON-like configuration */ TVM_DLL Map Export() const; + /*! \return The Optional typed target host of the TargetNode */ + TVM_DLL Optional GetHost() const; void VisitAttrs(AttrVisitor* v) { v->Visit("kind", &kind); @@ -150,6 +155,13 @@ class Target : public ObjectRef { */ TVM_DLL explicit Target(Target target, Target host); TVM_DEFINE_OBJECT_REF_METHODS(Target, ObjectRef, TargetNode); + /*! + * \brief Create a new Target object with given target (w.o host) and target host. + * \param target The current Target typed object target, with or without host field. + * \param host The given Target typed object target host + * \return The new Target object with the given target and host field of given host. + */ + static Target WithHost(const Target& target, const Target& host); private: // enable with syntax. @@ -167,6 +179,29 @@ class Target : public ObjectRef { */ TVM_DLL void ExitWithScope(); }; - +/*! + * \brief Check and update host field of the given legacy target and target host pair. + * Note that this function is for legacy target api compatibility issue only, not + * recommended for other use. + * \param target The pointer to a Target typed object with host field to be updated + * \param host The pointer to a Target typed object for target host to be updated + */ +void CheckAndUpdateHostConsistency(Target* target, Target* host); +/*! + * \brief Check and update host field of the given legacy heterogeneous targets and + * target host.Note that this function is for legacy target api compatibility issue only, + * not recommended for other use. + * \param target The pointer to a Map objects with values being Target objects + * \param host The Target typed object for target host to be updated + */ +void CheckAndUpdateHostConsistency(Map* target, Target* host); +/*! + * \brief Check and update host field of the given legacy heterogeneous targets and + * target host.Note that this function is for legacy target api compatibility issue only, + * not recommended for other use. + * \param target The pointer to a Map objects with keys being Target objects + * \param host The Target typed object for target host to be updated + */ +void CheckAndUpdateHostConsistency(Map* target, Target* host); } // namespace tvm #endif // TVM_TARGET_TARGET_H_ diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 95d3942465e7d..83f1bcec7ebcf 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -44,6 +44,8 @@ from tvm.ir import transform from tvm.autotvm.measure.measure_methods import set_cuda_target_arch from tvm.contrib import tar, ndk +from tvm.target import Target + from . import _ffi_api from .loop_state import StateObject @@ -221,10 +223,12 @@ def recover_measure_input(inp, rebuild_state=False): from .search_task import SearchTask # lazily import to avoid recursive dependency task = inp.task + task.target, task.target_host = Target.check_and_update_host_consist( + task.target, task.target_host + ) new_task = SearchTask( workload_key=task.workload_key, target=task.target, - target_host=task.target_host, hardware_params=task.hardware_params, layout_rewrite_option=task.layout_rewrite_option, task_inputs=list(task.task_input_names), @@ -602,6 +606,9 @@ def _timed_func(inp_serialized, build_func, verbose): tic = time.time() inp = MeasureInput.deserialize(inp_serialized) task = inp.task + task.target, task.target_host = Target.check_and_update_host_consist( + task.target, task.target_host + ) error_no = MeasureErrorNo.NO_ERROR error_msg = None @@ -622,9 +629,7 @@ def _timed_func(inp_serialized, build_func, verbose): try: with transform.PassContext(): - func = build_module.build( - sch, args, target=task.target, target_host=task.target_host - ) + func = build_module.build(sch, args, target=task.target) func.export_library(filename, build_func) # pylint: disable=broad-except except Exception: diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index d10f0fb5ecd8b..aea2ee182ceeb 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -29,9 +29,11 @@ from tvm import autotvm, transform from tvm.ir.transform import PassContext from tvm.runtime import convert_to_object + from tvm.te.tensor import ComputeOp, PlaceholderOp, Tensor from tvm.tir import Reduce from tvm.tir import expr as _expr +from tvm.target import Target from . import _ffi_api from .compute_dag import ComputeDAG, LayoutRewriteOption @@ -108,10 +110,7 @@ def extract_tasks( """ # pylint: disable=import-outside-toplevel - if isinstance(target, str): - target = tvm.target.Target(target) - if isinstance(target_host, str): - target_host = tvm.target.Target(target_host) + target, target_host = Target.check_and_update_host_consist(target, target_host) # Run the compiler to collect all TOPI calls during compilation. env = TracingEnvironment( @@ -137,7 +136,6 @@ def extract_tasks( SearchTask( workload_key=wkl_key, target=target, - target_host=target_host, hardware_params=hardware_params, # When auto scheduler is used in end to end network, try to apply layout rewrite # to improve the overall performance diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index c5c2b5b44451b..d8fa4380a6beb 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -398,10 +398,8 @@ def __init__( compute_dag = ComputeDAG(workload_key) assert target is not None, "Must specify a target." - if isinstance(target, str): - target = Target(target) - if isinstance(target_host, str): - target_host = Target(target_host) + + target, target_host = Target.check_and_update_host_consist(target, target_host) if layout_rewrite_option is None: layout_rewrite_option = LayoutRewriteOption.get_target_default(target) @@ -511,6 +509,9 @@ def print_best(self, log_file, print_mode="schedule"): raise ValueError("Invalid print_mode: %s" % print_mode) def __getstate__(self): + self.target, self.target_host = Target.check_and_update_host_consist( + self.target, self.target_host + ) return { "compute_dag": self.compute_dag, "workload_key": self.workload_key, @@ -535,12 +536,15 @@ def __setstate__(self, state): if workload[0] not in WORKLOAD_FUNC_REGISTRY: register_workload_tensors(state["workload_key"], state["compute_dag"].tensors) + state["target"], state["target_host"] = Target.check_and_update_host_consist( + state["target"], state["target_host"] + ) self.__init_handle_by_constructor__( _ffi_api.SearchTask, state["compute_dag"], state["workload_key"], state["target"], - state["target_host"], + state["target"].host, state["hardware_params"], state["layout_rewrite_option"], state["task_input_names"], diff --git a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py index 741b05f4c453c..b307130780a72 100644 --- a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py +++ b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py @@ -28,6 +28,7 @@ from tvm.autotvm.task import get_config from tvm.autotvm.record import encode, load_from_file from tvm.autotvm.measure import MeasureResult, MeasureInput +from tvm.target import Target from ...target import Target from .utils import ( @@ -439,6 +440,8 @@ def benchmark_layout_transform( This might bring performance loss comparing to benchmarking layout transformation. """ self._logger.info("Start to benchmark layout transformation...") + self._target, target_host = Target.check_and_update_host_consist(self._target, target_host) + if layout_records is None and infer_layout: raise RuntimeError("Requires some records to infer layout transformation time.") @@ -525,9 +528,7 @@ def _callback(_, inputs, results): continue records = [] - task = autotvm.task.create( - "layout_transform", args=args, target=self._target, target_host=target_host - ) + task = autotvm.task.create("layout_transform", args=args, target=self._target) tuner = autotvm.tuner.GridSearchTuner(task) tuner.tune(n_trial=1, measure_option=measure_option, callbacks=[_log_to_list(records)]) if not isinstance(records[0][1].costs[0], float): diff --git a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py index fd2612f20371c..f61d34284e01b 100644 --- a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py +++ b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py @@ -63,7 +63,7 @@ def expr2graph(expr, target_ops, node_dict, node_list): for node_entry in node_list: if node_entry["op"] in target_ops: task_name, args = env.task_collection[task_pos] - task = autotvm.task.create(task_name, args, target="llvm", target_host=None) + task = autotvm.task.create(task_name, args, target="llvm") node_entry["workloads"] = [task.workload] node_entry["topi_op"] = [task_name] task_pos += 1 diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index 261ecabe49cda..d212e5f26f20a 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -40,6 +40,7 @@ from tvm.error import TVMError from tvm.driver import build from tvm.contrib import nvcc, ndk, tar +from tvm.target import Target from ..utils import get_const_tuple from ..env import AutotvmGlobalScope @@ -418,6 +419,8 @@ def set_task(self, task): def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_option=None): """Common part for building a configuration""" target, task, config = measure_input + target, task.target_host = Target.check_and_update_host_consist(target, task.target_host) + with target: s, args = task.instantiate(config) diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index 12e057e01da62..9117ce398d492 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -25,6 +25,7 @@ import tvm from tvm.autotvm.task.dispatcher import DispatchContext, FallbackContext +from tvm.target import Target from .task import create from .topi_integration import TaskExtractEnv @@ -89,7 +90,8 @@ def extract_from_program(mod, params, target, target_host=None, ops=None): task: Array of autotvm.task.Task collected tasks """ - return extract_from_multiple_program([mod], [params], target, target_host, ops) + target, target_host = Target.check_and_update_host_consist(target, target_host) + return extract_from_multiple_program([mod], [params], target, ops=ops) def extract_from_multiple_program(mods, params, target, target_host=None, ops=None): @@ -122,6 +124,9 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No env = TaskExtractEnv.get() + # merge target and target host + target, target_host = Target.check_and_update_host_consist(target, target_host) + # run compiler to collect all TOPI calls during compilation env.reset(ops) with env: @@ -152,7 +157,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No tasks = [] for task_name, args in env.get_tasks(): try: - tsk = create(task_name, args, target=target, target_host=target_host) + tsk = create(task_name, args, target=target) tasks.append(tsk) except topi.InvalidShapeError: logger.warning("Invalid shape during AutoTVM task creation") diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index 52f0996c800c3..0d60ca929d7bc 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -175,6 +175,9 @@ def __getstate__(self): # and restore the function by name when unpickling it. import cloudpickle # pylint: disable=import-outside-toplevel + self.target, self.target_host = Target.check_and_update_host_consist( + self.target, self.target_host + ) return { "name": self.name, "args": self.args, @@ -182,7 +185,7 @@ def __getstate__(self): "config_space": self.config_space, "flop": self.flop, "target": self.target, - "target_host": self.target_host, + "target_host": self.target.host, "func": cloudpickle.dumps(self.func), } @@ -195,8 +198,9 @@ def __setstate__(self, state): self.config_space = state["config_space"] self.func = cloudpickle.loads(state["func"]) self.flop = state["flop"] - self.target = state["target"] - self.target_host = state["target_host"] + self.target, self.target_host = Target.check_and_update_host_consist( + state["target"], state["target_host"] + ) def __repr__(self): return "Task(func_name=%s, args=%s, kwargs=%s, workload=%s)" % ( @@ -448,6 +452,8 @@ def create(task_name, args, target, target_host=None): if isinstance(target, str): target = Target(target) + target, target_host = Target.check_and_update_host_consist(target, target_host) + # init config space ret.config_space = ConfigSpace() @@ -459,7 +465,7 @@ def create(task_name, args, target, target_host=None): ret.flop = ret.config_space.flop or compute_flop(sch) ret.target = target - ret.target_host = target_host + ret.target_host = target.host return ret diff --git a/python/tvm/contrib/peak.py b/python/tvm/contrib/peak.py index 833a505f6d425..8e8e158b07401 100644 --- a/python/tvm/contrib/peak.py +++ b/python/tvm/contrib/peak.py @@ -20,6 +20,7 @@ import logging import tvm from tvm import te +from tvm.target import Target from . import utils from .. import rpc @@ -86,6 +87,8 @@ def measure_bandwidth_sum( GBPS: float gigabyte per second """ + target, target_host = Target.check_and_update_host_consist(target, target_host) + n, m = total_item, item_per_thread n //= lanes @@ -107,7 +110,7 @@ def measure_bandwidth_sum( s[y].unroll(k) try: - func = tvm.build(s, [x, y], target, target_host=target_host) + func = tvm.build(s, [x, y], target) x = tvm.nd.empty((n,), dtype=dtype, device=dev) y = tvm.nd.empty((n // m,), dtype=dtype, device=dev) @@ -151,6 +154,7 @@ def measure_bandwidth_all_types( result: list a list of (type_name, GBPS) pairs """ + target, target_host = Target.check_and_update_host_consist(target, target_host) max_threads = target.max_num_threads result = [] @@ -221,6 +225,7 @@ def measure_compute_mad( GOPS: float giga operation per second """ + target, target_host = Target.check_and_update_host_consist(target, target_host) n = total_item @@ -272,7 +277,7 @@ def mad_func(x, y): s = te.create_schedule(y.op) try: - func = tvm.build(s, [y], target, target_host=target_host) + func = tvm.build(s, [y], target) func = _convert_to_remote(func, remote) time_f = func.time_evaluator(func.entry_name, dev, number=n_times) y = tvm.nd.empty((n,), dtype=dtype, device=dev) @@ -313,6 +318,8 @@ def measure_compute_all_types( result: list a list of (type_name, GFLOPS/GIOPS) pairs """ + target, target_host = Target.check_and_update_host_consist(target, target_host) + result = [] for base_type in ["float", "int"]: for bits in [16, 32, 64]: @@ -357,7 +364,7 @@ def measure_peak_all(target, target_host, host, port): port: int """ - target = tvm.target.Target(target) + target, target_host = Target.check_and_update_host_consist(target, target_host) remote = rpc.connect(host, port) n_times = 20 diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 79c9cef801c30..9f56a9b82a7eb 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -231,8 +231,7 @@ def _build_for_device(input_mod, target, target_host): mdev : tvm.module A module that contains device code. """ - target = Target(target) - target_host = Target(target_host) + target, target_host = Target.check_and_update_host_consist(target, target_host) device_type = ndarray.device(target.kind.name, 0).device_type mod_mixed = input_mod @@ -399,8 +398,12 @@ def build(inputs, args=None, target=None, target_host=None, name="default_functi if not isinstance(mod, tvm.IRModule): raise ValueError("inputs must be Schedule, IRModule," "or dict of str to IRModule.") + target_input_mod, target_host = Target.check_and_update_host_consist( + target_input_mod, target_host + ) + if not target_host: - for tar, _ in target_input_mod.items(): + for tar, mod in target_input_mod.items(): tar = Target(tar) device_type = ndarray.device(tar.kind.name, 0).device_type if device_type == ndarray.cpu(0).device_type: @@ -409,6 +412,10 @@ def build(inputs, args=None, target=None, target_host=None, name="default_functi if not target_host: target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" + target_input_mod, target_host = Target.check_and_update_host_consist( + target_input_mod, target_host + ) + mod_host_all = tvm.IRModule({}) device_modules = [] diff --git a/python/tvm/driver/tvmc/autotuner.py b/python/tvm/driver/tvmc/autotuner.py index 187b7c5d2a315..99ed117893640 100644 --- a/python/tvm/driver/tvmc/autotuner.py +++ b/python/tvm/driver/tvmc/autotuner.py @@ -28,6 +28,7 @@ from tvm.autotvm.tuner import GridSearchTuner from tvm.autotvm.tuner import RandomTuner from tvm.autotvm.tuner import XGBTuner +from tvm.target import Target from . import common, composite_target, frontends from .common import TVMCException @@ -242,6 +243,8 @@ def drive_tune(args): ) target, extra_targets = common.target_from_cli(args.target) + target_host = args.target_host + target, target_host = Target.check_and_update_host_consist(target, target_host) mod, params = frontends.load_model(args.FILE, args.model_format, shape_dict=args.input_shapes) for codegen_from_cli in extra_targets: @@ -298,7 +301,6 @@ def drive_tune(args): mod=mod, params=params, target=target, - target_host=args.target_host, alter_layout=args.desired_layout, hardware_params=hardware_params, include_simple_tasks=args.include_simple_tasks, @@ -321,7 +323,6 @@ def drive_tune(args): mod=mod, params=params, target=target, - target_host=args.target_host, alter_layout=args.desired_layout, ) @@ -362,13 +363,14 @@ def autotvm_get_tuning_tasks(mod, params, target, target_host=None, alter_layout tasks : list of autotvm.Tasks list of tasks to be tuned """ + target, target_host = Target.check_and_update_host_consist(target, target_host) + if alter_layout: mod = common.convert_graph_layout(mod, alter_layout) tasks = autotvm.task.extract_from_program( mod["main"], target=target, - target_host=target_host, params=params, ) @@ -410,6 +412,8 @@ def autoscheduler_get_tuning_tasks( weights : List[int] the weight (i.e. the number of appearance) of extracted tasks """ + target, target_host = Target.check_and_update_host_consist(target, target_host) + if alter_layout: mod = common.convert_graph_layout(mod, alter_layout) @@ -418,7 +422,6 @@ def autoscheduler_get_tuning_tasks( mod["main"], params, target=target, - target_host=target_host, hardware_params=hardware_params, include_simple_tasks=include_simple_tasks, ) diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 5bdb578f2c16c..f484290bb5d02 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -27,6 +27,7 @@ from tvm import relay, runtime from tvm.contrib import cc from tvm.contrib import utils +from tvm.target import Target from . import common, composite_target, frontends from .main import register_parser @@ -192,6 +193,7 @@ def compile_model( tvm_target, extra_targets = common.target_from_cli(target) target_host = tvm_target if not target_host else target_host + tvm_target, target_host = Target.check_and_update_host_consist(tvm_target, target_host) for codegen_from_cli in extra_targets: codegen = composite_target.get_codegen_by_target(codegen_from_cli["name"]) @@ -214,20 +216,16 @@ def compile_model( config["relay.backend.use_auto_scheduler"] = True with tvm.transform.PassContext(opt_level=3, config=config): logger.debug("building relay graph with autoscheduler") - graph_module = relay.build( - mod, target=target, params=params, target_host=target_host - ) + graph_module = relay.build(mod, target=target, params=params) else: with autotvm.apply_history_best(tuning_records): with tvm.transform.PassContext(opt_level=3, config=config): logger.debug("building relay graph with tuning records") - graph_module = relay.build( - mod, tvm_target, params=params, target_host=target_host - ) + graph_module = relay.build(mod, tvm_target, params=params) else: with tvm.transform.PassContext(opt_level=3, config=config): logger.debug("building relay graph (no tuning records provided)") - graph_module = relay.build(mod, tvm_target, params=params, target_host=target_host) + graph_module = relay.build(mod, tvm_target, params=params) # Generate output dump files with sources dump_code = dump_code or [] diff --git a/python/tvm/exec/measure_peak.py b/python/tvm/exec/measure_peak.py index 3b502a96d09c6..d8840fadd802e 100644 --- a/python/tvm/exec/measure_peak.py +++ b/python/tvm/exec/measure_peak.py @@ -25,6 +25,7 @@ import argparse import logging +from tvm.target import Target from ..contrib.peak import measure_peak_all @@ -43,6 +44,9 @@ def main(): args = parser.parse_args() logging.basicConfig(level=logging.INFO) + args.target, args.target_host = Target.check_and_update_host_consist( + args.target, args.target_host + ) measure_peak_all(args.target, args.target_host, args.rpc_host, args.rpc_port) diff --git a/python/tvm/relay/backend/_backend.py b/python/tvm/relay/backend/_backend.py index 65b0c0ba87c7a..6df83559645dd 100644 --- a/python/tvm/relay/backend/_backend.py +++ b/python/tvm/relay/backend/_backend.py @@ -17,6 +17,7 @@ """The interface of expr function exposed from C++.""" import tvm._ffi import tvm.driver +from tvm.target import Target @tvm._ffi.register_func("relay.backend.lower") @@ -78,9 +79,9 @@ def build(mod, target, target_host=None): module : tvm.Module The runtime module. """ - if target_host == "": - target_host = None - return tvm.driver.build(mod, target=target, target_host=target_host) + target_host = None if target_host == "" else target_host + target, target_host = Target.check_and_update_host_consist(target, target_host) + return tvm.driver.build(mod, target=target) @tvm._ffi.register_func("relay._tensor_value_repr") diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index 7e0d4acc5453b..0b6d1372d0505 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -28,6 +28,7 @@ from tvm import autotvm from tvm.relay import expr as _expr from tvm.relay.backend.interpreter import Executor +from tvm.target import Target from . import _vm @@ -62,10 +63,13 @@ def compile(mod, target=None, target_host=None, params=None): exec : tvm.runtime.vm.Executable The VM executable that contains both library code and bytecode. """ + target, target_host = Target.check_and_update_host_consist( + target, target_host, target_is_dict_key=False + ) compiler = VMCompiler() if params: compiler.set_params(params) - compiler.lower(mod, target, target_host) + compiler.lower(mod, target) compiler.codegen() return compiler.get_exec() @@ -130,6 +134,10 @@ def lower(self, mod, target=None, target_host=None): """ target = self._update_target(target) target_host = self._update_target_host(target, target_host) + target, target_host = Target.check_and_update_host_consist( + target, target_host, target_is_dict_key=False + ) + tophub_context = self._tophub_context(target) with tophub_context: self._lower(mod, target, target_host) @@ -167,6 +175,10 @@ def optimize(self, mod, target=None, target_host=None, params=None): """ target = self._update_target(target) target_host = self._update_target_host(target, target_host) + target, target_host = Target.check_and_update_host_consist( + target, target_host, target_is_dict_key=False + ) + if params: self.set_params(params) return self._optimize(mod, target, target_host), self.get_params() @@ -206,6 +218,9 @@ def _update_target_host(self, target, target_host): """Update target host.""" target_host = None if target_host == "" else target_host if not target_host: + for _, tgt in target.items(): + if tgt.host is not None: + return tgt.host for device_type, tgt in target.items(): if device_type.value == tvm.nd.cpu(0).device_type: target_host = tgt diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 4795a2d386857..ed59ad9bdc8f9 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -25,6 +25,7 @@ from tvm.ir.transform import PassContext from tvm.tir import expr as tvm_expr +from tvm.target import Target from .. import nd as _nd, autotvm, register_func from ..target import Target from ..contrib import graph_executor as _graph_rt @@ -114,6 +115,9 @@ def build(self, mod, target=None, target_host=None, params=None): The runtime factory for the TVM graph executor. """ target = _update_target(target) + target, target_host = Target.check_and_update_host_consist( + target, target_host, target_is_dict_key=False + ) # Setup the params. if params: @@ -205,7 +209,8 @@ def _build_module_no_factory(mod, target=None, target_host=None, params=None, mo This wrapper is suitable to be used from other programming languages as the runtime::Module can be freely passed between language boundaries. """ - return build(mod, target, target_host, params, mod_name).module + target, target_host = Target.check_and_update_host_consist(target, target_host) + return build(mod, target, params=params, mod_name=mod_name).module def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"): @@ -263,14 +268,16 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default" "instead of deprecated parameter mod (tvm.relay.function.Function)", DeprecationWarning, ) - target = _update_target(target) - if isinstance(target_host, (str, Target)): target_host = Target(target_host) elif target_host: raise ValueError("target host must be the type of str, " + "tvm.target.Target, or None") + target, target_host = Target.check_and_update_host_consist( + target, target_host, target_is_dict_key=False + ) + # If current dispatch context is fallback context (the default root context), # then load pre-tuned parameters from TopHub if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext): @@ -280,7 +287,7 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default" with tophub_context: bld_mod = BuildModule() - graph_json, runtime_mod, params = bld_mod.build(ir_mod, target, target_host, params) + graph_json, runtime_mod, params = bld_mod.build(mod=ir_mod, target=target, params=params) executor_factory = _graph_executor_factory.GraphExecutorFactoryModule( ir_mod, target, graph_json, runtime_mod, mod_name, params ) diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index e3ef51158c5a7..6d0a0635221e4 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -46,7 +46,7 @@ class Target(Object): - :py:func:`tvm.target.intel_graphics` create Intel Graphics target """ - def __init__(self, tag_or_str_or_dict, host_tag_or_str_or_dict=None): + def __init__(self, target, host=None): """Construct a TVM target object from 1) Raw target string 2) Target config dict @@ -54,7 +54,7 @@ def __init__(self, tag_or_str_or_dict, host_tag_or_str_or_dict=None): Parameters ---------- - tag_or_str_or_dict : Union[str, Dict[str, Any]] + target : Union[str, Dict[str, Any]] Can be one of a literal target string, a json string describing a configuration, or a dictionary of configuration options. When using a dictionary or json string to configure target, the @@ -87,21 +87,21 @@ def __init__(self, tag_or_str_or_dict, host_tag_or_str_or_dict=None): An llvm setting that is one of 'hard' or 'soft' indicating whether to use hardware or software floating-point operations. host : Union[str, Dict[str, Any]] (optional) - Description for target host. Can be recursive. Similar to tag_or_str_or_dict. - host_tag_or_str_or_dict : Optional[Union[str, Dict[str, Any]]] - Similar to tag_or_str_or_dict but for target host. Can be one of a literal - target host string, a json string describing a configuration, or a dictionary of - configuration options. When using a dictionary or json string to configure target, - the possible values are same as tag_or_str_or_dict. + Description for target host. Can be recursive. Similar to target. + host : Optional[Union[str, Dict[str, Any]]] + Similar to target but for target host. Can be one of a literal target host string, + a json string describing a configuration, or a dictionary of configuration options. + When using a dictionary or json string to configure target, the possible values are + same as target. """ - if not isinstance(tag_or_str_or_dict, (dict, str, Target)): + if target is None or not isinstance(target, (dict, str, Target)): raise ValueError("target has to be a string or dictionary.") - if host_tag_or_str_or_dict is not None: - self.__init_handle_by_constructor__( - _ffi_api.Target, Target(tag_or_str_or_dict), Target(host_tag_or_str_or_dict) - ) + if host is not None: + if not isinstance(host, (dict, str, Target)): + raise ValueError("target host has to be a string or dictionary.") + self.__init_handle_by_constructor__(_ffi_api.Target, Target(target), Target(host)) else: - self.__init_handle_by_constructor__(_ffi_api.Target, tag_or_str_or_dict) + self.__init_handle_by_constructor__(_ffi_api.Target, target) def __enter__(self): _ffi_api.TargetEnterScope(self) @@ -113,6 +113,9 @@ def __exit__(self, ptype, value, trace): def export(self): return _ffi_api.TargetExport(self) + def with_host(self, host=None): + return _ffi_api.WithHost(self, Target(host)) + @staticmethod def current(allow_none=True): """Returns the current target. @@ -164,6 +167,37 @@ def list_kinds(): """Returns the list of available target names.""" return list(_ffi_api.ListTargetKinds()) + @staticmethod + def check_and_update_host_consist(target, host=None, target_is_dict_key=True): + """A helper function that merges a legacy "target, target_host" pair, then returns + the merged target and its host field. The function is for legacy target and target + host pair only, and should not be used in the new target system. + + Parameters + ---------- + target : Union[str, Dict[str, Any], Target] + The target or heterogeneous target + host : Union[str, Dict[str, Any], Target, None] + The target host + target_is_dict_key : Bool + When the type of target is dict, whether Target is the key (Otherwise the value) + """ + if isinstance(target, dict) and "kind" not in target: + new_target = {} + for tgt, mod in target.items(): + if not target_is_dict_key: + tgt, mod = mod, tgt + if isinstance(tgt, (dict, str, Target)): + tgt, host = Target.check_and_update_host_consist(tgt, host) + if not target_is_dict_key: + tgt, mod = mod, tgt + new_target[tgt] = mod + target = new_target + else: + target = Target(target, host) + host = target.host + return target, host + # TODO(@tvm-team): Deprecate the helper functions below. Encourage the usage of config dict instead. diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index b3c62f01c7c8d..be78bc4aa9f95 100755 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -1397,9 +1397,12 @@ void GetPerStoreFeaturesFromFile(const std::string& filename, int max_lines, int if (find_res == task_cache.end()) { // rebuild task Array tensors = (*workload_key_to_tensors)(workload_key); - task = SearchTask(ComputeDAG(tensors), workload_key, cur_inp->task->target, - cur_inp->task->target_host, cur_inp->task->hardware_params, - cur_inp->task->layout_rewrite_option, cur_inp->task->task_input_names); + Target target = cur_inp->task->target; + Target target_host = cur_inp->task->target_host; + CheckAndUpdateHostConsistency(&target, &target_host); + task = SearchTask(ComputeDAG(tensors), workload_key, target, target_host, + cur_inp->task->hardware_params, cur_inp->task->layout_rewrite_option, + cur_inp->task->task_input_names); task_id = task_cache.size(); // compute min cost for each task @@ -1466,10 +1469,13 @@ void GetPerStoreFeaturesFromMeasurePairs(const Array& inputs, // The measure input is incomplete, rebuild task for incomplete measure pairs read from file try { Array tensors = (*workload_key_to_tensors)(workload_key); + Target target = inputs[i]->task->target; + Target target_host = inputs[i]->task->target_host; + CheckAndUpdateHostConsistency(&target, &target_host); task = - SearchTask(ComputeDAG(tensors), workload_key, inputs[i]->task->target, - inputs[i]->task->target_host, inputs[i]->task->hardware_params, - inputs[i]->task->layout_rewrite_option, inputs[i]->task->task_input_names); + SearchTask(ComputeDAG(tensors), workload_key, target, target_host, + inputs[i]->task->hardware_params, inputs[i]->task->layout_rewrite_option, + inputs[i]->task->task_input_names); } catch (std::exception& e) { // Cannot build ComputeDAG from workload key, the task may have not been registered in // this search round diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc index 5dafa8d987020..af37443d91e2a 100644 --- a/src/auto_scheduler/measure_record.cc +++ b/src/auto_scheduler/measure_record.cc @@ -163,8 +163,11 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> { writer->WriteArrayItem(std::string(data.workload_key)); writer->WriteArrayItem(data.target->str()); writer->WriteArrayItem(*data.hardware_params.get()); - if (data.target_host.defined()) { - writer->WriteArrayItem(data.target_host->str()); + ::tvm::Target target = data.target; + ::tvm::Target target_host = data.target_host; + ::tvm::CheckAndUpdateHostConsistency(&target, &target_host); + if (target_host.defined()) { + writer->WriteArrayItem(target_host->str()); } else { writer->WriteArrayItem(std::string("")); } @@ -200,6 +203,7 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> { reader->Read(&str_value); if (!str_value.empty()) { data->target_host = ::tvm::Target(str_value); + ::tvm::CheckAndUpdateHostConsistency(&data->target, &data->target_host); } s = reader->NextArrayItem(); ICHECK(s); diff --git a/src/auto_scheduler/search_task.cc b/src/auto_scheduler/search_task.cc index 58bdb6ca8359f..db53a325fdc4a 100755 --- a/src/auto_scheduler/search_task.cc +++ b/src/auto_scheduler/search_task.cc @@ -53,6 +53,7 @@ HardwareParams::HardwareParams(int num_cores, int vector_unit_bytes, int cache_l HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target, const Target& target_host) { + // There is no use of target_host so no updates here in the function. const auto device_type = target->kind->device_type; if (device_type == kDLCPU) { return HardwareParams(tvm::runtime::threading::MaxConcurrency(), 64, 64, 0, 0, 0, 0, 0); @@ -138,6 +139,7 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host, Optional hardware_params, LayoutRewriteOption layout_rewrite_option, Array task_input_names) { + CheckAndUpdateHostConsistency(&target, &target_host); auto node = make_object(); node->compute_dag = std::move(compute_dag); node->workload_key = std::move(workload_key); @@ -167,6 +169,7 @@ TVM_REGISTER_GLOBAL("auto_scheduler.SearchTask") .set_body_typed([](ComputeDAG compute_dag, String workload_key, Target target, Target target_host, Optional hardware_params, int layout_rewrite_option, Array task_input_names) { + CheckAndUpdateHostConsistency(&target, &target_host); return SearchTask(compute_dag, workload_key, target, target_host, hardware_params, LayoutRewriteOption(layout_rewrite_option), task_input_names); }); diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index bbbb7e3f9eb51..f30cecbf7f059 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -185,9 +185,11 @@ IRModule lower(te::Schedule sch, const Array& args, const std::strin return mod; } -std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target& target, - const Target& target_host, +std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target& target_arg, + const Target& target_host_arg, const transform::PassContext& pass_ctx) { + Target target = target_arg, target_host = target_host_arg; + CheckAndUpdateHostConsistency(&target, &target_host); Array mixed_pass_list = {BindTarget(target), tir::transform::VerifyMemory()}; @@ -253,31 +255,39 @@ std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target } // Build for heterogeneous execution. -runtime::Module build(const Map& inputs, const Target& target_host) { +runtime::Module build(const Map& inputs_arg, const Target& target_host_arg) { auto pass_ctx = transform::PassContext::Current(); std::vector device_modules; - Target target_host_val = target_host; + Map inputs = inputs_arg; + Target target_host = target_host_arg; + + // Fetch previous defined target host in targets + CheckAndUpdateHostConsistency(&inputs, &target_host); + if (!target_host.defined()) { for (const auto& it : inputs) { if (it.first->kind->device_type == kDLCPU || it.first->kind->device_type == kDLMicroDev) { - target_host_val = it.first; + target_host = it.first; break; } } } - if (!target_host_val.defined()) { - target_host_val = DefaultTargetHost(target_host_val); + if (!target_host.defined()) { + target_host = DefaultTargetHost(target_host); } + // Update target host for all targets + CheckAndUpdateHostConsistency(&inputs, &target_host); + IRModule mhost_all = IRModule(Map()); ICHECK(mhost_all.defined()) << "The host module must be defined"; for (const auto& it : inputs) { if (it.second.defined()) { - auto pair = SplitDevHostFuncs(it.second, it.first, target_host_val, pass_ctx); + auto pair = SplitDevHostFuncs(it.second, it.first, target_host, pass_ctx); auto& mhost = pair.first; auto& mdevice = pair.second; @@ -293,7 +303,7 @@ runtime::Module build(const Map& inputs, const Target& target_ } } - runtime::Module mhost = codegen::Build(mhost_all, target_host_val); + runtime::Module mhost = codegen::Build(mhost_all, target_host); // Import all modules for (const auto& it : device_modules) { if (it.operator->()) { @@ -304,21 +314,26 @@ runtime::Module build(const Map& inputs, const Target& target_ } // Build for heterogeneous execution when target is a string. -runtime::Module build(const Map& inputs, const Target& target_host) { - Map updated_input; - for (const auto& it : inputs) { - auto target = Target(it.first); +runtime::Module build(const Map& inputs_arg, const Target& target_host_arg) { + Map updated_inputs; + Target target_host = target_host_arg; + for (const auto& it : inputs_arg) { + Target target = Target(it.first); + CheckAndUpdateHostConsistency(&target, &target_host); Optional device = target->GetAttr("device"); if (device.defined() && device.value() == "vta") { target = Target("ext_dev"); } - updated_input.Set(target, it.second); + updated_inputs.Set(target, it.second); } - return build(updated_input, target_host); + return build(updated_inputs, target_host); } // Build for homogeneous execution. -runtime::Module build(const IRModule& funcs, const Target& target, const Target& target_host) { +runtime::Module build(const IRModule& funcs, const Target& target_arg, + const Target& target_host_arg) { + auto target = target_arg, target_host = target_host_arg; + CheckAndUpdateHostConsistency(&target, &target_host); Map inputs = {{target, funcs}}; return build(inputs, target_host); } diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 3995d5ab3568d..07bb51150bee5 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -235,8 +235,10 @@ class RelayBuildModule : public runtime::ModuleNode { * \param target_host Host target device */ void Build(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host) { + // Create protected variable targets_ from ground up targets_ = targets; target_host_ = target_host; + CheckAndUpdateHostConsistency(&targets_, &target_host_); BuildRelay(mod, params_); // Clear compile engine so that tuning schedules can be changed between runs. See issue #6096. CompileEngine::Global()->Clear(); @@ -460,6 +462,15 @@ class RelayBuildModule : public runtime::ModuleNode { */ void BuildRelay(IRModule relay_module, const std::unordered_map& params) { + Target target_host = GetTargetHost(); + // If no target_host has been set, we choose a default one, which is + // llvm if "codegen.LLVMModuleCreate" is accessible. + const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate"); + if (!target_host.defined()) target_host = (pf != nullptr) ? Target("llvm") : Target("stackvm"); + + // Update all the targets in the targets_ TargetsMap + CheckAndUpdateHostConsistency(&targets_, &target_host); + // Relay IRModule -> IRModule optimizations. relay_module = Optimize(relay_module, targets_, params); // Get the updated function. @@ -475,12 +486,6 @@ class RelayBuildModule : public runtime::ModuleNode { auto lowered_funcs = graph_codegen_->GetIRModule(); - Target target_host = GetTargetHost(); - // If no target_host has been set, we choose a default one, which is - // llvm if "codegen.LLVMModuleCreate" is accessible. - const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate"); - if (!target_host.defined()) target_host = (pf != nullptr) ? Target("llvm") : Target("stackvm"); - // Generate a placeholder function that attaches linked params as its arguments. if (target_host->GetAttr("link-params").value_or(Bool(false))) { CHECK(pf != nullptr) << "Unable to link-params with no target_host and no llvm codegen."; diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 906250c1bb0d0..1e231e65424d8 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -255,9 +255,11 @@ class VMFunctionCompiler : ExprFunctor { context_(context), target_host_(target_host), expr_device_map_(std::move(expr_device_map)) { + CheckAndUpdateHostConsistency(&targets, &target_host); for (const auto& it : targets) { targets_[it.first->value] = it.second; } + target_host_ = target_host; } VMFunction Compile(const GlobalVar& var, const Function& func) { @@ -900,6 +902,7 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe exec_ = make_object(); targets_ = targets; target_host_ = target_host; + CheckAndUpdateHostConsistency(&targets_, &target_host_); // Run the optimizations necessary to target the VM. context_.module = OptimizeModule(mod, targets_, target_host_); @@ -1001,8 +1004,11 @@ transform::Sequential MemoryOpt(tvm::Target host_target, TargetsMap targets) { return transform::Sequential(pass_seqs); } -IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetsMap& targets, - const Target& target_host) { +IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetsMap& targets_arg, + const Target& target_host_arg) { + TargetsMap targets = targets_arg; + Target target_host = target_host_arg; + CheckAndUpdateHostConsistency(&targets, &target_host); if (params_.size()) { BaseFunc base_func = mod->Lookup("main"); ICHECK(base_func->IsInstance()) diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc index dd0cfc85a5108..1dc204d43ba13 100644 --- a/src/relay/transforms/memory_alloc.cc +++ b/src/relay/transforms/memory_alloc.cc @@ -414,6 +414,7 @@ class DialectRewriter : public ExprMutator { namespace transform { Pass ManifestAlloc(Target target_host, Map targets) { + CheckAndUpdateHostConsistency(&targets, &target_host); return tvm::transform::CreateModulePass( [=](IRModule mod, const PassContext& pass_ctx) { DLOG(INFO) << "tvm::relay::transform::ManifestAlloc"; @@ -457,6 +458,7 @@ Pass ManifestAlloc(Target target_host, Map targets) { TVM_REGISTER_GLOBAL("relay.transform.ManifestAlloc") .set_body_typed([](Target target_host, Map targets) { + CheckAndUpdateHostConsistency(&targets, &target_host); return ManifestAlloc(target_host, targets); }); diff --git a/src/target/target.cc b/src/target/target.cc index 55ef5f1a4e24f..396e264ede4d4 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -51,9 +51,42 @@ class TargetInternal { static ObjectPtr FromRawString(const String& target_str); static ObjectPtr FromConfig(std::unordered_map config); static void ConstructorDispatcher(TVMArgs args, TVMRetValue* rv); + static Target WithHost(const Target& target, const Target& target_host) { + ObjectPtr n = make_object(*target.get()); + n->host = target_host; + return (Target)n; + } }; /********** Helper functions **********/ +Target Target::WithHost(const Target& target, const Target& host) { + return TargetInternal::WithHost(target, host); +} + +void CheckAndUpdateHostConsistency(Target* target, Target* host) { + *target = Target(*target, *host); + *host = (*target)->GetHost().value_or(Target()); +} + +void CheckAndUpdateHostConsistency(Map* targets, Target* host) { + Map new_targets; + for (auto& it : *targets) { + auto target = it.second; + CheckAndUpdateHostConsistency(&target, host); + new_targets.Set(it.first, target); + } + *targets = new_targets; +} + +void CheckAndUpdateHostConsistency(Map* targets, Target* host) { + Map new_targets; + for (auto& it : *targets) { + auto target = it.first; + CheckAndUpdateHostConsistency(&target, host); + new_targets.Set(target, it.second); + } + *targets = new_targets; +} static std::vector DeduplicateKeys(const std::vector& keys) { std::vector new_keys; @@ -374,7 +407,7 @@ Target::Target(const Map& config) { Target::Target(Target target, Target host) { ObjectPtr n = make_object(*target.get()); - CHECK(!n->host.defined()) + CHECK(!n->host.defined() || n->host.same_as(host)) << "ValueError: Adding a host to a target whose host field has been defined"; // add target host into host field n->host = std::move(host); @@ -407,12 +440,19 @@ Map TargetNode::Export() const { {"tag", this->tag}, {"keys", this->keys}, }; + if (this->host.defined()) { + result.Set("host", this->GetHost().value_or(Target())->Export()); + } for (const auto& kv : attrs) { result.Set(kv.first, kv.second); } return result; } +Optional TargetNode::GetHost() const { + return GetRef>(this->host.as()); +} + /*! \brief Entry to hold the Target context stack. */ struct TVMTargetThreadLocalEntry { /*! \brief The current target context */ @@ -606,6 +646,13 @@ ObjectPtr TargetInternal::FromConfig(std::unordered_mapkeys = DeduplicateKeys(keys); config.erase(kKeys); } + // parse host + if (config.count(kHost)) { + target->host = PackedFunc(ConstructorDispatcher)(config[kHost]).AsObjectRef(); + config.erase(kHost); + } else { + target->host = NullOpt; + } // parse attrs std::unordered_map attrs; for (const auto& cfg_kv : config) { @@ -618,13 +665,6 @@ ObjectPtr TargetInternal::FromConfig(std::unordered_maphost = PackedFunc(ConstructorDispatcher)(config[kHost]).AsObjectRef(); - config.erase(kHost); - } else { - target->host = NullOpt; - } // set default attribute values if they do not exist for (const auto& kv : target->kind->key2default_) { if (!attrs.count(kv.first)) { @@ -647,6 +687,7 @@ TVM_REGISTER_GLOBAL("target.TargetEnterScope").set_body_typed(TargetInternal::En TVM_REGISTER_GLOBAL("target.TargetExitScope").set_body_typed(TargetInternal::ExitScope); TVM_REGISTER_GLOBAL("target.TargetCurrent").set_body_typed(Target::Current); TVM_REGISTER_GLOBAL("target.TargetExport").set_body_typed(TargetInternal::Export); +TVM_REGISTER_GLOBAL("target.WithHost").set_body_typed(TargetInternal::WithHost); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& obj, ReprPrinter* p) { diff --git a/tests/micro/zephyr/test_zephyr.py b/tests/micro/zephyr/test_zephyr.py index b4731f16d99f5..1db3d505f4909 100644 --- a/tests/micro/zephyr/test_zephyr.py +++ b/tests/micro/zephyr/test_zephyr.py @@ -51,8 +51,9 @@ def _make_sess_from_op(model, zephyr_board, west_cmd, op_name, sched, arg_bufs): target = tvm.target.target.micro(model) + target = tvm.target.Target(target=target, host=target) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - mod = tvm.build(sched, arg_bufs, target, target_host=target, name=op_name) + mod = tvm.build(sched, arg_bufs, target=target, name=op_name) return _make_session(model, target, zephyr_board, west_cmd, mod) diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index 87629518566b5..772d374bc5112 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -72,7 +72,7 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0, groups=1): # validation dev = tvm.gpu(0) - f = tvm.build(s, [X, W, Y], "cuda", target_host="llvm", name="conv2d") + f = tvm.build(s, [X, W, Y], "cuda --host=llvm", name="conv2d") x_np = np.random.uniform(-1, 1, xshape).astype(data_dtype) w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype) y_np = np.zeros(yshape).astype(data_dtype) @@ -151,7 +151,7 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0, groups=1): # validation dev = tvm.gpu(0) - f = tvm.build(s, [X, W, Y], "cuda", target_host="llvm", name="conv3d") + f = tvm.build(s, [X, W, Y], target="cuda --host=llvm", name="conv3d") x_np = np.random.uniform(-1, 1, xshape).astype(data_dtype) w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype) y_np = np.zeros(yshape).astype(data_dtype) @@ -183,7 +183,7 @@ def verify_softmax(shape, axis, dtype="float32"): b_np = tvm.topi.testing.softmax_python(a_np) a = tvm.nd.array(a_np, dev) b = tvm.nd.array(b_np, dev) - f = tvm.build(s, [A, B], "cuda", target_host="llvm", name="softmax") + f = tvm.build(s, [A, B], target="cuda --host=llvm", name="softmax") f(a, b) tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3) @@ -200,7 +200,7 @@ def verify_softmax_4d(shape, dtype="float32"): b_np = b_np.reshape(n, h, w, c).transpose(0, 3, 1, 2) a = tvm.nd.array(a_np, dev) b = tvm.nd.array(b_np, dev) - f = tvm.build(s, [A, B], "cuda", target_host="llvm", name="softmax") + f = tvm.build(s, [A, B], target="cuda --host=llvm", name="softmax") f(a, b) tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3) diff --git a/tests/python/contrib/test_dlpack.py b/tests/python/contrib/test_dlpack.py index 6ff2529f75702..8bf9069b78cfc 100644 --- a/tests/python/contrib/test_dlpack.py +++ b/tests/python/contrib/test_dlpack.py @@ -49,7 +49,9 @@ def test(): k = te.reduce_axis((0, n), name="k") ZZ = te.compute((n, n), lambda i, j: te.sum(XX[i, k] * YY[k, j], axis=k)) s = te.create_schedule(ZZ.op) - f = tvm.build(s, [XX, YY, ZZ], target_host="llvm", name="f") + # No need to speficy target_host if it's llvm + # Otherwise you will need to specify the target and target_host + f = tvm.build(s, [XX, YY, ZZ], name="f") f_pytorch = to_pytorch_func(f) zz2 = torch.empty(137, 137) diff --git a/tests/python/contrib/test_miopen.py b/tests/python/contrib/test_miopen.py index 4847c0e1b7bc9..630bfc0110382 100644 --- a/tests/python/contrib/test_miopen.py +++ b/tests/python/contrib/test_miopen.py @@ -53,7 +53,7 @@ def test_conv2d(): def verify(): dev = tvm.rocm(0) - f = tvm.build(s, [X, W, Y], "rocm", target_host="llvm", name="conv2d") + f = tvm.build(s, [X, W, Y], "rocm --host=llvm", name="conv2d") x = tvm.nd.array(np.random.uniform(-1, 1, xshape).astype(np.float32), dev) w = tvm.nd.array(np.random.uniform(-1, 1, wshape).astype(np.float32), dev) y = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32), dev) @@ -63,7 +63,7 @@ def verify(): X, W, (stride_h, stride_w), (pad_h, pad_w), (dilation_h, dilation_w) ) s_ref = te.create_schedule(Y_ref.op) - f_ref = tvm.build(s_ref, [X, W, Y_ref], "rocm", target_host="llvm") + f_ref = tvm.build(s_ref, [X, W, Y_ref], "rocm --host=llvm") y_ref = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32), dev) f_ref(x, w, y_ref) print("Max abs diff:", np.max(np.abs(y.asnumpy() - y_ref.asnumpy()))) diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 0180c35d7a267..17b2834feb11f 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -176,8 +176,7 @@ def test_compile_opencl(tflite_mobilenet_v1_0_25_128): graph, lib, params, dumps = tvmc.compile( tflite_mobilenet_v1_0_25_128, - target="opencl", - target_host="llvm", + target="opencl --host=llvm", alter_layout="NCHW", ) diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 6a410ea51e161..6c5f7f3690575 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -164,7 +164,8 @@ def run_tvm_graph( return vmobj_to_list(result) else: with tvm.transform.PassContext(opt_level=opt_level, disabled_pass=disabled_pass): - graph, lib, params = relay.build(mod, target, target_host, params) + target = tvm.target.Target(target, target_host) + graph, lib, params = relay.build(mod, target=target, params=params) from tvm.contrib import graph_executor m = graph_executor.create(graph, lib, dev) diff --git a/tests/python/integration/test_reduce.py b/tests/python/integration/test_reduce.py index cf140be94b86a..19bd03ec79ce2 100644 --- a/tests/python/integration/test_reduce.py +++ b/tests/python/integration/test_reduce.py @@ -45,7 +45,9 @@ def check_device(device, host="llvm"): if not tvm.testing.device_enabled(device): print("skip because %s is not enabled.." % device) return - freduce = tvm.build(s, args=[A, B], target=device, target_host=host, name="myreduce") + freduce = tvm.build( + s, args=[A, B], target=tvm.target.Target(device, host), name="myreduce" + ) # launch the kernel. n = 1028 m = 129 diff --git a/tests/python/integration/test_tuning.py b/tests/python/integration/test_tuning.py index 170b4709262b8..45e0958a02405 100644 --- a/tests/python/integration/test_tuning.py +++ b/tests/python/integration/test_tuning.py @@ -131,12 +131,11 @@ def teardown_module(): def get_sample_task(target=tvm.target.cuda(), target_host=None): + target = tvm.target.Target(target, target_host) + target_host = target.host """return a sample task for testing""" task = autotvm.task.create( - "testing/conv2d_no_batching", - args=(1, 7, 7, 512, 512, 3, 3), - target=target, - target_host=target_host, + "testing/conv2d_no_batching", args=(1, 7, 7, 512, 512, 3, 3), target=target ) return task, target diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index c1bdc3ff9fd06..7ca06c5c97e04 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -808,8 +808,7 @@ def test_vm_rpc(): upload it to a remote machine using RPC and then execute it on the other machine. """ - target = "llvm" - target_host = "llvm" + target = tvm.target.Target("llvm --host=llvm") # Build a IRModule. x = relay.var("x", shape=(10, 1)) @@ -817,7 +816,7 @@ def test_vm_rpc(): mod = IRModule.from_expr(f) # Compile to VMExecutable. - vm_exec = vm.compile(mod, target=target, target_host=target_host) + vm_exec = vm.compile(mod, target=target) # Export to Disk temp = utils.tempdir() diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index 7605b70be6f49..d82cfd447a403 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -26,7 +26,6 @@ import tempfile import tvm.testing import pickle - from test_auto_scheduler_common import matmul_auto_scheduler_test from tvm.auto_scheduler import workload_registry @@ -336,8 +335,7 @@ def test_measure_target_host(): task = auto_scheduler.SearchTask( func=matmul_auto_scheduler_test, args=(512, 512, 512), - target="llvm", - target_host="llvm -mtriple=aarch64-linux-gnu", + target=tvm.target.Target("llvm", "llvm -mtriple=aarch64-linux-gnu"), ) inp = auto_scheduler.measure.MeasureInput(task, task.compute_dag.init_state) @@ -353,7 +351,7 @@ def test_measure_target_host(): raw_inp = inputs[0] recovered_inp = auto_scheduler.measure.recover_measure_input(raw_inp) - assert str(recovered_inp.task.target_host) == str(inp.task.target_host) + assert str(recovered_inp.task.target.host) == str(inp.task.target.host) @tvm.testing.requires_llvm diff --git a/tests/python/unittest/test_auto_scheduler_search_task.py b/tests/python/unittest/test_auto_scheduler_search_task.py index 78e85dc213e04..cd47f1e468ff4 100644 --- a/tests/python/unittest/test_auto_scheduler_search_task.py +++ b/tests/python/unittest/test_auto_scheduler_search_task.py @@ -70,7 +70,7 @@ def test_search_task_record(): # TODO(jcf94): Check the compute dag & hardware parameter assert task.workload_key == new_task.workload_key assert str(task.target) == str(new_task.target) - assert str(task.target_host) == str(new_task.target_host) + assert str(task.target.host) == str(new_task.target.host) assert task.layout_rewrite_option == new_task.layout_rewrite_option # Log with 1 task input @@ -86,7 +86,7 @@ def test_search_task_record(): new_task = auto_scheduler._ffi_api.DeserializeSearchTask(task_record) assert task.workload_key == new_task.workload_key assert str(task.target) == str(new_task.target) - assert str(task.target_host) == str(new_task.target_host) + assert str(task.target.host) == str(new_task.target.host) assert task.layout_rewrite_option == new_task.layout_rewrite_option assert len(new_task.task_input_names) == 1 assert new_task.task_input_names[0] == "test_input_0" @@ -107,7 +107,7 @@ def test_search_task_record(): new_task = auto_scheduler._ffi_api.DeserializeSearchTask(task_record) assert task.workload_key == new_task.workload_key assert str(task.target) == str(new_task.target) - assert str(task.target_host) == str(new_task.target_host) + assert str(task.target.host) == str(new_task.target.host) assert task.layout_rewrite_option == new_task.layout_rewrite_option assert len(new_task.task_input_names) == 2 assert new_task.task_input_names[0] == "test_input_0" @@ -118,7 +118,7 @@ def test_search_task_record(): new_task = auto_scheduler._ffi_api.DeserializeSearchTask(v5_log) assert task.workload_key == new_task.workload_key assert str(task.target) == str(new_task.target) - assert str(task.target_host) == str(new_task.target_host) + assert str(task.target.host) == str(new_task.target.host) assert task.layout_rewrite_option == new_task.layout_rewrite_option assert len(new_task.task_input_names) == 0 @@ -139,7 +139,7 @@ def test_recover_measure_input_with_task_input(): new_task = measure_log[0].task assert task.workload_key == new_task.workload_key assert str(task.target) == str(new_task.target) - assert str(task.target_host) == str(new_task.target_host) + assert str(task.target.host) == str(new_task.target.host) assert task.layout_rewrite_option == new_task.layout_rewrite_option # Log with 1 task input @@ -160,7 +160,7 @@ def test_recover_measure_input_with_task_input(): new_task = measure_log[0].task assert task.workload_key == new_task.workload_key assert str(task.target) == str(new_task.target) - assert str(task.target_host) == str(new_task.target_host) + assert str(task.target.host) == str(new_task.target.host) assert task.layout_rewrite_option == new_task.layout_rewrite_option assert len(new_task.task_input_names) == 1 assert new_task.task_input_names[0] == "test_input_0" @@ -184,7 +184,7 @@ def test_recover_measure_input_with_task_input(): new_task = measure_log[0].task assert task.workload_key == new_task.workload_key assert str(task.target) == str(new_task.target) - assert str(task.target_host) == str(new_task.target_host) + assert str(task.target.host) == str(new_task.target.host) assert task.layout_rewrite_option == new_task.layout_rewrite_option assert len(new_task.task_input_names) == 2 assert new_task.task_input_names[0] == "test_input_0" @@ -196,7 +196,7 @@ def test_recover_measure_input_with_task_input(): new_task = measure_log[0].task assert task.workload_key == new_task.workload_key assert str(task.target) == str(new_task.target) - assert str(task.target_host) == str(new_task.target_host) + assert str(task.target.host) == str(new_task.target.host) assert task.layout_rewrite_option == new_task.layout_rewrite_option assert len(new_task.task_input_names) == 0 diff --git a/tests/python/unittest/test_crt.py b/tests/python/unittest/test_crt.py index e605ddd05036a..ac07fa34f46bf 100644 --- a/tests/python/unittest/test_crt.py +++ b/tests/python/unittest/test_crt.py @@ -32,6 +32,7 @@ import tvm import tvm.relay import tvm.testing +from tvm.target import Target from tvm.topi.utils import get_const_tuple from tvm.topi.testing import conv2d_nchw_python @@ -44,7 +45,7 @@ def _make_sess_from_op(workspace, op_name, sched, arg_bufs): with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - mod = tvm.build(sched, arg_bufs, TARGET, target_host=TARGET, name=op_name) + mod = tvm.build(sched, arg_bufs, Target(TARGET, TARGET), name=op_name) return _make_session(workspace, mod) diff --git a/tests/python/unittest/test_runtime_heterogeneous.py b/tests/python/unittest/test_runtime_heterogeneous.py index e97b349af36ed..5388dee2fa587 100644 --- a/tests/python/unittest/test_runtime_heterogeneous.py +++ b/tests/python/unittest/test_runtime_heterogeneous.py @@ -170,7 +170,8 @@ def check_device(device, target_device): ) target_flist = {target_device: lower_add, target_host: lower_sub} - mhost = tvm.build(target_flist, target_host=target_host) + target = tvm.target.Target(target, target_host) + mhost = tvm.build(target_flist, target=target) dev = [host_dev, device_dev] mod = graph_executor.create(graph, mhost, dev) params = {} @@ -399,7 +400,8 @@ def check_device(device, target_device): lower_add0.update(lower_add1) target_flist = {target_device: lower_add0, target_host: lower_sub} - mhost = tvm.build(target_flist, target_host=target_host) + target = tvm.target.Target(target, target_host) + mhost = tvm.build(target_flist, target=target) dev = [host_dev, device_dev] params = {} params["A"] = tensor_a = np.random.uniform(size=shape).astype(tensor_a.dtype) diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index 7ec09d8b9b486..256fd33387bf7 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -309,7 +309,7 @@ def check_remote_link_cl(remote): xo, xi = s[B].split(B.op.axis[0], factor=32) s[B].bind(xo, te.thread_axis("blockIdx.x")) s[B].bind(xi, te.thread_axis("threadIdx.x")) - f = tvm.build(s, [A, B], "opencl", target_host="llvm", name="myadd") + f = tvm.build(s, [A, B], "opencl --host=llvm", name="myadd") # Option 1: save modules separately and rely on remote compiler path_o = temp.relpath("myadd.o") path_cl = temp.relpath("myadd.cl") diff --git a/tests/python/unittest/test_target_codegen_blob.py b/tests/python/unittest/test_target_codegen_blob.py index f1290ddd1e51a..c7698197c111b 100644 --- a/tests/python/unittest/test_target_codegen_blob.py +++ b/tests/python/unittest/test_target_codegen_blob.py @@ -85,7 +85,7 @@ def test_cuda_lib(): from tvm.contrib import utils temp = utils.tempdir() - fn_add = tvm.build(s, [A, B], target="cuda", target_host="llvm", name="add") + fn_add = tvm.build(s, [A, B], target="cuda --host=llvm", name="add") path_lib = temp.relpath("deploy_lib.so") fn_add.export_library(path_lib) m = tvm.runtime.load_module(path_lib) diff --git a/tests/python/unittest/test_target_codegen_device.py b/tests/python/unittest/test_target_codegen_device.py index b1b14f448b4ef..4ce7a021981de 100644 --- a/tests/python/unittest/test_target_codegen_device.py +++ b/tests/python/unittest/test_target_codegen_device.py @@ -71,7 +71,7 @@ def check_target(device, host="stackvm"): if not tvm.testing.device_enabled(device) or not tvm.testing.device_enabled(host): return dev = tvm.device(device, 0) - mhost = tvm.driver.build(s, [A, B, D], target=device, target_host=host) + mhost = tvm.driver.build(s, [A, B, D], target=tvm.target.Target(device, host)) f = mhost.entry_func # launch the kernel. n = 1027 diff --git a/tests/python/unittest/test_target_codegen_hexagon.py b/tests/python/unittest/test_target_codegen_hexagon.py index b74d487f3fa7a..6ffb2f4741e80 100644 --- a/tests/python/unittest/test_target_codegen_hexagon.py +++ b/tests/python/unittest/test_target_codegen_hexagon.py @@ -53,7 +53,9 @@ def check_add(offload): m = tvm.build(s, [C, A, B], target=target, name="offload_add") hexm = m.imported_modules[0] else: - hexm = tvm.build(s, [C, A, B], target=target, target_host=target, name="native_add") + hexm = tvm.build( + s, [C, A, B], target=tvm.target.Target(target, target), name="native_add" + ) asm = hexm.get_source("s") vadds = re.findall(r"v[0-9]+.b = vadd\(v[0-9]+.b,v[0-9]+.b\)", asm) @@ -71,7 +73,7 @@ def test_llvm_target_features(): A = tvm.te.placeholder((128,), dtype="uint8", name="A") C = tvm.te.compute((128,), lambda i: A[i] + 1, name="C") s = tvm.te.create_schedule(C.op) - m = tvm.build(s, [C, A], target=target, target_host=target, name="add_one") + m = tvm.build(s, [C, A], target=tvm.target.Target(target, target), name="add_one") llvm_ir = m.get_source("ll") # Make sure we find +hvx-length128b in "attributes". fs = re.findall(r"attributes.*\+hvx-length128b", llvm_ir) diff --git a/tests/python/unittest/test_target_target.py b/tests/python/unittest/test_target_target.py index 7b998bef34a5d..2f885d39335b2 100644 --- a/tests/python/unittest/test_target_target.py +++ b/tests/python/unittest/test_target_target.py @@ -15,10 +15,10 @@ # specific language governing permissions and limitations # under the License. import json -import tvm +import sys import pytest -from tvm import te -from tvm.target import cuda, rocm, mali, intel_graphics, arm_cpu, vta, bifrost, hexagon +import tvm +from tvm.target import cuda, rocm, mali, intel_graphics, arm_cpu, vta, bifrost @tvm.target.generic_func @@ -121,7 +121,7 @@ def test_config_map(): def test_composite_target(): tgt = tvm.target.Target("composite --host=llvm --devices=cuda,opencl") assert tgt.kind.name == "composite" - assert tgt.attrs["host"].kind.name == "llvm" + assert tgt.host.kind.name == "llvm" assert len(tgt.attrs["devices"]) == 2 cuda_device, opencl_device = tgt.attrs["devices"] assert cuda_device.kind.name == "cuda" @@ -215,15 +215,58 @@ def test_target_host_warning(): Confirm that constructing a target with invalid attributes fails as expected. """ - with pytest.raises(ValueError): - tgt = tvm.target.Target("cuda --host nvidia/jetson-nano", "llvm") + with pytest.raises( + ValueError, match="Adding a host to a target whose host field has been defined" + ): + tvm.target.Target("cuda --host nvidia/jetson-nano", "llvm") + + +def test_target_host_merge_0(): + tgt = tvm.target.Target(tvm.target.Target("cuda --host nvidia/jetson-nano"), None) + assert tgt.kind.name == "cuda" + assert tgt.host.kind.name == "cuda" + assert tgt.host.attrs["arch"] == "sm_53" + assert tgt.host.attrs["shared_memory_per_block"] == 49152 + assert tgt.host.attrs["max_threads_per_block"] == 1024 + assert tgt.host.attrs["thread_warp_size"] == 32 + assert tgt.host.attrs["registers_per_block"] == 32768 + + +def test_target_host_merge_1(): + tgt = tvm.target.Target("cuda --host llvm") + tgt = tvm.target.Target(tgt, tgt.host) + assert tgt.kind.name == "cuda" + assert tgt.host.kind.name == "llvm" + + +def test_target_host_merge_2(): + with pytest.raises( + ValueError, match="Adding a host to a target whose host field has been defined" + ): + tvm.target.Target(tvm.target.Target("cuda --host llvm"), tvm.target.Target("llvm")) + + +@pytest.mark.skip(reason="Causing infinite loop because of pytest and handle issue") +def test_target_host_merge_3(): + with pytest.raises(ValueError, match=r"target host has to be a string or dictionary."): + tvm.target.Target(tvm.target.Target("cuda --host llvm"), 12.34) + + +def test_target_with_host(): + tgt = tvm.target.Target("cuda") + llvm = tvm.target.Target("llvm") + tgt = tgt.with_host(llvm) + assert tgt.kind.name == "cuda" + assert tgt.host.kind.name == "llvm" + cuda_host = tvm.target.Target("nvidia/jetson-nano") + tgt = tgt.with_host(cuda_host) + assert tgt.host.kind.name == "cuda" + assert tgt.host.attrs["arch"] == "sm_53" + assert tgt.host.attrs["shared_memory_per_block"] == 49152 + assert tgt.host.attrs["max_threads_per_block"] == 1024 + assert tgt.host.attrs["thread_warp_size"] == 32 + assert tgt.host.attrs["registers_per_block"] == 32768 if __name__ == "__main__": - test_target_dispatch() - test_target_string_parse() - test_target_create() - test_target_config() - test_config_map() - test_composite_target() - test_list_kinds() + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py b/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py index e0cba8421e837..c035fd063dba1 100644 --- a/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py +++ b/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py @@ -39,8 +39,9 @@ def test_out_of_bounds_llvm(index_a, index_b): tgt_host = "llvm" stmt = tvm.lower(s, [A, B, C], simple_mode=True) print(stmt) - fadd = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name="myadd") - dev = tvm.device(tgt, 0) + tgt = tvm.target.Target(tgt, tgt_host) + fadd = tvm.build(s, [A, B, C], target=tgt, name="myadd") + dev = tvm.device(tgt.kind.name, 0) a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), dev) b = tvm.nd.array(np.random.uniform(size=1024).astype(B.dtype), dev) c = tvm.nd.array(np.zeros(1024, dtype=C.dtype), dev) @@ -57,8 +58,9 @@ def test_in_bounds_llvm(): tgt = "llvm" tgt_host = "llvm" stmt = tvm.lower(s, [A, B, C], simple_mode=True) - fadd = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name="myadd") - dev = tvm.device(tgt, 0) + tgt = tvm.target.Target(tgt, tgt_host) + fadd = tvm.build(s, [A, B, C], target=tgt, name="myadd") + dev = tvm.device(tgt.kind.name, 0) a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), dev) b = tvm.nd.array(np.random.uniform(size=1024).astype(B.dtype), dev) c = tvm.nd.array(np.zeros(1024, dtype=C.dtype), dev) @@ -79,7 +81,8 @@ def test_out_of_bounds_vectorize_llvm(nn, index_a, index_b): tgt = "llvm" tgt_host = "llvm" stmt = tvm.lower(s, [a, b, c], simple_mode=True) - f = tvm.build(s, [a, b, c], tgt, target_host=tgt_host, name="myaddvec") + tgt = tvm.target.Target(tgt, tgt_host) + f = tvm.build(s, [a, b, c], target=tgt, name="myaddvec") dev = tvm.cpu(0) n = nn a = tvm.nd.array(np.random.uniform(size=(n)).astype(a.dtype), dev) diff --git a/tutorials/auto_scheduler/tune_network_mali.py b/tutorials/auto_scheduler/tune_network_mali.py index 13d1e4793ffa3..35751fa11f17d 100644 --- a/tutorials/auto_scheduler/tune_network_mali.py +++ b/tutorials/auto_scheduler/tune_network_mali.py @@ -139,8 +139,7 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): use_ndk = True # Path to cross compiler os.environ["TVM_NDK_CC"] = "/usr/bin/aarch64-linux-gnu-g++" -target_host = tvm.target.Target("llvm -mtriple=aarch64-linux-gnu") -target = tvm.target.Target("opencl -device=mali") +target = tvm.target.Target("opencl -device=mali", host="llvm -mtriple=aarch64-linux-gnu") dtype = "float32" log_file = "%s-%s-B%d-%s.json" % (network, layout, batch_size, target.kind.name) @@ -170,7 +169,7 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): # Extract tasks from the network print("Extract tasks...") mod, params, input_shape, output_shape = get_network(network, batch_size, layout, dtype=dtype) -tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target, target_host) +tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) for idx, task in enumerate(tasks): print("========== Task %d (workload key: %s) ==========" % (idx, task.workload_key)) @@ -198,7 +197,9 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): # # .. code-block:: python # -# tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target, target_host, hardware_params) +# tasks, task_weights = auto_scheduler.extract_tasks( +# mod["main"], params, target, hardware_params = hardware_params +# ) # ################################################################# @@ -240,7 +241,7 @@ def tune_and_evaluate(): with tvm.transform.PassContext( opt_level=3, config={"relay.backend.use_auto_scheduler": True} ): - lib = relay.build(mod, target=target, target_host=target_host, params=params) + lib = relay.build(mod, target, params=params) # Create graph executor print("=============== Request Remote ===============") diff --git a/tutorials/autotvm/tune_relay_mobile_gpu.py b/tutorials/autotvm/tune_relay_mobile_gpu.py index aefa600e3c3f7..2b109873c7506 100644 --- a/tutorials/autotvm/tune_relay_mobile_gpu.py +++ b/tutorials/autotvm/tune_relay_mobile_gpu.py @@ -201,12 +201,9 @@ def get_network(name, batch_size): # set :code:`use_android` to True if you use android phone. #### DEVICE CONFIG #### - -target = tvm.target.Target("opencl -device=mali") - # Replace "aarch64-linux-gnu" with the correct target of your board. # This target host is used for cross compilation. You can query it by :code:`gcc -v` on your device. -target_host = "llvm -mtriple=aarch64-linux-gnu" +target = tvm.target.Target("opencl -device=mali", host="llvm -mtriple=aarch64-linux-gnu") # Also replace this with the device key in your tracker device_key = "rk3399" @@ -317,7 +314,6 @@ def tune_and_evaluate(tuning_opt): tasks = autotvm.task.extract_from_program( mod["main"], target=target, - target_host=target_host, params=params, ops=(relay.op.get("nn.conv2d"),), ) @@ -330,9 +326,7 @@ def tune_and_evaluate(tuning_opt): with autotvm.apply_history_best(log_file): print("Compile...") with tvm.transform.PassContext(opt_level=3): - lib = relay.build_module.build( - mod, target=target, params=params, target_host=target_host - ) + lib = relay.build_module.build(mod, target=target, params=params) # export library tmp = tempdir() if use_android: diff --git a/tutorials/frontend/deploy_model_on_android.py b/tutorials/frontend/deploy_model_on_android.py index 8efcb706b3800..158280fe9447e 100644 --- a/tutorials/frontend/deploy_model_on_android.py +++ b/tutorials/frontend/deploy_model_on_android.py @@ -257,25 +257,21 @@ def transform_image(image): # Change target configuration. # Run `adb shell cat /proc/cpuinfo` to find the arch. arch = "arm64" -target = "llvm -mtriple=%s-linux-android" % arch -target_host = None +target = tvm.target.Target("llvm -mtriple=%s-linux-android" % arch) if local_demo: - target_host = None - target = "llvm" + target = tvm.target.Target("llvm") elif test_target == "opencl": - target_host = target - target = "opencl" + target = tvm.target.Target("opencl", host=target) elif test_target == "vulkan": - target_host = target - target = "vulkan" + target = tvm.target.Target("vulkan", host=target) input_name = "input_1" shape_dict = {input_name: x.shape} mod, params = relay.frontend.from_keras(keras_mobilenet_v2, shape_dict) with tvm.transform.PassContext(opt_level=3): - lib = relay.build(mod, target=target, target_host=target_host, params=params) + lib = relay.build(mod, target=target, params=params) # After `relay.build`, you will get three return values: graph, # library and the new parameter, since we do some optimization that will diff --git a/tutorials/frontend/from_darknet.py b/tutorials/frontend/from_darknet.py index 356dc16bedf0b..b29ed3d962c73 100644 --- a/tutorials/frontend/from_darknet.py +++ b/tutorials/frontend/from_darknet.py @@ -94,14 +94,13 @@ # Import the graph to Relay # ------------------------- # compile the model -target = "llvm" -target_host = "llvm" +target = tvm.target.Target("llvm", host="llvm") dev = tvm.cpu(0) data = np.empty([batch_size, net.c, net.h, net.w], dtype) shape = {"data": data.shape} print("Compiling the model...") with tvm.transform.PassContext(opt_level=3): - lib = relay.build(mod, target=target, target_host=target_host, params=params) + lib = relay.build(mod, target=target, params=params) [neth, netw] = shape["data"][2:] # Current image shape is 608x608 ###################################################################### diff --git a/tutorials/frontend/from_pytorch.py b/tutorials/frontend/from_pytorch.py index a0db518025e30..5f515e656bc80 100644 --- a/tutorials/frontend/from_pytorch.py +++ b/tutorials/frontend/from_pytorch.py @@ -100,11 +100,10 @@ # Relay Build # ----------- # Compile the graph to llvm target with given input specification. -target = "llvm" -target_host = "llvm" +target = tvm.target.Target("llvm", host="llvm") dev = tvm.cpu(0) with tvm.transform.PassContext(opt_level=3): - lib = relay.build(mod, target=target, target_host=target_host, params=params) + lib = relay.build(mod, target=target, params=params) ###################################################################### # Execute the portable graph on TVM diff --git a/tutorials/frontend/from_tensorflow.py b/tutorials/frontend/from_tensorflow.py index 96c001e4fd416..9c8d0f65878c9 100644 --- a/tutorials/frontend/from_tensorflow.py +++ b/tutorials/frontend/from_tensorflow.py @@ -70,12 +70,10 @@ # Target settings # Use these commented settings to build for cuda. -# target = 'cuda' -# target_host = 'llvm' +# target = tvm.target.Target("cuda", host="llvm") # layout = "NCHW" # dev = tvm.gpu(0) -target = "llvm" -target_host = "llvm" +target = tvm.target.Target("llvm", host="llvm") layout = None dev = tvm.cpu(0) @@ -145,7 +143,7 @@ # lib: target library which can be deployed on target with TVM runtime. with tvm.transform.PassContext(opt_level=3): - lib = relay.build(mod, target=target, target_host=target_host, params=params) + lib = relay.build(mod, target, params=params) ###################################################################### # Execute the portable graph on TVM diff --git a/tutorials/get_started/cross_compilation_and_rpc.py b/tutorials/get_started/cross_compilation_and_rpc.py index 75985fccf1f3b..3c23c49562623 100644 --- a/tutorials/get_started/cross_compilation_and_rpc.py +++ b/tutorials/get_started/cross_compilation_and_rpc.py @@ -225,16 +225,16 @@ def run_opencl(): # NOTE: This is the setting for my rk3399 board. You need to modify # them according to your environment. - target_host = "llvm -mtriple=aarch64-linux-gnu" opencl_device_host = "10.77.1.145" opencl_device_port = 9090 + target = tvm.target.Target("opencl", host="llvm -mtriple=aarch64-linux-gnu") # create schedule for the above "add one" compute declaration s = te.create_schedule(B.op) xo, xi = s[B].split(B.op.axis[0], factor=32) s[B].bind(xo, te.thread_axis("blockIdx.x")) s[B].bind(xi, te.thread_axis("threadIdx.x")) - func = tvm.build(s, [A, B], "opencl", target_host=target_host) + func = tvm.build(s, [A, B], target=target) remote = rpc.connect(opencl_device_host, opencl_device_port) diff --git a/tutorials/get_started/tensor_expr_get_started.py b/tutorials/get_started/tensor_expr_get_started.py index a9952c2422e02..c63a068360f2e 100644 --- a/tutorials/get_started/tensor_expr_get_started.py +++ b/tutorials/get_started/tensor_expr_get_started.py @@ -36,9 +36,8 @@ # Global declarations of environment. -tgt_host = "llvm" -# Change it to respective GPU if gpu is enabled Ex: cuda, opencl, rocm -tgt = "cuda" +# Change target to respective GPU if gpu is enabled Ex: cuda, opencl, rocm +tgt = tvm.target.Target(target="cuda", host="llvm") ###################################################################### # Vector Add Example @@ -117,7 +116,7 @@ # compute grid. These are GPU specific constructs that allow us # to generate code that runs on GPU. # -if tgt == "cuda" or tgt == "rocm" or tgt.startswith("opencl"): +if tgt.kind.name == "cuda" or tgt.kind.name == "rocm" or tgt.kind.name.startswith("opencl"): s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(tx, te.thread_axis("threadIdx.x")) @@ -138,7 +137,7 @@ # function. fadd is the generated host wrapper function, it contains # a reference to the generated device function internally. # -fadd = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name="myadd") +fadd = tvm.build(s, [A, B, C], target=tgt, name="myadd") ###################################################################### # Run the Function @@ -154,7 +153,7 @@ # - fadd runs the actual computation. # - asnumpy() copies the GPU array back to the CPU and we can use this to verify correctness # -dev = tvm.device(tgt, 0) +dev = tvm.device(tgt.kind.name, 0) n = 1024 a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) @@ -172,7 +171,7 @@ # # The following code fetches the device module and prints the content code. # -if tgt == "cuda" or tgt == "rocm" or tgt.startswith("opencl"): +if tgt.kind.name == "cuda" or tgt.kind.name == "rocm" or tgt.kind.name.startswith("opencl"): dev_module = fadd.imported_modules[0] print("-----GPU code-----") print(dev_module.get_source()) @@ -214,11 +213,11 @@ temp = utils.tempdir() fadd.save(temp.relpath("myadd.o")) -if tgt == "cuda": +if tgt.kind.name == "cuda": fadd.imported_modules[0].save(temp.relpath("myadd.ptx")) -if tgt == "rocm": +if tgt.kind.name == "rocm": fadd.imported_modules[0].save(temp.relpath("myadd.hsaco")) -if tgt.startswith("opencl"): +if tgt.kind.name.startswith("opencl"): fadd.imported_modules[0].save(temp.relpath("myadd.cl")) cc.create_shared(temp.relpath("myadd.so"), [temp.relpath("myadd.o")]) print(temp.listdir()) @@ -240,15 +239,15 @@ # re-links them together. We can verify that the newly loaded function works. # fadd1 = tvm.runtime.load_module(temp.relpath("myadd.so")) -if tgt == "cuda": +if tgt.kind.name == "cuda": fadd1_dev = tvm.runtime.load_module(temp.relpath("myadd.ptx")) fadd1.import_module(fadd1_dev) -if tgt == "rocm": +if tgt.kind.name == "rocm": fadd1_dev = tvm.runtime.load_module(temp.relpath("myadd.hsaco")) fadd1.import_module(fadd1_dev) -if tgt.startswith("opencl"): +if tgt.kind.name.startswith("opencl"): fadd1_dev = tvm.runtime.load_module(temp.relpath("myadd.cl")) fadd1.import_module(fadd1_dev) @@ -290,7 +289,7 @@ # The following code blocks generate OpenCL code, creates array on an OpenCL # device, and verifies the correctness of the code. # -if tgt.startswith("opencl"): +if tgt.kind.name.startswith("opencl"): fadd_cl = tvm.build(s, [A, B, C], tgt, name="myadd") print("------opencl code------") print(fadd_cl.imported_modules[0].get_source())