Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizing autotvm task extraction speed #4138

Merged
merged 8 commits into from
Oct 29, 2019
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions python/tvm/autotvm/task/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,28 @@
logger = logging.getLogger('autotvm')


# TODO(moreau89) find a more elegant way to build for VTAs
def _build(func,
# TODO(moreau89) find a more elegant way to lower for VTAs
def _lower(func,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need test cases for extract_from_program? I didn't find any existing one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do have unit tests for extract_from_program.
My just merged PR is related to this one: #4173

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need test cases for extract_from_program? I didn't find any existing one.

I do have my own test case for justification and performance evaluation.

We do have unit tests for extract_from_program.
My just merged PR is related to this one: #4173

As this PR has been merged, maybe I should just skip mine and directly use this one?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine. After resolving the conflict your changes should be covered by existing unit tests (for functionality).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine. After resolving the conflict your changes should be covered by existing unit tests (for functionality).

I'll make a commit and trigger the test.

target,
target_host,
params):
""" Helper to build VTA properly.
""" Helper to lower VTA properly.
"""

from tvm import relay
from tvm.relay.backend import graph_runtime_codegen

if hasattr(target, 'device_name') and target.device_name == "vta":
with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
import vta
with vta.build_config():
return relay.build(func, target, target_host, params)
_, mod, _ = relay.optimize(func, target, params)
grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
return grc.codegen(mod["main"])
# default case
return relay.build(func, target, target_host, params)
_, mod, _ = relay.optimize(func, target, params)
grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
return grc.codegen(mod["main"])


def extract_from_program(func, params, ops, target, target_host=None):
""" Extract tuning tasks from a relay program.
Expand Down Expand Up @@ -133,8 +138,8 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
relay.backend.compile_engine.get().clear()
# wrap build call in thread to avoid multiprocessing problems
mod = relay.Module.from_expr(func)
build_thread = threading.Thread(target=_build,
args=(mod, target, target_host, param))
build_thread = threading.Thread(target=_lower,
args=(mod, target, param))
build_thread.start()
build_thread.join()

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from . import adt
from . import analysis
from . import transform
from .build_module import build, create_executor
from .build_module import build, create_executor, optimize
from .transform import build_config
from . import prelude
from . import parser
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/backend/graph_runtime_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@
from __future__ import absolute_import

from tvm.ndarray import empty
from tvm.relay import build_module
from tvm.relay import _build_module
fwd4 marked this conversation as resolved.
Show resolved Hide resolved
from tvm import target as _target
from tvm import expr as _expr

class GraphRuntimeCodegen(object):
"""The compiler from Relay to the TVM runtime system."""

def __init__(self, mod, target):
self._mod = build_module._GraphRuntimeCodegen()
self._mod = _build_module._GraphRuntimeCodegen()
self._init = self._mod["init"]
self._codegen = self._mod["codegen"]
self._get_graph_json = self._mod["get_graph_json"]
Expand Down
96 changes: 96 additions & 0 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(self):
self._get_graph_json = self.mod["get_graph_json"]
self._get_module = self.mod["get_module"]
self._build = self.mod["build"]
self._optimize = self.mod["optimize"]
self._set_params_func = self.mod["set_params"]
self._get_params_func = self.mod["get_params"]

Expand Down Expand Up @@ -113,6 +114,46 @@ def build(self, func, target=None, target_host=None, params=None):

return graph_json, mod, params

def optimize(self, func, target=None, params=None):
"""
Parameters
----------
func: relay.Function
The function to build.

target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
For heterogeneous compilation, it is a dictionary indicating context
to target mapping. For homogeneous compilation, it is a build target.

params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.

Returns
-------
graph_json : str
The json string that can be accepted by graph runtime.

mod : relay.Module
The optimized relay module.

params : dict
The parameters of the final graph.
"""
target = _update_target(target)

# Setup the params.
if params:
self._set_params(params)
mod = self._optimize(func, target)
# Get artifacts
graph_json = self.get_json()
params = self.get_params()
fwd4 marked this conversation as resolved.
Show resolved Hide resolved

return graph_json, mod, params


def _set_params(self, params):
inputs = {}
for name, param in params.items():
Expand Down Expand Up @@ -208,6 +249,61 @@ def build(mod, target=None, target_host=None, params=None):
return graph_json, mod, params


def optimize(mod, target=None, params=None):
"""Helper function that builds a Relay function to run on TVM graph
fwd4 marked this conversation as resolved.
Show resolved Hide resolved
runtime.

Parameters
----------
mod : relay.Module
The module to build. Using relay.Function is deprecated.

target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context
name) to str/tvm.target.Target, optional
For heterogeneous compilation, it is a dictionary indicating context to
target mapping. For homogeneous compilation, it is a build target.

params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.

Returns
-------
graph_json : str
The json string that can be accepted by graph runtime.
fwd4 marked this conversation as resolved.
Show resolved Hide resolved

mod : relay.Module
The optimized relay module.

params : dict
The parameters of the final graph.
"""
fwd4 marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(mod, _Module):
func = mod["main"]
elif isinstance(mod, _expr.Function):
func = mod
warnings.warn(
"Please use input parameter mod (tvm.relay.module.Module) "
"instead of deprecated parameter func (tvm.relay.expr.Function)",
DeprecationWarning)
else:
raise ValueError("Type of input parameter mod must be tvm.relay.module.Module")

target = _update_target(target)

# If current dispatch context is fallback context (the default root context),
# then load pre-tuned parameters from TopHub
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
tophub_context = autotvm.tophub.context(list(target.values()))
else:
tophub_context = autotvm.util.EmptyContext()

with tophub_context:
bld_mod = BuildModule()
graph_json, mod, params = bld_mod.optimize(func, target, params)
return graph_json, mod, params


class GraphExecutor(_interpreter.Executor):
"""Wrapper around Executor interface.

Expand Down
13 changes: 13 additions & 0 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,19 @@ class RelayBuildModule : public runtime::ModuleNode {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->graph_codegen_->GetLoweredFunc();
});
} else if (name == "optimize") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 2);
Function func = args[0];
if (this->params_.size()) {
func = this->BindParamsByName(func, this->params_);
}
// Perform Module->Module optimizations.
relay::Module relay_module = relay::ModuleNode::FromExpr(func);
relay_module = Optimize(relay_module, args[1], this->params_);
fwd4 marked this conversation as resolved.
Show resolved Hide resolved
CHECK(relay_module.defined());
*rv = relay_module;
});
} else {
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
Expand Down