Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoScheduler] New layout rewrite option: Weight pre-transpose #6750

Merged
merged 15 commits into from
Nov 2, 2020
Prev Previous commit
Next Next commit
Update
  • Loading branch information
jcf94 committed Oct 27, 2020

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 8b2f716b7f866a1746fc2a7f989bd64585619e5b
16 changes: 9 additions & 7 deletions include/tvm/auto_scheduler/compute_dag.h
Original file line number Diff line number Diff line change
@@ -195,19 +195,21 @@ class ComputeDAGNode : public Object {
};

/*!
* \brief Several options for applying layout rewrite.
* This is a optimization to rewrite the shape of input tensor according to the schedule we get.
* \brief Options for applying layout rewrite.
* This is an optimization to rewrite the layout of input tensors according to the schedule we get.
*/
enum class LayoutRewriteOption : int {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enum class LayoutRewriteOption : uint8 should be enough.

/*! \brief Do not process layout rewrite. */
NoRewrite = 0,
/*! \brief Insert layout transformation stages for input placeholders in the compute DAG */
InsertTransformStage = 1,
/*!
* \brief Modify the placeholder to suit the schedule.
* \note This should be used along with the graph optimization in Relay.
* \brief Do not insert layout transformation stages and assume the input placeholders
* are pre-transformed.
* \note The lowered function with this option does not accept the origial input shapes,
* so this option must be used along with a layout conversion pass in Relay.
*/
RewriteWithPlaceholder = 1,
/*! \brief Insert a pre-transpose stage between placeholer and compute op to suit the schedule. */
RewriteWithPreTranspose = 2
RewriteForPreTransformed = 2,
};

/*!
2 changes: 1 addition & 1 deletion include/tvm/auto_scheduler/transform_step.h
Original file line number Diff line number Diff line change
@@ -187,7 +187,7 @@ class Step : public ObjectRef {
* This works almost the same as a normal ObjectRef.CopyOnWrite(), but can dispatch to different
* steps.
* \return A base StepNode pointer, need to cast to its real StepNode type before doing any
* modifies.
* modifications.
* \code
*
* SplitStep ref;
12 changes: 6 additions & 6 deletions python/tvm/auto_scheduler/compute_dag.py
Original file line number Diff line number Diff line change
@@ -50,11 +50,11 @@ class ComputeDAG(Object):
compute : Union[List[Tensor], str, Schedule]
Input/output tensors or workload key for a compute declaration.
"""
LAYOUT_REWRITE_TABLE = {
"NoRewrite": 0,
"RewriteWithPlaceholder": 1,
"RewriteWithPreTranspose": 2,
}

# Layout Rewrite Options
NoRewrite = 0
InsertTransformStage = 1
RewriteForPreTransformed = 2

def __init__(self, compute_or_sche):
if isinstance(compute_or_sche, str):
@@ -86,7 +86,7 @@ def get_init_state(self):
"""
return State(self.init_state, self)

def apply_steps_from_state(self, state, layout_rewrite=LAYOUT_REWRITE_TABLE["NoRewrite"]):
def apply_steps_from_state(self, state, layout_rewrite=NoRewrite):
"""
Apply the history transform steps from a State to get a TVM schedule.

12 changes: 6 additions & 6 deletions src/auto_scheduler/compute_dag.cc
Original file line number Diff line number Diff line change
@@ -863,6 +863,8 @@ std::string GetNewLayout(const State& state, const int stage_id, const Stage& st

ComputeDAG ComputeDAG::RewriteLayout(Array<Step>* transform_steps,
LayoutRewriteOption layout_rewrite) const {
CHECK(layout_rewrite != LayoutRewriteOption::NoRewrite)
<< "Call ComputeDAG::RewriteLayout with NoRewrite.";
ComputeDAG new_dag = *this;
ComputeDAGNode* p_dag = new_dag.CopyOnWrite();

@@ -921,11 +923,11 @@ ComputeDAG ComputeDAG::RewriteLayout(Array<Step>* transform_steps,

// Process op updates
te::Operation new_op_to_update;
if (layout_rewrite == LayoutRewriteOption::RewriteWithPlaceholder) {
if (layout_rewrite == LayoutRewriteOption::RewriteForPreTransformed) {
// Create new placeholder
new_op_to_update = te::PlaceholderOp(placeholder_op->name, new_shape,
placeholder_op.as<te::PlaceholderOpNode>()->dtype);
} else if (layout_rewrite == LayoutRewriteOption::RewriteWithPreTranspose) {
} else if (layout_rewrite == LayoutRewriteOption::InsertTransformStage) {
// Process index strides
std::unordered_map<std::string, PrimExpr> axes_stride;
for (const auto& i : origin_axes) {
@@ -980,8 +982,6 @@ ComputeDAG ComputeDAG::RewriteLayout(Array<Step>* transform_steps,
}
transform_steps->push_back(FuseStep(stage_id, to_fuse));
transform_steps->push_back(AnnotationStep(stage_id, 0, IteratorAnnotation::kParallel));
} else {
LOG(FATAL) << "Call ComputeDAG::RewriteLayout with NoRewrite.";
}

te::Operation new_compute_op, original_compute_op;
@@ -1015,7 +1015,7 @@ ComputeDAG ComputeDAG::RewriteLayout(Array<Step>* transform_steps,
for (size_t i = 0; i < original_ops.size(); ++i) {
const auto& original_op = original_ops[i];
if (original_op == placeholder_op) {
if (layout_rewrite == LayoutRewriteOption::RewriteWithPreTranspose) {
if (layout_rewrite == LayoutRewriteOption::InsertTransformStage) {
p_dag->ops.push_back(placeholder_op);
}
p_dag->ops.push_back(new_op_to_update);
@@ -1062,7 +1062,7 @@ ComputeDAG ComputeDAG::RewriteLayout(Array<Step>* transform_steps,
ArrayNode* p_tensors = p_dag->tensors.CopyOnWrite();
for (size_t i = 0; i < old_tensors.size(); ++i) {
const auto& old_tensor = old_tensors[i];
if (layout_rewrite != LayoutRewriteOption::RewriteWithPlaceholder &&
if (layout_rewrite != LayoutRewriteOption::RewriteForPreTransformed &&
old_tensor->op->IsInstance<te::PlaceholderOpNode>()) {
continue;
}
24 changes: 16 additions & 8 deletions tests/python/unittest/test_auto_scheduler_layout_rewrite.py
Original file line number Diff line number Diff line change
@@ -31,19 +31,22 @@ def test_apply_steps_with_layout_rewrite():
_, bufs = dag.apply_steps_from_state(s)
assert bufs[1].shape[0] == 512
assert bufs[1].shape[1] == 512
_, bufs = dag.apply_steps_from_state(s,
layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.LAYOUT_REWRITE_TABLE["RewriteWithPlaceholder"])
_, bufs = dag.apply_steps_from_state(
s, layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.RewriteForPreTransformed
)
assert bufs[1].shape[0] == 4
assert bufs[1].shape[1] == 8
assert bufs[1].shape[2] == 4
assert bufs[1].shape[3] == 4
assert bufs[1].shape[4] == 512
_, bufs = dag.apply_steps_from_state(s,
layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.LAYOUT_REWRITE_TABLE["RewriteWithPreTranspose"])
_, bufs = dag.apply_steps_from_state(
s, layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.InsertTransformStage
)
assert bufs[1].shape[0] == 512
assert bufs[1].shape[1] == 512


@tvm.testing.requires_llvm
def test_correctness_layout_rewrite_with_placeholder():
N = 128
target = tvm.target.Target("llvm")
@@ -64,8 +67,9 @@ def test_correctness_layout_rewrite_with_placeholder():
)
auto_scheduler.auto_schedule(task, search_policy, tuning_options)
inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target)
s, bufs = dag.apply_steps_from_state(inp.state,
layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.LAYOUT_REWRITE_TABLE["RewriteWithPlaceholder"])
s, bufs = dag.apply_steps_from_state(
inp.state, layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.RewriteForPreTransformed
)
s_ref, bufs_ref = dag.apply_steps_from_state(inp.state)
np_args = [np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs]
np_args_ref = [np.array(x) for x in np_args]
@@ -109,8 +113,10 @@ def test_correctness_layout_rewrite_with_placeholder():

np.testing.assert_allclose(args[0].asnumpy(), args_ref[0].asnumpy())
np.testing.assert_allclose(args[2].asnumpy(), args_ref[2].asnumpy())
del measure_ctx


@tvm.testing.requires_llvm
def test_correctness_layout_rewrite_with_pre_transpose():
N = 128
target = tvm.target.Target("llvm")
@@ -131,8 +137,9 @@ def test_correctness_layout_rewrite_with_pre_transpose():
)
auto_scheduler.auto_schedule(task, search_policy, tuning_options)
inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target)
s, bufs = dag.apply_steps_from_state(inp.state,
layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.LAYOUT_REWRITE_TABLE["RewriteWithPreTranspose"])
s, bufs = dag.apply_steps_from_state(
inp.state, layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.InsertTransformStage
)

s_ref, bufs_ref = dag.apply_steps_from_state(inp.state)
np_args = [np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs]
@@ -154,6 +161,7 @@ def test_correctness_layout_rewrite_with_pre_transpose():
np.testing.assert_allclose(args[0].asnumpy(), args_ref[0].asnumpy())
np.testing.assert_allclose(args[1].asnumpy(), args_ref[1].asnumpy())
np.testing.assert_allclose(args[2].asnumpy(), args_ref[2].asnumpy())
del measure_ctx


if __name__ == "__main__":