From 96bdd3d36fb3193e5540c3dd7e312e6a6e037b65 Mon Sep 17 00:00:00 2001 From: Siju Samuel Date: Tue, 13 Nov 2018 15:49:22 +0530 Subject: [PATCH 1/2] sch and compute for reduce ops --- tests/python/relay/test_op_level4.py | 59 +++++++++++++++++++++++----- 1 file changed, 49 insertions(+), 10 deletions(-) diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index dd12dc7cff3a..fbfc7d729df4 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -106,8 +106,11 @@ def test_where(): assert zz.checked_type == relay.TensorType((3, 4), "float32") -def verify_reduce(test_func, data, axis, keepdims, exclude, output): - x = relay.var("x", relay.TensorType(data, "float32")) +def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32"): + test_func = funcs[0] + ref_func = funcs[1] + + x = relay.var("x", relay.TensorType(data, dtype)) z = test_func(x, axis, keepdims, exclude) zz = relay.ir_pass.infer_type(z) if axis: @@ -116,18 +119,54 @@ def verify_reduce(test_func, data, axis, keepdims, exclude, output): assert "keepdims=" in z.astext() if exclude: assert "exclude=" in z.astext() - out_type = "int32" if test_func in [relay.argmin, relay.argmax] else "float32" + out_type = "int32" if test_func in [relay.argmin, relay.argmax] else dtype assert zz.checked_type == relay.ty.TensorType(output, out_type) + if all(isinstance(v, tvm.expr.Var) == 1 for v in data) or len(output) == 0: + return + + func = relay.Function([x], z) + x_data = np.random.uniform(size=data).astype(dtype) + if ref_func in [np.sum]: + ref_res = ref_func(x_data + 0, axis=axis, dtype=dtype, keepdims=keepdims) + elif ref_func in [np.max, np.min, np.mean, np.prod]: + ref_res = ref_func(x_data + 0, axis=axis, keepdims=keepdims) + else: #argmin/argmax + if axis and len(axis) > 1: + return + ref_res = ref_func(x_data + 0, axis=axis, keepdims=keepdims) + + for target, ctx in ctx_list(): + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + intrp2 = relay.create_executor("debug", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(x_data) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5) + op_res2 = intrp2.evaluate(func)(x_data) + tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5) + def test_reduce_functions(): + def _with_keepdims(func): + def _wrapper(data, axis=None, keepdims=False): + if not keepdims: + return func(data, axis=axis) + else: + if axis is not None: + axis = axis[0] + out_shape = list(data.shape) + out_shape[axis] = 1 + else: + out_shape = [1 for _ in range(len(data.shape))] + return func(data, axis=axis).reshape(out_shape) + return _wrapper + d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4") - for func in [relay.sum, - relay.max, - relay.min, - relay.mean, - relay.prod, - relay.argmin, - relay.argmax]: + for func in [[relay.sum, np.sum], + [relay.max, np.max], + [relay.min, np.min], + [relay.mean, np.mean], + [relay.prod, np.prod], + [relay.argmin, _with_keepdims(np.argmin)], + [relay.argmax, _with_keepdims(np.argmax)]]: verify_reduce(func, (d1, d2, d3, d4), (2,), True, False, (d1, d2, 1, d4)) verify_reduce(func, (d1, d2, d3), (1,), True, False, (d1, 1, d3)) verify_reduce(func, (d1, d2, d3), None, True, False, (1, 1, 1)) From 08bb5a0c86ee350ebb85b8754945c42af8984461 Mon Sep 17 00:00:00 2001 From: Siju Samuel Date: Mon, 26 Nov 2018 09:51:36 +0530 Subject: [PATCH 2/2] rebased --- python/tvm/relay/op/_reduce.py | 1 + tests/python/relay/test_op_level4.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/op/_reduce.py b/python/tvm/relay/op/_reduce.py index fd18c0e71d53..5c720256bbd6 100644 --- a/python/tvm/relay/op/_reduce.py +++ b/python/tvm/relay/op/_reduce.py @@ -15,5 +15,6 @@ def _schedule_reduce(_, outs, target): _reg.register_schedule("argmin", _schedule_reduce) _reg.register_schedule("sum", _schedule_reduce) _reg.register_schedule("max", _schedule_reduce) +_reg.register_schedule("min", _schedule_reduce) _reg.register_schedule("prod", _schedule_reduce) _reg.register_schedule("mean", _schedule_reduce) diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index fbfc7d729df4..e5da48f107eb 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -173,7 +173,6 @@ def _wrapper(data, axis=None, keepdims=False): verify_reduce(func, (d1, d2, d3), (0, 1), True, False, (1, 1, d3)) verify_reduce(func, (2, 3, 4), (1,), True, False, (2, 1, 4)) verify_reduce(func, (2, 3, 4), (0, 1, 2), False, False, ()) - verify_reduce(func, (4, 4, 3), None, True, False, (1, 1, 1)) verify_reduce(func, (4, 4, 3), None, False, True, ()) verify_reduce(func, (4, 4, 3), (0, 2), False, False, (4,)) verify_reduce(func, (128, 24, 128), (0, 1), False, False, (128,))