Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoScheduler] Update layout rewrite option setting for measuring #7156

Merged
merged 14 commits into from
Dec 28, 2020
2 changes: 1 addition & 1 deletion include/tvm/auto_scheduler/measure_record.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
namespace tvm {
namespace auto_scheduler {

const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.4"; // NOLINT(*)
const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.5"; // NOLINT(*)

/*! \brief Callback for logging the input and results of measurements to file */
class RecordToFileNode : public MeasureCallbackNode {
Expand Down
6 changes: 5 additions & 1 deletion include/tvm/auto_scheduler/search_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,16 @@ class SearchTaskNode : public Object {
Target target_host;
/*! \brief Hardware parameters used in this search task. */
HardwareParams hardware_params;
/*! \brief The layout rewrite option used for measuring programs. */
LayoutRewriteOption layout_rewrite_option;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("compute_dag", &compute_dag);
v->Visit("workload_key", &workload_key);
v->Visit("target", &target);
v->Visit("target_host", &target_host);
v->Visit("hardware_params", &hardware_params);
v->Visit("layout_rewrite_option", &layout_rewrite_option);
}

static constexpr const char* _type_key = "auto_scheduler.SearchTask";
Expand All @@ -144,9 +147,10 @@ class SearchTask : public ObjectRef {
* \param target The target device of this search task.
* \param target_host The target host device of this search task.
* \param hardware_params Hardware parameters used in this search task.
* \param layout_rewrite_option The layout rewrite option used for measuring programs.
*/
SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host,
Optional<HardwareParams> hardware_params);
Optional<HardwareParams> hardware_params, LayoutRewriteOption layout_rewrite_option);

TVM_DEFINE_OBJECT_REF_METHODS(SearchTask, ObjectRef, SearchTaskNode);
};
Expand Down
36 changes: 35 additions & 1 deletion python/tvm/auto_scheduler/compute_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@


class LayoutRewriteOption:
"""Options for applying layout rewrite."""
"""
Options for applying layout rewrite.

The NO_REWRITE and INSERT_TRANSFORM_STAGE are expected to be used when tuning a standalone op,
and the REWRITE_FOR_PRE_TRANSFORMED is expected to be used when tuning ops inside a network.
"""

# Do not perform layout rewrite
NO_REWRITE = 0
Expand All @@ -44,6 +49,35 @@ class LayoutRewriteOption:
# so this option must be used along with `AutoSchedulerLayoutRewrite` pass in Relay.
REWRITE_FOR_PRE_TRANSFORMED = 2

@staticmethod
def get_target_default(target, in_relay_integration=False):
"""Get the default layout rewrite option for the specified target.
Currently we only enable layout rewrite for cpu / mali backend for now

Parameters
----------
target: tvm.target.Target
The compilation target.
in_relay_integration: bool
If this check is ask for relay integration.

Returns
-------
layout_rewrite_option: LayoutRewriteOption
The default layout rewrite option for the specified target.
"""
layout_rewrite_option = LayoutRewriteOption.NO_REWRITE
if target.kind.name == "llvm" or (
"device" in target.attrs and target.attrs["device"] == "mali"
):
layout_rewrite_option = (
LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED
if in_relay_integration
else LayoutRewriteOption.INSERT_TRANSFORM_STAGE
)

return layout_rewrite_option


@tvm._ffi.register_object("auto_scheduler.ComputeDAG")
class ComputeDAG(Object):
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
make_traceback_info,
request_remote,
)
from .compute_dag import LayoutRewriteOption
from .workload_registry import (
serialize_workload_registry_entry,
deserialize_workload_registry_entry,
Expand Down Expand Up @@ -186,6 +185,7 @@ def recover_measure_input(inp, rebuild_state=False):
target=task.target,
target_host=task.target_host,
hardware_params=task.hardware_params,
layout_rewrite_option=task.layout_rewrite_option,
)

if rebuild_state:
Expand Down Expand Up @@ -551,7 +551,7 @@ def _timed_func(inp_serialized, build_func, verbose):

try:
sch, args = task.compute_dag.apply_steps_from_state(
inp.state, layout_rewrite=LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED
inp.state, layout_rewrite=task.layout_rewrite_option
)
# pylint: disable=broad-except
except Exception:
Expand Down
16 changes: 8 additions & 8 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from tvm.te.tensor import ComputeOp, PlaceholderOp, Tensor
from tvm.tir import expr as _expr
from . import _ffi_api
from .compute_dag import ComputeDAG
from .compute_dag import ComputeDAG, LayoutRewriteOption
from .dispatcher import DispatchContext
from .search_task import SearchTask
from .workload_registry import register_workload_tensors
Expand Down Expand Up @@ -126,6 +126,9 @@ def extract_tasks(
target=target,
target_host=target_host,
hardware_params=hardware_params,
# When auto scheduler is used in end to end network, try to apply layout rewrite
# to improve the overall performance
layout_rewrite_option=LayoutRewriteOption.get_target_default(target, True),
)
)
weights.append(use_count_dict[ccache_key] + 1)
Expand Down Expand Up @@ -259,13 +262,7 @@ def auto_schedule_topi(outs, has_complex_op):

key = register_workload_tensors(dag.hash_key(), io_tensors)

# only enable layout rewrite for cpu / mali backend
target = tvm.target.Target.current()
enable_layout_rewrite_targets = ["cpu", "mali"]
enable_layout_rewrite = any(
enable_layout_rewrite_target in target.keys
for enable_layout_rewrite_target in enable_layout_rewrite_targets
)

env = TracingEnvironment.current
if env is None:
Expand All @@ -284,7 +281,10 @@ def auto_schedule_topi(outs, has_complex_op):
schedule = te.create_schedule([x.op for x in outs])
elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE:
# in prepare_layout_rewrite mode
if enable_layout_rewrite and has_layout_free:
if (
LayoutRewriteOption.get_target_default(target, True) != LayoutRewriteOption.NO_REWRITE
and has_layout_free
):
dispatch_ctx = DispatchContext.current
state = dispatch_ctx.query(target, key, has_complex_op, dag)
if state is None:
Expand Down
27 changes: 21 additions & 6 deletions python/tvm/auto_scheduler/search_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,13 @@ class SearchTask(Object):
The target host device of this search task.
hardware_params : Optional[HardwareParams]
Hardware parameters used in this search task.
layout_rewrite_option : Optional[LayoutRewriteOption]
The layout rewrite option used for measuring programs. If None, the default value will be
set depending on the specified target.
Auto_scheduler will find a better schedule for the specified layout rewrite option.
The NO_REWRITE and INSERT_TRANSFORM_STAGE are expected to be used when tuning a standalone
op, and the REWRITE_FOR_PRE_TRANSFORMED is expected to be used when tuning ops inside a
network.

Examples
--------
Expand All @@ -204,6 +211,7 @@ def __init__(
target=None,
target_host=None,
hardware_params=None,
layout_rewrite_option=None,
):
assert (
func is not None or workload_key is not None
Expand All @@ -221,7 +229,13 @@ def __init__(
target_host = Target(target_host)

self.__init_handle_by_constructor__(
_ffi_api.SearchTask, compute_dag, workload_key, target, target_host, hardware_params
_ffi_api.SearchTask,
compute_dag,
workload_key,
target,
target_host,
hardware_params,
layout_rewrite_option or LayoutRewriteOption.get_target_default(target),
)

def tune(self, tuning_options, search_policy=None):
Expand Down Expand Up @@ -250,6 +264,7 @@ def apply_best(self, log_file, layout_rewrite_option=None):
layout_rewrite_option : Optional[LayoutRewriteOption]
The layout rewrite option.


Returns
-------
A `te.Schedule` and the a list of `te.Tensor` to be used in `tvm.lower` or `tvm.build`.
Expand All @@ -260,11 +275,9 @@ def apply_best(self, log_file, layout_rewrite_option=None):
"Cannot find any valid schedule for %s in file %s" % (self.workload_key, log_file)
)

if layout_rewrite_option is None:
layout_rewrite_option = LayoutRewriteOption.NO_REWRITE
if self.target.kind.name == "llvm":
layout_rewrite_option = LayoutRewriteOption.INSERT_TRANSFORM_STAGE
sch, args = self.compute_dag.apply_steps_from_state(inp.state, layout_rewrite_option)
sch, args = self.compute_dag.apply_steps_from_state(
inp.state, layout_rewrite_option or self.layout_rewrite_option
)
return sch, args

def print_best(self, log_file, print_mode="schedule"):
Expand Down Expand Up @@ -305,6 +318,7 @@ def __getstate__(self):
"target": self.target,
"target_host": self.target_host,
"hardware_params": self.hardware_params,
"layout_rewrite_option": self.layout_rewrite_option,
}

def __setstate__(self, state):
Expand All @@ -327,6 +341,7 @@ def __setstate__(self, state):
state["target"],
state["target_host"],
state["hardware_params"],
state["layout_rewrite_option"],
)


Expand Down
20 changes: 16 additions & 4 deletions src/auto_scheduler/compute_dag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -997,11 +997,20 @@ ComputeDAG ComputeDAG::RewriteLayout(Array<Step>* transform_steps,
transform_steps->Set(i, std::move(step));
}
}

// Add schedule for the new added transform stage
Array<Integer> to_fuse;
for (size_t i = 0; i < new_shape.size() - 1; i++) {
to_fuse.push_back(i);

if (new_shape.size() >= 5) {
to_fuse.push_back(0);
to_fuse.push_back(1);
to_fuse.push_back(2);
transform_steps->push_back(FuseStep(stage_id, to_fuse));
} else if (new_shape.size() >= 3) {
to_fuse.push_back(0);
to_fuse.push_back(1);
transform_steps->push_back(FuseStep(stage_id, to_fuse));
}
transform_steps->push_back(FuseStep(stage_id, to_fuse));
transform_steps->push_back(AnnotationStep(stage_id, 0, IteratorAnnotation::kParallel));
}

Expand All @@ -1023,7 +1032,10 @@ ComputeDAG ComputeDAG::RewriteLayout(Array<Step>* transform_steps,
}
original_compute_op = op;
CHECK(!new_compute_op.defined());
new_compute_op = te::ComputeOp(pop->name, pop->tag, pop->attrs, pop->axis, new_body);
auto new_attrs = pop->attrs;
new_attrs.Set("ori_placeholder_layout", tvm::String(origin_layout));
new_attrs.Set("new_placeholder_layout", tvm::String(new_layout));
Comment on lines +1036 to +1037
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would need to get the layout information when exporting the kernel to run in an environment outside the tvm.
Emm... it's also fine to remove them here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. We can leave them here

new_compute_op = te::ComputeOp(pop->name, pop->tag, new_attrs, pop->axis, new_body);
}
}
}
Expand Down
6 changes: 4 additions & 2 deletions src/auto_scheduler/feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1398,7 +1398,8 @@ void GetPerStoreFeaturesFromFile(const std::string& filename, int max_lines, int
// rebuild task
Array<te::Tensor> tensors = (*workload_key_to_tensors)(workload_key);
task = SearchTask(ComputeDAG(tensors), workload_key, cur_inp->task->target,
cur_inp->task->target_host, cur_inp->task->hardware_params);
cur_inp->task->target_host, cur_inp->task->hardware_params,
cur_inp->task->layout_rewrite_option);
task_id = task_cache.size();

// compute min cost for each task
Expand Down Expand Up @@ -1465,7 +1466,8 @@ void GetPerStoreFeaturesFromMeasurePairs(const Array<MeasureInput>& inputs,
// rebuild task for incomplete measure pairs read from file
Array<te::Tensor> tensors = (*workload_key_to_tensors)(workload_key);
task = SearchTask(ComputeDAG(tensors), workload_key, inputs[i]->task->target,
inputs[i]->task->target_host, inputs[i]->task->hardware_params);
inputs[i]->task->target_host, inputs[i]->task->hardware_params,
inputs[i]->task->layout_rewrite_option);
}
task_id = task_cache.size();

Expand Down
12 changes: 11 additions & 1 deletion src/auto_scheduler/measure_record.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,16 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> {
writer->WriteArrayItem(*data.hardware_params.get());
if (data.target_host.defined()) {
writer->WriteArrayItem(data.target_host->str());
} else {
writer->WriteArrayItem(std::string(""));
}
writer->WriteArrayItem(static_cast<int>(data.layout_rewrite_option));
writer->EndArray();
}
inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::SearchTaskNode* data) {
bool s;
std::string str_value;
int int_value;
auto hardware_params_node = ::tvm::make_object<::tvm::auto_scheduler::HardwareParamsNode>();
reader->BeginArray();
s = reader->NextArrayItem();
Expand All @@ -188,7 +192,13 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> {
data->hardware_params = ::tvm::auto_scheduler::HardwareParams(hardware_params_node);
if (s) {
reader->Read(&str_value);
data->target_host = ::tvm::Target(str_value);
if (!str_value.empty()) {
data->target_host = ::tvm::Target(str_value);
}
s = reader->NextArrayItem();
ICHECK(s);
reader->Read(&int_value);
data->layout_rewrite_option = ::tvm::auto_scheduler::LayoutRewriteOption(int_value);
s = reader->NextArrayItem();
ICHECK(!s);
}
Expand Down
10 changes: 7 additions & 3 deletions src/auto_scheduler/search_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target
}

SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target target,
Target target_host, Optional<HardwareParams> hardware_params) {
Target target_host, Optional<HardwareParams> hardware_params,
LayoutRewriteOption layout_rewrite_option) {
auto node = make_object<SearchTaskNode>();
node->compute_dag = std::move(compute_dag);
node->workload_key = std::move(workload_key);
Expand All @@ -125,6 +126,7 @@ SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target targe
node->hardware_params =
HardwareParamsNode::GetDefaultHardwareParams(node->target, node->target_host);
}
node->layout_rewrite_option = layout_rewrite_option;
data_ = std::move(node);
}

Expand All @@ -139,8 +141,10 @@ TVM_REGISTER_GLOBAL("auto_scheduler.HardwareParams")

TVM_REGISTER_GLOBAL("auto_scheduler.SearchTask")
.set_body_typed([](ComputeDAG compute_dag, String workload_key, Target target,
Target target_host, Optional<HardwareParams> hardware_params) {
return SearchTask(compute_dag, workload_key, target, target_host, hardware_params);
Target target_host, Optional<HardwareParams> hardware_params,
int layout_rewrite_option) {
return SearchTask(compute_dag, workload_key, target, target_host, hardware_params,
LayoutRewriteOption(layout_rewrite_option));
});

} // namespace auto_scheduler
Expand Down