Skip to content

Commit

Permalink
sch and compute for reduce ops
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel committed Nov 26, 2018
1 parent b903834 commit 96bdd3d
Showing 1 changed file with 49 additions and 10 deletions.
59 changes: 49 additions & 10 deletions tests/python/relay/test_op_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,11 @@ def test_where():
assert zz.checked_type == relay.TensorType((3, 4), "float32")


def verify_reduce(test_func, data, axis, keepdims, exclude, output):
x = relay.var("x", relay.TensorType(data, "float32"))
def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32"):
test_func = funcs[0]
ref_func = funcs[1]

x = relay.var("x", relay.TensorType(data, dtype))
z = test_func(x, axis, keepdims, exclude)
zz = relay.ir_pass.infer_type(z)
if axis:
Expand All @@ -116,18 +119,54 @@ def verify_reduce(test_func, data, axis, keepdims, exclude, output):
assert "keepdims=" in z.astext()
if exclude:
assert "exclude=" in z.astext()
out_type = "int32" if test_func in [relay.argmin, relay.argmax] else "float32"
out_type = "int32" if test_func in [relay.argmin, relay.argmax] else dtype
assert zz.checked_type == relay.ty.TensorType(output, out_type)

if all(isinstance(v, tvm.expr.Var) == 1 for v in data) or len(output) == 0:
return

func = relay.Function([x], z)
x_data = np.random.uniform(size=data).astype(dtype)
if ref_func in [np.sum]:
ref_res = ref_func(x_data + 0, axis=axis, dtype=dtype, keepdims=keepdims)
elif ref_func in [np.max, np.min, np.mean, np.prod]:
ref_res = ref_func(x_data + 0, axis=axis, keepdims=keepdims)
else: #argmin/argmax
if axis and len(axis) > 1:
return
ref_res = ref_func(x_data + 0, axis=axis, keepdims=keepdims)

for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5)
op_res2 = intrp2.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)

def test_reduce_functions():
def _with_keepdims(func):
def _wrapper(data, axis=None, keepdims=False):
if not keepdims:
return func(data, axis=axis)
else:
if axis is not None:
axis = axis[0]
out_shape = list(data.shape)
out_shape[axis] = 1
else:
out_shape = [1 for _ in range(len(data.shape))]
return func(data, axis=axis).reshape(out_shape)
return _wrapper

d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4")
for func in [relay.sum,
relay.max,
relay.min,
relay.mean,
relay.prod,
relay.argmin,
relay.argmax]:
for func in [[relay.sum, np.sum],
[relay.max, np.max],
[relay.min, np.min],
[relay.mean, np.mean],
[relay.prod, np.prod],
[relay.argmin, _with_keepdims(np.argmin)],
[relay.argmax, _with_keepdims(np.argmax)]]:
verify_reduce(func, (d1, d2, d3, d4), (2,), True, False, (d1, d2, 1, d4))
verify_reduce(func, (d1, d2, d3), (1,), True, False, (d1, 1, d3))
verify_reduce(func, (d1, d2, d3), None, True, False, (1, 1, 1))
Expand Down

0 comments on commit 96bdd3d

Please sign in to comment.