diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index baf42f9367b4..a6dd39f79b1f 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -457,11 +457,11 @@ ComputeLoopNest ComputeLoopNest::make( ret.init_vmap[iv] = ret.main_vmap.at(iv); } ret.num_common_loop = begin_loop; - // skip loops that does not relates to axis. + // skip loops that are related to reduction and are unrelated to axis. std::unordered_set skip_iter; for (auto kv : update_state) { int flag = kv.second; - if ((flag & 1) == 0) skip_iter.insert(kv.first); + if (flag == 2) skip_iter.insert(kv.first); } ret.init_nest = op::MakeLoopNest( stage, dom_map, begin_loop, true, diff --git a/src/op/tensor_compute_op.cc b/src/op/tensor_compute_op.cc index f9b8188d4685..0262db7d8fc5 100644 --- a/src/op/tensor_compute_op.cc +++ b/src/op/tensor_compute_op.cc @@ -215,11 +215,11 @@ ComputeLoopNest MakeLoopNest( ret.init_vmap[iv] = ret.main_vmap.at(iv); } ret.num_common_loop = begin_loop; - // skip loops that does not relates to axis. + // skip loops that are related to reduction and are unrelated to axis. std::unordered_set skip_iter; for (auto kv : update_state) { int flag = kv.second; - if ((flag & 1) == 0) skip_iter.insert(kv.first); + if (flag == 2) skip_iter.insert(kv.first); } ret.init_nest = op::MakeLoopNest( stage, dom_map, begin_loop, true, diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index 006dc2fc9f1e..c7cf1c142dd0 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -1,5 +1,5 @@ import tvm - +import numpy as np def test_schedule0(): m = tvm.var('m') @@ -432,6 +432,32 @@ def f(n): s.cache_write(Y, 'local') f = tvm.build(s, [X, Y]) +def test_reduction_and_dummy_fuse_split(): + n = 10 + X = tvm.placeholder(shape=(n,), dtype='int32', name="X") + k = tvm.reduce_axis((0, n)) + Y = tvm.compute((), lambda: tvm.sum(X[k], k), name="Y") + s = tvm.create_schedule([Y.op]) + ax = s[Y.op].fuse(*Y.op.axis) + axo, axi = s[Y.op].split(ax, nparts=20) + f = tvm.build(s, [Y, X]) + + args = [tvm.nd.empty((), 'int32')] + [tvm.ndarray.array(np.ones((n,), dtype='int32'))] + f(*args) + assert args[0].asnumpy() == n + + n = 10 + X = tvm.placeholder(shape=(n,), dtype='int32', name="X") + k = tvm.reduce_axis((0, n)) + Y = tvm.compute((n,), lambda i: tvm.sum(X[k], k), name="Y") + s = tvm.create_schedule([Y.op]) + ax = s[Y.op].fuse(*(list(Y.op.axis) + list(Y.op.reduce_axis))) + f = tvm.build(s, [Y, X]) + + args = [tvm.ndarray.array(np.ones((n,), dtype='int32'))] + \ + [tvm.ndarray.array(np.ones((n,), dtype='int32'))] + f(*args) + assert np.all(args[0].asnumpy() == n) if __name__ == "__main__": test_loop_dep_reduce() @@ -456,3 +482,4 @@ def f(n): test_schedule_tensor_compute1() test_schedule_tensor_compute2() test_schedule_tensor_compute3() + test_reduction_and_dummy_fuse_split()