Skip to content

Commit

Permalink
Check iter_type in vectorize (apache#1921)
Browse files Browse the repository at this point in the history
  • Loading branch information
izgzhen authored and AWS Neo committed Feb 20, 2019
1 parent 4477957 commit 858aab0
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/schedule/schedule_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,13 @@ inline void SetAttrIterType(StageNode* self, IterVar var, IterVarType iter_type)
}

Stage& Stage::vectorize(IterVar var) { // NOLINT(*)
CHECK(var->iter_type == kDataPar ||
var->iter_type == kOpaque ||
var->iter_type == kUnrolled ||
var->iter_type == kVectorized ||
var->iter_type == kTensorized ||
var->iter_type == kParallelized)
<< "Cannot vectorize on " << IterVarType2String(var->iter_type);
SetAttrIterType(operator->(), var, kVectorized);
return *this;
}
Expand Down
9 changes: 9 additions & 0 deletions tests/python/unittest/test_lang_schedule.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from nose.tools import raises
import tvm
import pickle as pkl

Expand Down Expand Up @@ -112,6 +113,13 @@ def test_vectorize():
assert s[T].iter_var_attrs[xi].iter_type == UNROLL
assert s[T].iter_var_attrs[yi].iter_type == VECTORIZE

@raises(Exception)
def test_vectorize_commreduce():
V = tvm.placeholder((128,), name='V')
ax = tvm.reduce_axis((0, 128), name='ax')
O = tvm.compute((1,), lambda _: tvm.sum(V[ax], axis=[ax]))
s = tvm.create_schedule(O.op)
s[O].vectorize(ax) # should throw here

def test_pragma():
m = 100
Expand Down Expand Up @@ -197,3 +205,4 @@ def intrin_func(ins, outs):
test_split()
test_fuse()
test_vectorize()
test_vectorize_commreduce()

0 comments on commit 858aab0

Please sign in to comment.