Skip to content

Commit

Permalink
Early checking added and new test cases added for schedule fuse (apac…
Browse files Browse the repository at this point in the history
…he#5010)

* [1] New test case added for fuse

* [2] New test case added for fuse

* [3] New test case added for fuse

* [4] New test case added for fuse

* [5] Early check added
  • Loading branch information
ANSHUMAN TRIPATHY authored and zhiics committed Apr 17, 2020
1 parent 12cd550 commit 12f9a0a
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/te/schedule/schedule_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,10 @@ Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT
std::swap(outer, inner);
std::swap(pos_inner, pos_outer);
}
self->relations.push_back(FuseNode::make(outer, inner, fused));
all_vars->data.push_back(fused);
CHECK_EQ(pos_inner, pos_outer + 1)
<< "Can only fuse iterations that are consecutive between each other";
self->relations.push_back(FuseNode::make(outer, inner, fused));
all_vars->data.push_back(fused);
leaf_vars->data.erase(leaf_vars->data.begin() + pos_outer,
leaf_vars->data.begin() + pos_inner + 1);
leaf_vars->data.insert(leaf_vars->data.begin() + pos_outer,
Expand Down
46 changes: 46 additions & 0 deletions tests/python/unittest/test_lang_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,49 @@ def test_fuse():
assert any(isinstance(x, tvm.te.schedule.Fuse) for x in s[T].relations)
assert tuple(s[T].leaf_iter_vars) == (fused, xi, yi)

def test_fuse_with_split():
m = te.size_var('m')
n = te.size_var('n')
A = te.placeholder((m, n), name='A')
T = te.compute((m, n), lambda i, j: A[i, j])

s = te.create_schedule(T.op)
y = T.op.axis[1]
xo, xi = s[T].split(T.op.axis[0], factor=10)
fused = s[T].fuse(xi, y)
assert any(isinstance(x, tvm.te.schedule.Fuse) for x in s[T].relations)
assert tuple(s[T].leaf_iter_vars) == (xo, fused)

@pytest.mark.xfail
def test_fuse_with_out_of_order_axis():
m = te.size_var('m')
n = te.size_var('n')
A = te.placeholder((m, n), name='A')
T = te.compute((m, n), lambda i, j: A[i, j])

s = te.create_schedule(T.op)
y = T.op.axis[1]
xo, xi = s[T].split(T.op.axis[0], factor=10)
fused = s[T].fuse(xo, y) # should throw here

@pytest.mark.xfail
def test_fuse_with_out_of_order_axis_with_reorder():
m = te.size_var('m')
n = te.size_var('n')
A = te.placeholder((m, n), name='A')
T = te.compute((m, n), lambda i, j: A[i, j])

s = te.create_schedule(T.op)
y = T.op.axis[1]
xo, xi = s[T].split(T.op.axis[0], factor=10)
s[T].reorder(y, xo, xi)
fused = s[T].fuse(y, xo) # should be ok

s = te.create_schedule(T.op)
y = T.op.axis[1]
xo, xi = s[T].split(T.op.axis[0], factor=10)
s[T].reorder(y, xo, xi)
fused = s[T].fuse(y, xi) # should throw here

def test_singleton():
print("test singleton")
Expand Down Expand Up @@ -257,5 +300,8 @@ def intrin_func(ins, outs, sp):
test_tile()
test_split()
test_fuse()
test_fuse_with_split()
test_fuse_with_out_of_order_axis()
test_fuse_with_out_of_order_axis_with_reorder()
test_vectorize()
test_vectorize_commreduce()

0 comments on commit 12f9a0a

Please sign in to comment.