Skip to content

Commit

Permalink
[TOPI] Fix reduction (apache#6250)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics authored and wjliu1998 committed Aug 13, 2020
1 parent b3ff294 commit 6807703
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/tvm/topi/cuda/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def traverse_after_reduce(operator):
for tensor in input_tensors:
if tensor.op not in scheduled_ops:
traverse_before_reduce(tensor.op)
elif isinstance(operator, tvm.te.PlaceholderOp):
pass
else:
raise RuntimeError("Unsupported operator: %s" % operator.tag)

Expand Down
29 changes: 29 additions & 0 deletions tests/python/relay/test_pass_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,34 @@ def expected():
assert tvm.ir.structural_equal(m["main"], after)


def test_fuse_bcast_reduce_scalar():
"""Test fusion case with broadcast and reduction involving scalar"""

def before():
x = relay.var("x", shape=(), dtype="int32")
less = relay.less(x, relay.const(10, dtype="int32"))
z = relay.min(less)
return relay.Function([x], z)

def expected():
p0 = relay.var("p0", shape=(), dtype="int32")
less = relay.less(p0, relay.const(10, dtype="int32"))
z0 = relay.min(less)
f0 = relay.Function([p0], z0)
f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))

x = relay.var("x", shape=(), dtype="int32")
f = relay.Call(f0, [x])
return relay.Function([x], f)

orig = before()
m = fuse2(tvm.IRModule.from_expr(orig))
for tgt, _ in tvm.relay.testing.config.ctx_list():
relay.build(m, tgt)
after = run_opt_pass(expected(), transform.InferType())
assert tvm.ir.structural_equal(m["main"], after)


if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
Expand All @@ -712,3 +740,4 @@ def expected():
test_fuse_max()
test_fuse_take()
test_fuse_gather_nd()
test_fuse_bcast_reduce_scalar()

0 comments on commit 6807703

Please sign in to comment.