Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoScheduler] Add function name in message #7703

Merged
merged 2 commits into from
Mar 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 29 additions & 20 deletions python/tvm/auto_scheduler/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class DispatchContext(object):
def __init__(self):
self._old_ctx = DispatchContext.current

def query(self, target, workload_key, has_complex_op, dag):
def query(self, target, workload_key, has_complex_op, dag, func_name):
"""
Query the context to get the specific config for a workload.
If cannot find the result inside this context, this function will query it
Expand All @@ -66,15 +66,17 @@ def query(self, target, workload_key, has_complex_op, dag):
Whether this workload has at least one complex op.
dag: ComputeDAG
The ComputeDAG of the workload.
func_name: str
The function name of this workload.

Returns
-------
state : StateObject
The state that stores schedule configuration for the workload
"""
ret = self._query_inside(target, workload_key)
ret = self._query_inside(target, workload_key, func_name)
if ret is None:
ret = self._old_ctx.query(target, workload_key, has_complex_op, dag)
ret = self._old_ctx.query(target, workload_key, has_complex_op, dag, func_name)
return ret

def update(self, target, workload_key, state):
Expand All @@ -92,7 +94,7 @@ def update(self, target, workload_key, state):
"""
raise NotImplementedError()

def _query_inside(self, target, workload_key):
def _query_inside(self, target, workload_key, func_name):
"""
Query the context to get the specific config for a workload.
This function only query config inside this context.
Expand All @@ -103,6 +105,8 @@ def _query_inside(self, target, workload_key):
The current target
workload_key : str
The current workload_key.
func_name: str
The function name of this workload.

Returns
-------
Expand Down Expand Up @@ -241,7 +245,7 @@ def load(self, records, n_lines=None):

logger.debug("Finish loading %d records", counter)

def _query_inside(self, target, workload_key):
def _query_inside(self, target, workload_key, func_name):
if target is None:
raise RuntimeError(
"Need a target context to find the history best. "
Expand Down Expand Up @@ -343,18 +347,20 @@ def __init__(
records, n_lines=None, include_compatible=True
)

def query(self, target, workload_key, has_complex_op, dag):
def query(self, target, workload_key, has_complex_op, dag, func_name):
if has_complex_op or self.sample_simple_workloads:
ret = self._query_inside(target, workload_key)
ret = self._query_inside(target, workload_key, func_name)
else:
ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key)
ret = super(ApplyHistoryBestOrSample, self)._query_inside(
target, workload_key, func_name
)

if ret is None:
ret = self._old_ctx.query(target, workload_key, has_complex_op, dag)
ret = self._old_ctx.query(target, workload_key, has_complex_op, dag, func_name)
return ret

def _query_inside(self, target, workload_key):
ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key)
def _query_inside(self, target, workload_key, func_name):
ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key, func_name)
if ret is not None:
return ret

Expand Down Expand Up @@ -386,7 +392,9 @@ def _query_inside(self, target, workload_key):

# Load the sampled records and query again.
self.load(log_file)
ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key)
ret = super(ApplyHistoryBestOrSample, self)._query_inside(
target, workload_key, func_name
)

del measure_ctx
return ret
Expand All @@ -411,18 +419,19 @@ def __init__(self):
# a set to prevent print duplicated message
self.messages = set()

def query(self, target, workload_key, has_complex_op, dag):
def query(self, target, workload_key, has_complex_op, dag, func_name):
key = (str(target), workload_key)
if key in self.memory:
return self.memory[key]

if self.verbose == 2 or (has_complex_op and self.verbose == 1):
msg = (
"-----------------------------------\n"
"Cannot find tuned schedules for target=%s, workload_key=%s. "
"A fallback TOPI schedule is used, "
"which may bring great performance regression or even compilation failure. "
"Compute DAG info:\n%s" % (target, workload_key, dag)
f"-----------------------------------\n"
f"{func_name}\n"
f"Cannot find tuned schedules for target={target}, workload_key={workload_key}. "
f"A fallback TOPI schedule is used, "
f"which may bring great performance regression or even compilation failure. "
f"Compute DAG info:\n{dag}"
)
if msg not in self.messages:
self.messages.add(msg)
Expand All @@ -434,8 +443,8 @@ def query(self, target, workload_key, has_complex_op, dag):
self.memory[key] = state
return state

def _query_inside(self, target, workload_key):
_ = target = workload_key
def _query_inside(self, target, workload_key, func_name):
_ = target = workload_key = func_name
raise RuntimeError("This function should never be called")

def update(self, target, workload_key, state):
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,17 @@ def traverse(t):


@tvm._ffi.register_func("auto_scheduler.relay_integration.auto_schedule_topi_compute")
def auto_schedule_topi(outs):
def auto_schedule_topi(func_name, outs):
"""Use auto-scheduler to schedule any topi compute function.

Note: This is used internally for relay integration. Do
not use this as a general user-facing API.

Parameters
----------
func_name: str
The name of the function being scheduled.

outs: List[Tensor]
The output tensors of topi compute functions

Expand All @@ -289,7 +292,7 @@ def auto_schedule_topi(outs):
target = tvm.target.Target.current()

dispatch_ctx = DispatchContext.current
state = dispatch_ctx.query(target, key, has_complex_op, dag)
state = dispatch_ctx.query(target, key, has_complex_op, dag, func_name)
schedule = None

env = TracingEnvironment.current
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>>
runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
ICHECK(fauto_schedule != nullptr)
<< "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered";
ObjectRef obj = (*fauto_schedule)(tensor_outs);
ObjectRef obj = (*fauto_schedule)(String(cache_node->func_name), tensor_outs);
if (obj.defined()) {
schedule = Downcast<te::Schedule>(obj);
}
Expand Down