From 782e3e0f2720ffef0ac19e2830a6ce05556a04eb Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Wed, 12 Aug 2020 00:32:50 +0000 Subject: [PATCH] [TOPI] Fix reduction --- python/tvm/topi/cuda/reduction.py | 2 ++ tests/python/relay/test_pass_fuse_ops.py | 28 ++++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/python/tvm/topi/cuda/reduction.py b/python/tvm/topi/cuda/reduction.py index 38e30867b7917..664ea441141b7 100644 --- a/python/tvm/topi/cuda/reduction.py +++ b/python/tvm/topi/cuda/reduction.py @@ -139,6 +139,8 @@ def traverse_after_reduce(operator): for tensor in input_tensors: if tensor.op not in scheduled_ops: traverse_before_reduce(tensor.op) + elif isinstance(operator, tvm.te.PlaceholderOp): + pass else: raise RuntimeError("Unsupported operator: %s" % operator.tag) diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index f4369c1f1d904..986f51b459d54 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -694,6 +694,33 @@ def expected(): assert tvm.ir.structural_equal(m["main"], after) +def test_fuse_bcast_reduce_scalar(): + """Test fusion case with broadcast and reduction involving scalar""" + + def before(): + x = relay.var("x", shape=(), dtype="int32") + less = relay.less(x, relay.const(10, dtype="int32")) + z = relay.min(less) + return relay.Function([x], z) + + def expected(): + p0 = relay.var("p0", shape=(), dtype="int32") + less = relay.less(p0, relay.const(10, dtype="int32")) + z0 = relay.min(less) + f0 = relay.Function([p0], z0) + f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + + x = relay.var("x", shape=(), dtype="int32") + f = relay.Call(f0, [x]) + return relay.Function([x], f) + + orig = before() + m = fuse2(tvm.IRModule.from_expr(orig)) + relay.build(m, "cuda") + after = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(m["main"], after) + + if __name__ == "__main__": test_fuse_simple() test_conv2d_fuse() @@ -712,3 +739,4 @@ def expected(): test_fuse_max() test_fuse_take() test_fuse_gather_nd() + test_fuse_bcast_reduce_scalar()