Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Target] Add support for target object with host field compatible with previous api #7534

Merged
merged 74 commits into from
Mar 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
6401b6f
Fix legacy code on target host
zxybazh Feb 25, 2021
0167a5f
Modify legacy code for target host change
zxybazh Feb 25, 2021
2a3c502
Add tests and fix merge issue
zxybazh Feb 25, 2021
511ce56
Add condition for same host
zxybazh Feb 25, 2021
69601a7
Modify all files for new target host api compatibility
zxybazh Feb 26, 2021
23187d8
Add newline
zxybazh Feb 26, 2021
85b27db
Change import format
zxybazh Feb 26, 2021
7e4eb0a
Optimize test file
zxybazh Feb 26, 2021
59457f6
Add match error info for unit tests
zxybazh Feb 26, 2021
b7e4c71
Fix for heterogeneous targets
zxybazh Mar 2, 2021
f5ccc50
Fix format for dict iteration
zxybazh Mar 2, 2021
11c77ba
Fix target host type error
zxybazh Mar 2, 2021
ca95bfd
Merge branch 'main' of https://github.com/zxybazh/tvm into target
zxybazh Mar 2, 2021
7543422
Skip one testcase for tvm infinite loop bug
zxybazh Mar 3, 2021
fbd597a
Fixed bug for target map compatibility
zxybazh Mar 3, 2021
4d11b7b
Fix another TargetsMap issue
zxybazh Mar 3, 2021
5a0f06b
Fix typo and infinite loop error
zxybazh Mar 3, 2021
0e01e13
Temporary fix for handle issue
zxybazh Mar 3, 2021
7db8327
Fix vm target
zxybazh Mar 4, 2021
f214410
Add condition support for str case
zxybazh Mar 4, 2021
38c4ec0
Add GetHost function and fix previous bugs
zxybazh Mar 4, 2021
8bacc8d
Fix measure_record.cc
zxybazh Mar 4, 2021
36153dd
Fix search_task.cc
zxybazh Mar 4, 2021
df1f6a1
Fix compiler.cc, memory_alloc.cc
zxybazh Mar 5, 2021
4539cff
Fix driver_api.cc
zxybazh Mar 5, 2021
b328525
Fix format
zxybazh Mar 5, 2021
ba427ec
Fix bugs and GetHost function usage
zxybazh Mar 5, 2021
915e3d3
Fix clang format
zxybazh Mar 5, 2021
1a9dcb5
Fix bug
zxybazh Mar 6, 2021
efacf81
Merged main branch, resolve conflicts
zxybazh Mar 6, 2021
606ec71
Modify python tests
zxybazh Mar 7, 2021
71e01d0
Change python unit tests to new target api
zxybazh Mar 7, 2021
95539d9
Fi test_runtime_heterogeneous.py
zxybazh Mar 8, 2021
858d901
Modify tutorials & remove extra print
zxybazh Mar 8, 2021
d99b560
Update more tests to new api
zxybazh Mar 8, 2021
62ec2d3
Refine the tutorial target usage
zxybazh Mar 8, 2021
6916758
change argument name for Target constructor function
zxybazh Mar 8, 2021
a762d7d
Fix target export function
zxybazh Mar 9, 2021
b01f6cc
Fix and validate all tutorial usage
zxybazh Mar 9, 2021
b480bee
Remove unused argument
zxybazh Mar 9, 2021
c17a18e
Fix format
zxybazh Mar 9, 2021
a64efd6
Fix bug in driver/build_module.py for heterogeneous target
zxybazh Mar 9, 2021
fa982a9
Fix bug in driver/build_module.py for heterogeneous target more
zxybazh Mar 9, 2021
33c4057
Fix target host type error
zxybazh Mar 10, 2021
88d2379
Merge branch 'main' of https://github.com/apache/tvm into target
zxybazh Mar 10, 2021
75d0f44
Fix cudnn target host bug
zxybazh Mar 10, 2021
47bcc4c
Fix according to reviews, add helper function in python
zxybazh Mar 13, 2021
5d8201e
Refactor code as helper function
zxybazh Mar 16, 2021
c9e1c9b
Expand helper function
zxybazh Mar 16, 2021
ec664ee
Fix bug add and update python helper function
zxybazh Mar 16, 2021
983108c
Update target hosts
zxybazh Mar 16, 2021
ddfdeb2
Fix format & refresh function
zxybazh Mar 16, 2021
cb206ec
Fix unit test bug
zxybazh Mar 16, 2021
ae4ca68
Fix bug in refreshing host
zxybazh Mar 16, 2021
26a8647
Fix bug
zxybazh Mar 16, 2021
83f290b
Add SetHost function
zxybazh Mar 16, 2021
47b072c
Update export function
zxybazh Mar 16, 2021
bef6fbb
Fix format
zxybazh Mar 17, 2021
6771f2d
Fix export bug in target
zxybazh Mar 17, 2021
4442fba
Fix bug on host referencing
zxybazh Mar 17, 2021
542c927
Addtional tests
zxybazh Mar 17, 2021
8a537b4
Address review issues
zxybazh Mar 18, 2021
6f76c1d
Fix format target.py
zxybazh Mar 18, 2021
f46626f
Fix issues and format
zxybazh Mar 30, 2021
244cc40
Add some 3rd party dependencies
zxybazh Mar 30, 2021
fdfb93a
Merge main branch
zxybazh Mar 30, 2021
7f509bd
Merge branch 'main' into target
zxybazh Mar 30, 2021
3804269
Fix target.h format
zxybazh Mar 30, 2021
dd3787c
Remove redundent import
zxybazh Mar 30, 2021
6e114ca
Fix function name
zxybazh Mar 30, 2021
adec87f
Add parameter name
zxybazh Mar 30, 2021
34f1dac
Merge branch 'main' into target
zxybazh Mar 31, 2021
b71bd1a
Fix new code bug
zxybazh Mar 31, 2021
3a8080e
Fix bug in lowering
zxybazh Mar 31, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion include/tvm/target/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#ifndef TVM_TARGET_TARGET_H_
#define TVM_TARGET_TARGET_H_

#include <tvm/ir/expr.h>
#include <tvm/ir/module.h>
#include <tvm/node/node.h>
#include <tvm/support/with.h>
#include <tvm/target/target_kind.h>
Expand All @@ -35,6 +37,7 @@
namespace tvm {

class TargetInternal;
class Target;

/*!
* \brief Compilation target.
Expand All @@ -60,6 +63,8 @@ class TargetNode : public Object {
TVM_DLL const std::string& str() const;
/*! \return Export target to JSON-like configuration */
TVM_DLL Map<String, ObjectRef> Export() const;
/*! \return The Optional<Target> typed target host of the TargetNode */
TVM_DLL Optional<Target> GetHost() const;

void VisitAttrs(AttrVisitor* v) {
v->Visit("kind", &kind);
Expand Down Expand Up @@ -150,6 +155,13 @@ class Target : public ObjectRef {
*/
TVM_DLL explicit Target(Target target, Target host);
TVM_DEFINE_OBJECT_REF_METHODS(Target, ObjectRef, TargetNode);
/*!
* \brief Create a new Target object with given target (w.o host) and target host.
* \param target The current Target typed object target, with or without host field.
* \param host The given Target typed object target host
* \return The new Target object with the given target and host field of given host.
*/
static Target WithHost(const Target& target, const Target& host);

private:
// enable with syntax.
Expand All @@ -167,6 +179,29 @@ class Target : public ObjectRef {
*/
TVM_DLL void ExitWithScope();
};

/*!
* \brief Check and update host field of the given legacy target and target host pair.
* Note that this function is for legacy target api compatibility issue only, not
* recommended for other use.
* \param target The pointer to a Target typed object with host field to be updated
* \param host The pointer to a Target typed object for target host to be updated
*/
void CheckAndUpdateHostConsistency(Target* target, Target* host);
/*!
* \brief Check and update host field of the given legacy heterogeneous targets and
* target host.Note that this function is for legacy target api compatibility issue only,
* not recommended for other use.
* \param target The pointer to a Map objects with values being Target objects
* \param host The Target typed object for target host to be updated
*/
void CheckAndUpdateHostConsistency(Map<Integer, Target>* target, Target* host);
/*!
* \brief Check and update host field of the given legacy heterogeneous targets and
* target host.Note that this function is for legacy target api compatibility issue only,
* not recommended for other use.
* \param target The pointer to a Map objects with keys being Target objects
* \param host The Target typed object for target host to be updated
*/
void CheckAndUpdateHostConsistency(Map<Target, IRModule>* target, Target* host);
} // namespace tvm
#endif // TVM_TARGET_TARGET_H_
13 changes: 9 additions & 4 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
from tvm.ir import transform
from tvm.autotvm.measure.measure_methods import set_cuda_target_arch
from tvm.contrib import tar, ndk
from tvm.target import Target


from . import _ffi_api
from .loop_state import StateObject
Expand Down Expand Up @@ -221,10 +223,12 @@ def recover_measure_input(inp, rebuild_state=False):
from .search_task import SearchTask # lazily import to avoid recursive dependency

task = inp.task
task.target, task.target_host = Target.check_and_update_host_consist(
task.target, task.target_host
)
new_task = SearchTask(
workload_key=task.workload_key,
target=task.target,
target_host=task.target_host,
hardware_params=task.hardware_params,
layout_rewrite_option=task.layout_rewrite_option,
task_inputs=list(task.task_input_names),
Expand Down Expand Up @@ -602,6 +606,9 @@ def _timed_func(inp_serialized, build_func, verbose):
tic = time.time()
inp = MeasureInput.deserialize(inp_serialized)
task = inp.task
task.target, task.target_host = Target.check_and_update_host_consist(
task.target, task.target_host
)

error_no = MeasureErrorNo.NO_ERROR
error_msg = None
Expand All @@ -622,9 +629,7 @@ def _timed_func(inp_serialized, build_func, verbose):

try:
with transform.PassContext():
func = build_module.build(
sch, args, target=task.target, target_host=task.target_host
)
func = build_module.build(sch, args, target=task.target)
func.export_library(filename, build_func)
# pylint: disable=broad-except
except Exception:
Expand Down
8 changes: 3 additions & 5 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@
from tvm import autotvm, transform
from tvm.ir.transform import PassContext
from tvm.runtime import convert_to_object

from tvm.te.tensor import ComputeOp, PlaceholderOp, Tensor
from tvm.tir import Reduce
from tvm.tir import expr as _expr
from tvm.target import Target

from . import _ffi_api
from .compute_dag import ComputeDAG, LayoutRewriteOption
Expand Down Expand Up @@ -108,10 +110,7 @@ def extract_tasks(
"""
# pylint: disable=import-outside-toplevel

if isinstance(target, str):
target = tvm.target.Target(target)
if isinstance(target_host, str):
target_host = tvm.target.Target(target_host)
target, target_host = Target.check_and_update_host_consist(target, target_host)

# Run the compiler to collect all TOPI calls during compilation.
env = TracingEnvironment(
Expand All @@ -137,7 +136,6 @@ def extract_tasks(
SearchTask(
workload_key=wkl_key,
target=target,
target_host=target_host,
hardware_params=hardware_params,
# When auto scheduler is used in end to end network, try to apply layout rewrite
# to improve the overall performance
Expand Down
14 changes: 9 additions & 5 deletions python/tvm/auto_scheduler/search_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,10 +398,8 @@ def __init__(
compute_dag = ComputeDAG(workload_key)

assert target is not None, "Must specify a target."
if isinstance(target, str):
target = Target(target)
if isinstance(target_host, str):
target_host = Target(target_host)

target, target_host = Target.check_and_update_host_consist(target, target_host)

if layout_rewrite_option is None:
layout_rewrite_option = LayoutRewriteOption.get_target_default(target)
Expand Down Expand Up @@ -511,6 +509,9 @@ def print_best(self, log_file, print_mode="schedule"):
raise ValueError("Invalid print_mode: %s" % print_mode)

def __getstate__(self):
self.target, self.target_host = Target.check_and_update_host_consist(
self.target, self.target_host
)
return {
"compute_dag": self.compute_dag,
"workload_key": self.workload_key,
Expand All @@ -535,12 +536,15 @@ def __setstate__(self, state):
if workload[0] not in WORKLOAD_FUNC_REGISTRY:
register_workload_tensors(state["workload_key"], state["compute_dag"].tensors)

state["target"], state["target_host"] = Target.check_and_update_host_consist(
state["target"], state["target_host"]
)
self.__init_handle_by_constructor__(
_ffi_api.SearchTask,
state["compute_dag"],
state["workload_key"],
state["target"],
state["target_host"],
state["target"].host,
state["hardware_params"],
state["layout_rewrite_option"],
state["task_input_names"],
Expand Down
7 changes: 4 additions & 3 deletions python/tvm/autotvm/graph_tuner/base_graph_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from tvm.autotvm.task import get_config
from tvm.autotvm.record import encode, load_from_file
from tvm.autotvm.measure import MeasureResult, MeasureInput
from tvm.target import Target

from ...target import Target
from .utils import (
Expand Down Expand Up @@ -439,6 +440,8 @@ def benchmark_layout_transform(
This might bring performance loss comparing to benchmarking layout transformation.
"""
self._logger.info("Start to benchmark layout transformation...")
self._target, target_host = Target.check_and_update_host_consist(self._target, target_host)

if layout_records is None and infer_layout:
raise RuntimeError("Requires some records to infer layout transformation time.")

Expand Down Expand Up @@ -525,9 +528,7 @@ def _callback(_, inputs, results):
continue

records = []
task = autotvm.task.create(
"layout_transform", args=args, target=self._target, target_host=target_host
)
task = autotvm.task.create("layout_transform", args=args, target=self._target)
tuner = autotvm.tuner.GridSearchTuner(task)
tuner.tune(n_trial=1, measure_option=measure_option, callbacks=[_log_to_list(records)])
if not isinstance(records[0][1].costs[0], float):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def expr2graph(expr, target_ops, node_dict, node_list):
for node_entry in node_list:
if node_entry["op"] in target_ops:
task_name, args = env.task_collection[task_pos]
task = autotvm.task.create(task_name, args, target="llvm", target_host=None)
task = autotvm.task.create(task_name, args, target="llvm")
node_entry["workloads"] = [task.workload]
node_entry["topi_op"] = [task_name]
task_pos += 1
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from tvm.error import TVMError
from tvm.driver import build
from tvm.contrib import nvcc, ndk, tar
from tvm.target import Target

from ..utils import get_const_tuple
from ..env import AutotvmGlobalScope
Expand Down Expand Up @@ -418,6 +419,8 @@ def set_task(self, task):
def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_option=None):
"""Common part for building a configuration"""
target, task, config = measure_input
target, task.target_host = Target.check_and_update_host_consist(target, task.target_host)

with target:
s, args = task.instantiate(config)

Expand Down
9 changes: 7 additions & 2 deletions python/tvm/autotvm/task/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import tvm
from tvm.autotvm.task.dispatcher import DispatchContext, FallbackContext
from tvm.target import Target
from .task import create
from .topi_integration import TaskExtractEnv

Expand Down Expand Up @@ -89,7 +90,8 @@ def extract_from_program(mod, params, target, target_host=None, ops=None):
task: Array of autotvm.task.Task
collected tasks
"""
return extract_from_multiple_program([mod], [params], target, target_host, ops)
target, target_host = Target.check_and_update_host_consist(target, target_host)
return extract_from_multiple_program([mod], [params], target, ops=ops)


def extract_from_multiple_program(mods, params, target, target_host=None, ops=None):
Expand Down Expand Up @@ -122,6 +124,9 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No

env = TaskExtractEnv.get()

# merge target and target host
target, target_host = Target.check_and_update_host_consist(target, target_host)

# run compiler to collect all TOPI calls during compilation
env.reset(ops)
with env:
Expand Down Expand Up @@ -152,7 +157,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No
tasks = []
for task_name, args in env.get_tasks():
try:
tsk = create(task_name, args, target=target, target_host=target_host)
tsk = create(task_name, args, target=target)
tasks.append(tsk)
except topi.InvalidShapeError:
logger.warning("Invalid shape during AutoTVM task creation")
Expand Down
14 changes: 10 additions & 4 deletions python/tvm/autotvm/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,14 +175,17 @@ def __getstate__(self):
# and restore the function by name when unpickling it.
import cloudpickle # pylint: disable=import-outside-toplevel

self.target, self.target_host = Target.check_and_update_host_consist(
self.target, self.target_host
)
return {
"name": self.name,
"args": self.args,
"kwargs": self.kwargs,
"config_space": self.config_space,
"flop": self.flop,
"target": self.target,
"target_host": self.target_host,
"target_host": self.target.host,
"func": cloudpickle.dumps(self.func),
}

Expand All @@ -195,8 +198,9 @@ def __setstate__(self, state):
self.config_space = state["config_space"]
self.func = cloudpickle.loads(state["func"])
self.flop = state["flop"]
self.target = state["target"]
self.target_host = state["target_host"]
self.target, self.target_host = Target.check_and_update_host_consist(
state["target"], state["target_host"]
)

def __repr__(self):
return "Task(func_name=%s, args=%s, kwargs=%s, workload=%s)" % (
Expand Down Expand Up @@ -448,6 +452,8 @@ def create(task_name, args, target, target_host=None):
if isinstance(target, str):
target = Target(target)

target, target_host = Target.check_and_update_host_consist(target, target_host)

# init config space
ret.config_space = ConfigSpace()

Expand All @@ -459,7 +465,7 @@ def create(task_name, args, target, target_host=None):

ret.flop = ret.config_space.flop or compute_flop(sch)
ret.target = target
ret.target_host = target_host
ret.target_host = target.host

return ret

Expand Down
13 changes: 10 additions & 3 deletions python/tvm/contrib/peak.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import logging
import tvm
from tvm import te
from tvm.target import Target
from . import utils
from .. import rpc

Expand Down Expand Up @@ -86,6 +87,8 @@ def measure_bandwidth_sum(
GBPS: float
gigabyte per second
"""
target, target_host = Target.check_and_update_host_consist(target, target_host)

n, m = total_item, item_per_thread
n //= lanes

Expand All @@ -107,7 +110,7 @@ def measure_bandwidth_sum(
s[y].unroll(k)

try:
func = tvm.build(s, [x, y], target, target_host=target_host)
func = tvm.build(s, [x, y], target)

x = tvm.nd.empty((n,), dtype=dtype, device=dev)
y = tvm.nd.empty((n // m,), dtype=dtype, device=dev)
Expand Down Expand Up @@ -151,6 +154,7 @@ def measure_bandwidth_all_types(
result: list
a list of (type_name, GBPS) pairs
"""
target, target_host = Target.check_and_update_host_consist(target, target_host)
max_threads = target.max_num_threads

result = []
Expand Down Expand Up @@ -221,6 +225,7 @@ def measure_compute_mad(
GOPS: float
giga operation per second
"""
target, target_host = Target.check_and_update_host_consist(target, target_host)

n = total_item

Expand Down Expand Up @@ -272,7 +277,7 @@ def mad_func(x, y):
s = te.create_schedule(y.op)

try:
func = tvm.build(s, [y], target, target_host=target_host)
func = tvm.build(s, [y], target)
func = _convert_to_remote(func, remote)
time_f = func.time_evaluator(func.entry_name, dev, number=n_times)
y = tvm.nd.empty((n,), dtype=dtype, device=dev)
Expand Down Expand Up @@ -313,6 +318,8 @@ def measure_compute_all_types(
result: list
a list of (type_name, GFLOPS/GIOPS) pairs
"""
target, target_host = Target.check_and_update_host_consist(target, target_host)

result = []
for base_type in ["float", "int"]:
for bits in [16, 32, 64]:
Expand Down Expand Up @@ -357,7 +364,7 @@ def measure_peak_all(target, target_host, host, port):
port: int
"""

target = tvm.target.Target(target)
target, target_host = Target.check_and_update_host_consist(target, target_host)
remote = rpc.connect(host, port)
n_times = 20

Expand Down
Loading