Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finish support for list-of-targets #11382

Merged
merged 12 commits into from
May 23, 2022
21 changes: 17 additions & 4 deletions include/tvm/target/compilation_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

#include <tvm/target/virtual_device.h>

#include <string>

namespace tvm {

/*!
Expand Down Expand Up @@ -68,14 +70,20 @@ class CompilationConfigNode : public Object {
* \p host_target, however the \p host_target should be used for all host computations and data.
* Each \p Target will have \p host_target as its 'host'.
*
* Primitive targets must be unique by their kind name. In this way the
* \p FindPrimitiveTargetForKind method will find the unique target for the given kind name.
* This method is used when transitioning from an external codegen "Compiler" attribute value
* to the external codegen target representing that compiler.
*
* It is possible to have multiple primitive targets for the same device type. However given
* primitive targets left and right where:
* - left appears before right in the array
* - left->kind->device_type == right->kind->device_type
* then:
* - right.IsExternalCodegenFor(left) must be true
* In this way the FindPrimitiveTargetOrFail method will find the 'most general' target for
* the requested device type.
* In this way the \p FindPrimitiveTargetForDeviceOrFail method will find the 'most general'
* target for the requested device type. This method is used when transitioning from a device
* constraint to the target needed to compile for that device.
*
* In the homogeneous case primitive_targets will have just one entry, which will be pointer equal
* to optional_homogeneous_target.
Expand Down Expand Up @@ -114,11 +122,16 @@ class CompilationConfigNode : public Object {
void VisitAttrs(AttrVisitor* v);

/*!
* \brief Return the unique \p Target to use for \p device_type. Fail if no such target exists.
* \brief Returns the unique \p Target to use for \p device_type. Fail if no such target exists.
*
* This will be the first primitive target with matching device type.
*/
Target FindPrimitiveTargetOrFail(DLDeviceType device_type) const;
Target FindPrimitiveTargetForDeviceOrFail(DLDeviceType device_type) const;

/*!
* \brief Returns the unique \p Target to use for \p kind_name. Returns null if none such.
*/
Optional<Target> FindPrimitiveTargetForKind(const std::string& kind_name) const;

/*!
* \brief Returns a \p VirtualDevice agreeing with \p virtual_device on all its constrained
Expand Down
8 changes: 2 additions & 6 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,7 @@ def recover_measure_input(inp, rebuild_state=False):
from .search_task import SearchTask # lazily import to avoid recursive dependency

task = inp.task
task.target, task.target_host = Target.check_and_update_host_consist(
task.target, task.target_host
)
task.target, task.target_host = Target.canon_target_and_host(task.target, task.target_host)
new_task = SearchTask(
workload_key=task.workload_key,
target=task.target,
Expand Down Expand Up @@ -612,9 +610,7 @@ def _local_build_worker(inp_serialized, build_func, verbose):
tic = time.time()
inp = MeasureInput.deserialize(inp_serialized)
task = inp.task
task.target, task.target_host = Target.check_and_update_host_consist(
task.target, task.target_host
)
task.target, task.target_host = Target.canon_target_and_host(task.target, task.target_host)

error_no = MeasureErrorNo.NO_ERROR
error_msg = None
Expand Down
9 changes: 1 addition & 8 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import logging
import threading
import traceback
import warnings

import tvm
from tvm import autotvm, transform
Expand Down Expand Up @@ -115,13 +114,7 @@ def extract_tasks(
The weight (i.e. the number of appearance) of extracted tasks
"""
# pylint: disable=import-outside-toplevel
if target_host is not None:
warnings.warn(
"target_host parameter is going to be deprecated. "
"Please pass in tvm.target.Target(target, host=target_host) instead."
)

target, target_host = Target.check_and_update_host_consist(target, target_host)
target, target_host = Target.canon_target_and_host(target, target_host)

# Run the compiler to collect all TOPI calls during compilation.
env = TracingEnvironment(
Expand Down
12 changes: 5 additions & 7 deletions python/tvm/auto_scheduler/search_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,9 +380,9 @@ class SearchTask(Object):
The ComputeDAG for the corresponding compute declaration.
workload_key : str
The workload key for the corresponding compute declaration.
target : tvm.target.Target
target : any target-like object, see Target.canon_target
The target device of this search task.
target_host : Optional[tvm.target.Target]
target_host : None or any target-like object, see Target.canon_target
The target host device of this search task.
hardware_params : Optional[HardwareParams]
Hardware parameters used in this search task.
Expand Down Expand Up @@ -448,7 +448,7 @@ def __init__(

assert target is not None, "Must specify a target."

target, target_host = Target.check_and_update_host_consist(target, target_host)
target, target_host = Target.canon_target_and_host(target, target_host)

if layout_rewrite_option is None:
layout_rewrite_option = LayoutRewriteOption.get_target_default(target)
Expand Down Expand Up @@ -559,9 +559,7 @@ def print_best(self, log_file, print_mode="schedule"):
raise ValueError("Invalid print_mode: %s" % print_mode)

def __getstate__(self):
self.target, self.target_host = Target.check_and_update_host_consist(
self.target, self.target_host
)
self.target, self.target_host = Target.canon_target_and_host(self.target, self.target_host)
return {
"compute_dag": self.compute_dag,
"workload_key": self.workload_key,
Expand All @@ -587,7 +585,7 @@ def __setstate__(self, state):
if workload[0] not in WORKLOAD_FUNC_REGISTRY:
register_workload_tensors(state["workload_key"], state["compute_dag"].tensors)

state["target"], state["target_host"] = Target.check_and_update_host_consist(
state["target"], state["target_host"] = Target.canon_target_and_host(
state["target"], state["target_host"]
)
self.__init_handle_by_constructor__(
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/autotvm/graph_tuner/base_graph_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def benchmark_layout_transform(
Accept a user-supplied runner
"""
self._logger.info("Start to benchmark layout transformation...")
self._target, target_host = Target.check_and_update_host_consist(self._target, target_host)
self._target, target_host = Target.canon_target_and_host(self._target, target_host)

if layout_records is None and infer_layout:
raise RuntimeError("Requires some records to infer layout transformation time.")
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def set_task(self, task):
def _build_func_common(measure_input, runtime=None, check_gpu=None, build_option=None):
"""Common part for building a configuration"""
target, task, config = measure_input
target, task.target_host = Target.check_and_update_host_consist(target, task.target_host)
target, task.target_host = Target.canon_target_and_host(target, task.target_host)

with target:
s, args = task.instantiate(config)
Expand Down
10 changes: 2 additions & 8 deletions python/tvm/autotvm/task/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
"""
import threading
import logging
import warnings

import tvm
from tvm.autotvm.task.dispatcher import DispatchContext, FallbackContext
Expand Down Expand Up @@ -81,12 +80,7 @@ def extract_from_program(mod, params, target, target_host=None, ops=None):
task: Array of autotvm.task.Task
collected tasks
"""
if target_host is not None:
warnings.warn(
"target_host parameter is going to be deprecated. "
"Please pass in tvm.target.Target(target, host=target_host) instead."
)
target, target_host = Target.check_and_update_host_consist(target, target_host)
target, target_host = Target.canon_target_and_host(target, target_host)
return extract_from_multiple_program([mod], [params], target, ops=ops)


Expand Down Expand Up @@ -121,7 +115,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No
env = TaskExtractEnv.get()

# merge target and target host
target, target_host = Target.check_and_update_host_consist(target, target_host)
target, target_host = Target.canon_target_and_host(target, target_host)

# run compiler to collect all TOPI calls during compilation
env.reset(ops)
Expand Down
11 changes: 3 additions & 8 deletions python/tvm/autotvm/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,7 @@ def __getstate__(self):
# and restore the function by name when unpickling it.
import cloudpickle # pylint: disable=import-outside-toplevel

self.target, self.target_host = Target.check_and_update_host_consist(
self.target, self.target_host
)
self.target, self.target_host = Target.canon_target_and_host(self.target, self.target_host)
return {
"name": self.name,
"args": self.args,
Expand All @@ -200,7 +198,7 @@ def __setstate__(self, state):
self.config_space = state["config_space"]
self.func = cloudpickle.loads(state["func"])
self.flop = state["flop"]
self.target, self.target_host = Target.check_and_update_host_consist(
self.target, self.target_host = Target.canon_target_and_host(
state["target"], state["target_host"]
)

Expand Down Expand Up @@ -471,10 +469,7 @@ def create(task_name, args, target, target_host=None):
args = serialize_args(args)
ret = Task(task_name, args)

if isinstance(target, str):
target = Target(target)

target, target_host = Target.check_and_update_host_consist(target, target_host)
target, target_host = Target.canon_target_and_host(target, target_host)

# init config space
ret.config_space = ConfigSpace()
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/contrib/hexagon/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,8 @@ def _set_device_type(self, module: Union[str, pathlib.Path, GraphExecutorFactory
if not hasattr(module, "target"):
self._requires_cpu_device = False
else:
assert len(module.target.values()) == 1
for target in module.target.values():
assert len(module.target) == 1
for target in module.target:
target_type = str(target).split()[0]

if target_type == "llvm":
Expand Down Expand Up @@ -319,13 +319,13 @@ def _aot_executor_from_factory(

hexagon_arch = set(
target.mcpu.replace("hexagon", "")
for target in module.target.values()
for target in module.target
if "hexagon" in target.keys
)

self._set_device_type(module)

for target in module.target.values():
for target in module.target:
target_type = str(target).split()[0]

assert hexagon_arch, "No hexagon target architecture found"
Expand Down
10 changes: 5 additions & 5 deletions python/tvm/contrib/peak.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def measure_bandwidth_sum(
GBPS: float
gigabyte per second
"""
target, target_host = Target.check_and_update_host_consist(target, target_host)
target, target_host = Target.canon_target_and_host(target, target_host)

n, m = total_item, item_per_thread
n //= lanes
Expand Down Expand Up @@ -154,7 +154,7 @@ def measure_bandwidth_all_types(
result: list
a list of (type_name, GBPS) pairs
"""
target, target_host = Target.check_and_update_host_consist(target, target_host)
target, target_host = Target.canon_target_and_host(target, target_host)
max_threads = target.max_num_threads

result = []
Expand Down Expand Up @@ -225,7 +225,7 @@ def measure_compute_mad(
GOPS: float
giga operation per second
"""
target, target_host = Target.check_and_update_host_consist(target, target_host)
target, target_host = Target.canon_target_and_host(target, target_host)

n = total_item

Expand Down Expand Up @@ -318,7 +318,7 @@ def measure_compute_all_types(
result: list
a list of (type_name, GFLOPS/GIOPS) pairs
"""
target, target_host = Target.check_and_update_host_consist(target, target_host)
target, target_host = Target.canon_target_and_host(target, target_host)

result = []
for base_type in ["float", "int"]:
Expand Down Expand Up @@ -364,7 +364,7 @@ def measure_peak_all(target, target_host, host, port):
port: int
"""

target, target_host = Target.check_and_update_host_consist(target, target_host)
target, target_host = Target.canon_target_and_host(target, target_host)
remote = rpc.connect(host, port)
n_times = 20

Expand Down
17 changes: 5 additions & 12 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

# pylint: disable=invalid-name
"""The build utils in python."""
import warnings

from typing import Union, Optional, List, Mapping

import tvm.tir
Expand Down Expand Up @@ -238,12 +236,6 @@ def build(
f"but got {type(inputs)}."
)

if target_host is not None:
warnings.warn(
"target_host parameter is going to be deprecated. "
"Please pass in tvm.target.Target(target, host=target_host) instead."
)

if not isinstance(inputs, (dict, container.Map)):
target = Target.current() if target is None else target
target = target if target else "llvm"
Expand All @@ -261,23 +253,24 @@ def build(
raise ValueError("inputs must be Schedule, IRModule," "or dict of str to IRModule.")
annotated_mods[tar] = mod.with_attr("runtime", runtime)

annotated_mods, target_host = Target.check_and_update_host_consist(annotated_mods, target_host)
annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host)

# TODO(mbs): CompilationConfig implements the same host target defaulting logic, but
# tir_to_runtime currently bypasses that.
if not target_host:
for tar, mod in annotated_mods.items():
tar = Target(tar)
device_type = ndarray.device(tar.kind.name, 0).device_type
if device_type == ndarray.cpu(0).device_type:
target_host = tar
break
if not target_host:
target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm"

annotated_mods, target_host = Target.check_and_update_host_consist(annotated_mods, target_host)
annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host)

rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)

annotated_mods, target_host = Target.check_and_update_host_consist(annotated_mods, target_host)
annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host)

if not isinstance(target_host, Target):
target_host = Target(target_host)
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/driver/tvmc/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def tune_model(
The path to the produced tuning log file.
"""
target, extra_targets = target_from_cli(target, additional_target_options)
target, target_host = Target.check_and_update_host_consist(target, target_host)
target, target_host = Target.canon_target_and_host(target, target_host)
# TODO(jwfromm) Remove this deepcopy once AlterOpLayout bug that mutates source
# model is fixed. For now, creating a clone avoids the issue.
mod = deepcopy(tvmc_model.mod)
Expand Down Expand Up @@ -524,7 +524,7 @@ def autotvm_get_tuning_tasks(
tasks : list of autotvm.Tasks
list of tasks to be tuned
"""
target, target_host = Target.check_and_update_host_consist(target, target_host)
target, target_host = Target.canon_target_and_host(target, target_host)

if alter_layout:
mod = convert_graph_layout(mod, alter_layout)
Expand Down Expand Up @@ -573,7 +573,7 @@ def autoscheduler_get_tuning_tasks(
weights : List[int]
the weight (i.e. the number of appearance) of extracted tasks
"""
target, target_host = Target.check_and_update_host_consist(target, target_host)
target, target_host = Target.canon_target_and_host(target, target_host)

if alter_layout:
mod = convert_graph_layout(mod, alter_layout)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/driver/tvmc/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def compile_model(
mod = convert_graph_layout(mod, desired_layout)

tvm_target, extra_targets = target_from_cli(target, additional_target_options)
tvm_target, target_host = Target.check_and_update_host_consist(tvm_target, target_host)
tvm_target, target_host = Target.canon_target_and_host(tvm_target, target_host)

for codegen_from_cli in extra_targets:
codegen = composite_target.get_codegen_by_target(codegen_from_cli["name"])
Expand Down
4 changes: 1 addition & 3 deletions python/tvm/exec/measure_peak.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ def main():
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)

args.target, args.target_host = Target.check_and_update_host_consist(
args.target, args.target_host
)
args.target, args.target_host = Target.canon_target_and_host(args.target, args.target_host)
measure_peak_all(args.target, args.target_host, args.rpc_host, args.rpc_port)


Expand Down
Loading