Skip to content

Commit

Permalink
[AutoScheduler] Update layout rewrite option setting for measuring (a…
Browse files Browse the repository at this point in the history
…pache#7156)

* Add layout rewrite options for measure

* Update schedule for inserted transform stage

* Set layout rewrite when tuning for network

* Update the log version
  • Loading branch information
jcf94 authored Dec 28, 2020
1 parent 2f8187a commit 2dec2dd
Show file tree
Hide file tree
Showing 10 changed files with 110 additions and 29 deletions.
2 changes: 1 addition & 1 deletion include/tvm/auto_scheduler/measure_record.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
namespace tvm {
namespace auto_scheduler {

const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.4"; // NOLINT(*)
const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.5"; // NOLINT(*)

/*! \brief Callback for logging the input and results of measurements to file */
class RecordToFileNode : public MeasureCallbackNode {
Expand Down
6 changes: 5 additions & 1 deletion include/tvm/auto_scheduler/search_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,16 @@ class SearchTaskNode : public Object {
Target target_host;
/*! \brief Hardware parameters used in this search task. */
HardwareParams hardware_params;
/*! \brief The layout rewrite option used for measuring programs. */
LayoutRewriteOption layout_rewrite_option;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("compute_dag", &compute_dag);
v->Visit("workload_key", &workload_key);
v->Visit("target", &target);
v->Visit("target_host", &target_host);
v->Visit("hardware_params", &hardware_params);
v->Visit("layout_rewrite_option", &layout_rewrite_option);
}

static constexpr const char* _type_key = "auto_scheduler.SearchTask";
Expand All @@ -144,9 +147,10 @@ class SearchTask : public ObjectRef {
* \param target The target device of this search task.
* \param target_host The target host device of this search task.
* \param hardware_params Hardware parameters used in this search task.
* \param layout_rewrite_option The layout rewrite option used for measuring programs.
*/
SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host,
Optional<HardwareParams> hardware_params);
Optional<HardwareParams> hardware_params, LayoutRewriteOption layout_rewrite_option);

TVM_DEFINE_OBJECT_REF_METHODS(SearchTask, ObjectRef, SearchTaskNode);
};
Expand Down
36 changes: 35 additions & 1 deletion python/tvm/auto_scheduler/compute_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@


class LayoutRewriteOption:
"""Options for applying layout rewrite."""
"""
Options for applying layout rewrite.
The NO_REWRITE and INSERT_TRANSFORM_STAGE are expected to be used when tuning a standalone op,
and the REWRITE_FOR_PRE_TRANSFORMED is expected to be used when tuning ops inside a network.
"""

# Do not perform layout rewrite
NO_REWRITE = 0
Expand All @@ -44,6 +49,35 @@ class LayoutRewriteOption:
# so this option must be used along with `AutoSchedulerLayoutRewrite` pass in Relay.
REWRITE_FOR_PRE_TRANSFORMED = 2

@staticmethod
def get_target_default(target, in_relay_integration=False):
"""Get the default layout rewrite option for the specified target.
Currently we only enable layout rewrite for cpu / mali backend for now
Parameters
----------
target: tvm.target.Target
The compilation target.
in_relay_integration: bool
If this check is ask for relay integration.
Returns
-------
layout_rewrite_option: LayoutRewriteOption
The default layout rewrite option for the specified target.
"""
layout_rewrite_option = LayoutRewriteOption.NO_REWRITE
if target.kind.name == "llvm" or (
"device" in target.attrs and target.attrs["device"] == "mali"
):
layout_rewrite_option = (
LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED
if in_relay_integration
else LayoutRewriteOption.INSERT_TRANSFORM_STAGE
)

return layout_rewrite_option


@tvm._ffi.register_object("auto_scheduler.ComputeDAG")
class ComputeDAG(Object):
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
make_traceback_info,
request_remote,
)
from .compute_dag import LayoutRewriteOption
from .workload_registry import (
serialize_workload_registry_entry,
deserialize_workload_registry_entry,
Expand Down Expand Up @@ -211,6 +210,7 @@ def recover_measure_input(inp, rebuild_state=False):
target=task.target,
target_host=task.target_host,
hardware_params=task.hardware_params,
layout_rewrite_option=task.layout_rewrite_option,
)

if rebuild_state:
Expand Down Expand Up @@ -576,7 +576,7 @@ def _timed_func(inp_serialized, build_func, verbose):

try:
sch, args = task.compute_dag.apply_steps_from_state(
inp.state, layout_rewrite=LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED
inp.state, layout_rewrite=task.layout_rewrite_option
)
# pylint: disable=broad-except
except Exception:
Expand Down
16 changes: 8 additions & 8 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from tvm.te.tensor import ComputeOp, PlaceholderOp, Tensor
from tvm.tir import expr as _expr
from . import _ffi_api
from .compute_dag import ComputeDAG
from .compute_dag import ComputeDAG, LayoutRewriteOption
from .dispatcher import DispatchContext
from .search_task import SearchTask
from .workload_registry import register_workload_tensors
Expand Down Expand Up @@ -126,6 +126,9 @@ def extract_tasks(
target=target,
target_host=target_host,
hardware_params=hardware_params,
# When auto scheduler is used in end to end network, try to apply layout rewrite
# to improve the overall performance
layout_rewrite_option=LayoutRewriteOption.get_target_default(target, True),
)
)
weights.append(use_count_dict[ccache_key] + 1)
Expand Down Expand Up @@ -259,13 +262,7 @@ def auto_schedule_topi(outs, has_complex_op):

key = register_workload_tensors(dag.hash_key(), io_tensors)

# only enable layout rewrite for cpu / mali backend
target = tvm.target.Target.current()
enable_layout_rewrite_targets = ["cpu", "mali"]
enable_layout_rewrite = any(
enable_layout_rewrite_target in target.keys
for enable_layout_rewrite_target in enable_layout_rewrite_targets
)

env = TracingEnvironment.current
if env is None:
Expand All @@ -284,7 +281,10 @@ def auto_schedule_topi(outs, has_complex_op):
schedule = te.create_schedule([x.op for x in outs])
elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE:
# in prepare_layout_rewrite mode
if enable_layout_rewrite and has_layout_free:
if (
LayoutRewriteOption.get_target_default(target, True) != LayoutRewriteOption.NO_REWRITE
and has_layout_free
):
dispatch_ctx = DispatchContext.current
state = dispatch_ctx.query(target, key, has_complex_op, dag)
if state is None:
Expand Down
27 changes: 21 additions & 6 deletions python/tvm/auto_scheduler/search_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,13 @@ class SearchTask(Object):
The target host device of this search task.
hardware_params : Optional[HardwareParams]
Hardware parameters used in this search task.
layout_rewrite_option : Optional[LayoutRewriteOption]
The layout rewrite option used for measuring programs. If None, the default value will be
set depending on the specified target.
Auto_scheduler will find a better schedule for the specified layout rewrite option.
The NO_REWRITE and INSERT_TRANSFORM_STAGE are expected to be used when tuning a standalone
op, and the REWRITE_FOR_PRE_TRANSFORMED is expected to be used when tuning ops inside a
network.
Examples
--------
Expand All @@ -204,6 +211,7 @@ def __init__(
target=None,
target_host=None,
hardware_params=None,
layout_rewrite_option=None,
):
assert (
func is not None or workload_key is not None
Expand All @@ -221,7 +229,13 @@ def __init__(
target_host = Target(target_host)

self.__init_handle_by_constructor__(
_ffi_api.SearchTask, compute_dag, workload_key, target, target_host, hardware_params
_ffi_api.SearchTask,
compute_dag,
workload_key,
target,
target_host,
hardware_params,
layout_rewrite_option or LayoutRewriteOption.get_target_default(target),
)

def tune(self, tuning_options, search_policy=None):
Expand Down Expand Up @@ -250,6 +264,7 @@ def apply_best(self, log_file, layout_rewrite_option=None):
layout_rewrite_option : Optional[LayoutRewriteOption]
The layout rewrite option.
Returns
-------
A `te.Schedule` and the a list of `te.Tensor` to be used in `tvm.lower` or `tvm.build`.
Expand All @@ -260,11 +275,9 @@ def apply_best(self, log_file, layout_rewrite_option=None):
"Cannot find any valid schedule for %s in file %s" % (self.workload_key, log_file)
)

if layout_rewrite_option is None:
layout_rewrite_option = LayoutRewriteOption.NO_REWRITE
if self.target.kind.name == "llvm":
layout_rewrite_option = LayoutRewriteOption.INSERT_TRANSFORM_STAGE
sch, args = self.compute_dag.apply_steps_from_state(inp.state, layout_rewrite_option)
sch, args = self.compute_dag.apply_steps_from_state(
inp.state, layout_rewrite_option or self.layout_rewrite_option
)
return sch, args

def print_best(self, log_file, print_mode="schedule"):
Expand Down Expand Up @@ -305,6 +318,7 @@ def __getstate__(self):
"target": self.target,
"target_host": self.target_host,
"hardware_params": self.hardware_params,
"layout_rewrite_option": self.layout_rewrite_option,
}

def __setstate__(self, state):
Expand All @@ -327,6 +341,7 @@ def __setstate__(self, state):
state["target"],
state["target_host"],
state["hardware_params"],
state["layout_rewrite_option"],
)


Expand Down
20 changes: 16 additions & 4 deletions src/auto_scheduler/compute_dag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -998,11 +998,20 @@ ComputeDAG ComputeDAG::RewriteLayout(Array<Step>* transform_steps,
transform_steps->Set(i, std::move(step));
}
}

// Add schedule for the new added transform stage
Array<Integer> to_fuse;
for (size_t i = 0; i < new_shape.size() - 1; i++) {
to_fuse.push_back(i);

if (new_shape.size() >= 5) {
to_fuse.push_back(0);
to_fuse.push_back(1);
to_fuse.push_back(2);
transform_steps->push_back(FuseStep(stage_id, to_fuse));
} else if (new_shape.size() >= 3) {
to_fuse.push_back(0);
to_fuse.push_back(1);
transform_steps->push_back(FuseStep(stage_id, to_fuse));
}
transform_steps->push_back(FuseStep(stage_id, to_fuse));
transform_steps->push_back(AnnotationStep(stage_id, 0, IteratorAnnotation::kParallel));
}

Expand All @@ -1024,7 +1033,10 @@ ComputeDAG ComputeDAG::RewriteLayout(Array<Step>* transform_steps,
}
original_compute_op = op;
CHECK(!new_compute_op.defined());
new_compute_op = te::ComputeOp(pop->name, pop->tag, pop->attrs, pop->axis, new_body);
auto new_attrs = pop->attrs;
new_attrs.Set("ori_placeholder_layout", tvm::String(origin_layout));
new_attrs.Set("new_placeholder_layout", tvm::String(new_layout));
new_compute_op = te::ComputeOp(pop->name, pop->tag, new_attrs, pop->axis, new_body);
}
}
}
Expand Down
6 changes: 4 additions & 2 deletions src/auto_scheduler/feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1398,7 +1398,8 @@ void GetPerStoreFeaturesFromFile(const std::string& filename, int max_lines, int
// rebuild task
Array<te::Tensor> tensors = (*workload_key_to_tensors)(workload_key);
task = SearchTask(ComputeDAG(tensors), workload_key, cur_inp->task->target,
cur_inp->task->target_host, cur_inp->task->hardware_params);
cur_inp->task->target_host, cur_inp->task->hardware_params,
cur_inp->task->layout_rewrite_option);
task_id = task_cache.size();

// compute min cost for each task
Expand Down Expand Up @@ -1465,7 +1466,8 @@ void GetPerStoreFeaturesFromMeasurePairs(const Array<MeasureInput>& inputs,
// rebuild task for incomplete measure pairs read from file
Array<te::Tensor> tensors = (*workload_key_to_tensors)(workload_key);
task = SearchTask(ComputeDAG(tensors), workload_key, inputs[i]->task->target,
inputs[i]->task->target_host, inputs[i]->task->hardware_params);
inputs[i]->task->target_host, inputs[i]->task->hardware_params,
inputs[i]->task->layout_rewrite_option);
}
task_id = task_cache.size();

Expand Down
12 changes: 11 additions & 1 deletion src/auto_scheduler/measure_record.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,16 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> {
writer->WriteArrayItem(*data.hardware_params.get());
if (data.target_host.defined()) {
writer->WriteArrayItem(data.target_host->str());
} else {
writer->WriteArrayItem(std::string(""));
}
writer->WriteArrayItem(static_cast<int>(data.layout_rewrite_option));
writer->EndArray();
}
inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::SearchTaskNode* data) {
bool s;
std::string str_value;
int int_value;
auto hardware_params_node = ::tvm::make_object<::tvm::auto_scheduler::HardwareParamsNode>();
reader->BeginArray();
s = reader->NextArrayItem();
Expand All @@ -188,7 +192,13 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> {
data->hardware_params = ::tvm::auto_scheduler::HardwareParams(hardware_params_node);
if (s) {
reader->Read(&str_value);
data->target_host = ::tvm::Target(str_value);
if (!str_value.empty()) {
data->target_host = ::tvm::Target(str_value);
}
s = reader->NextArrayItem();
ICHECK(s);
reader->Read(&int_value);
data->layout_rewrite_option = ::tvm::auto_scheduler::LayoutRewriteOption(int_value);
s = reader->NextArrayItem();
ICHECK(!s);
}
Expand Down
10 changes: 7 additions & 3 deletions src/auto_scheduler/search_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target
}

SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target target,
Target target_host, Optional<HardwareParams> hardware_params) {
Target target_host, Optional<HardwareParams> hardware_params,
LayoutRewriteOption layout_rewrite_option) {
auto node = make_object<SearchTaskNode>();
node->compute_dag = std::move(compute_dag);
node->workload_key = std::move(workload_key);
Expand All @@ -125,6 +126,7 @@ SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target targe
node->hardware_params =
HardwareParamsNode::GetDefaultHardwareParams(node->target, node->target_host);
}
node->layout_rewrite_option = layout_rewrite_option;
data_ = std::move(node);
}

Expand All @@ -139,8 +141,10 @@ TVM_REGISTER_GLOBAL("auto_scheduler.HardwareParams")

TVM_REGISTER_GLOBAL("auto_scheduler.SearchTask")
.set_body_typed([](ComputeDAG compute_dag, String workload_key, Target target,
Target target_host, Optional<HardwareParams> hardware_params) {
return SearchTask(compute_dag, workload_key, target, target_host, hardware_params);
Target target_host, Optional<HardwareParams> hardware_params,
int layout_rewrite_option) {
return SearchTask(compute_dag, workload_key, target, target_host, hardware_params,
LayoutRewriteOption(layout_rewrite_option));
});

} // namespace auto_scheduler
Expand Down

0 comments on commit 2dec2dd

Please sign in to comment.