Skip to content

Commit

Permalink
[TVM][BUGFIX] Fix missing reduction init predicates (apache#2495)
Browse files Browse the repository at this point in the history
* [TVM][BUGFIX] Fix reductions in split axes

* A test case for the problem

* Fix the fix: skip loops that are related to reduction AND are unrelated to axis
  • Loading branch information
sgrechanik-h authored and merrymercy committed Feb 17, 2019
1 parent 5372ba0 commit 0b5c9af
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/op/compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -457,11 +457,11 @@ ComputeLoopNest ComputeLoopNest::make(
ret.init_vmap[iv] = ret.main_vmap.at(iv);
}
ret.num_common_loop = begin_loop;
// skip loops that does not relates to axis.
// skip loops that are related to reduction and are unrelated to axis.
std::unordered_set<IterVar> skip_iter;
for (auto kv : update_state) {
int flag = kv.second;
if ((flag & 1) == 0) skip_iter.insert(kv.first);
if (flag == 2) skip_iter.insert(kv.first);
}
ret.init_nest = op::MakeLoopNest(
stage, dom_map, begin_loop, true,
Expand Down
4 changes: 2 additions & 2 deletions src/op/tensor_compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,11 @@ ComputeLoopNest MakeLoopNest(
ret.init_vmap[iv] = ret.main_vmap.at(iv);
}
ret.num_common_loop = begin_loop;
// skip loops that does not relates to axis.
// skip loops that are related to reduction and are unrelated to axis.
std::unordered_set<IterVar> skip_iter;
for (auto kv : update_state) {
int flag = kv.second;
if ((flag & 1) == 0) skip_iter.insert(kv.first);
if (flag == 2) skip_iter.insert(kv.first);
}
ret.init_nest = op::MakeLoopNest(
stage, dom_map, begin_loop, true,
Expand Down
29 changes: 28 additions & 1 deletion tests/python/unittest/test_schedule_schedule_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import tvm

import numpy as np

def test_schedule0():
m = tvm.var('m')
Expand Down Expand Up @@ -432,6 +432,32 @@ def f(n):
s.cache_write(Y, 'local')
f = tvm.build(s, [X, Y])

def test_reduction_and_dummy_fuse_split():
n = 10
X = tvm.placeholder(shape=(n,), dtype='int32', name="X")
k = tvm.reduce_axis((0, n))
Y = tvm.compute((), lambda: tvm.sum(X[k], k), name="Y")
s = tvm.create_schedule([Y.op])
ax = s[Y.op].fuse(*Y.op.axis)
axo, axi = s[Y.op].split(ax, nparts=20)
f = tvm.build(s, [Y, X])

args = [tvm.nd.empty((), 'int32')] + [tvm.ndarray.array(np.ones((n,), dtype='int32'))]
f(*args)
assert args[0].asnumpy() == n

n = 10
X = tvm.placeholder(shape=(n,), dtype='int32', name="X")
k = tvm.reduce_axis((0, n))
Y = tvm.compute((n,), lambda i: tvm.sum(X[k], k), name="Y")
s = tvm.create_schedule([Y.op])
ax = s[Y.op].fuse(*(list(Y.op.axis) + list(Y.op.reduce_axis)))
f = tvm.build(s, [Y, X])

args = [tvm.ndarray.array(np.ones((n,), dtype='int32'))] + \
[tvm.ndarray.array(np.ones((n,), dtype='int32'))]
f(*args)
assert np.all(args[0].asnumpy() == n)

if __name__ == "__main__":
test_loop_dep_reduce()
Expand All @@ -456,3 +482,4 @@ def f(n):
test_schedule_tensor_compute1()
test_schedule_tensor_compute2()
test_schedule_tensor_compute3()
test_reduction_and_dummy_fuse_split()

0 comments on commit 0b5c9af

Please sign in to comment.