Skip to content

Commit

Permalink
[Relay] shape func for zeros, zeros_like, ones, ones_like (apache#4448)
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww authored and tmoreau89 committed Dec 3, 2019
1 parent 02b54a2 commit f276728
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 5 deletions.
18 changes: 18 additions & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,20 @@ def cast_shape_func(attrs, inputs, out_ndims):
return [_cast_shape_function(*inputs)]

# shape func
@script
def _full_shape_func(x):
out_ndim = len(x)
out = output_tensor((out_ndim,), "int64")
for i in const_range(out_ndim):
out[i] = x[i]
return out

def full_shape_func(attrs, inputs, out_ndims):
"""
Shape func for zeros, zeros_like, ones, ones_like.
"""
return [_full_shape_func(*inputs)]

@script
def _broadcast_shape_func(x, y, ndim):
out = output_tensor((ndim,), "int64")
Expand Down Expand Up @@ -162,6 +176,10 @@ def elemwise_shape_func(attrs, inputs, _):
return [topi.math.identity(inputs[0])]

register_shape_func("cast", False, cast_shape_func)
register_shape_func("zeros", False, full_shape_func)
register_shape_func("zeros_like", False, full_shape_func)
register_shape_func("ones", False, full_shape_func)
register_shape_func("ones_like", False, full_shape_func)

register_shape_func("add", False, broadcast_shape_func)
register_shape_func("subtract", False, broadcast_shape_func)
Expand Down
38 changes: 33 additions & 5 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ def verify_any_broadcast(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op):
mod["main"] = relay.Function([x, y], op(x, y))
x_np = np.random.uniform(size=x_np_shape).astype(dtype)
y_np = np.random.uniform(size=y_np_shape).astype(dtype)
res_np = np_op(x_np, y_np)
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(x_np, y_np)
tvm.testing.assert_allclose(result.asnumpy(), np_op(x_np, y_np))
tvm.testing.assert_allclose(result.asnumpy(), res_np)

def test_any_broadcast():
# Test broadcast with 1s
Expand Down Expand Up @@ -77,6 +78,32 @@ def check_fail(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op):
check_fail((relay.Any(),), (3, 2), (2), (4, 2), relay.add, np.add)


def verify_any_full(x_shape, x_np_shape, relay_op, np_op, dtype='float32'):
x = relay.var('x', shape=x_shape, dtype=dtype)
mod = relay.module.Module()
mod['main'] = relay.Function([x], relay.zeros_like(x))
x_np = np.random.uniform(size=x_np_shape).astype(dtype)
res_np = np.zeros_like(x_np)
for kind in ['debug', 'vm']:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target='llvm')
result = ex.evaluate()(x_np).asnumpy()
tvm.testing.assert_allclose(result, res_np)

def test_any_full():
# zeros, zeros_like, ones, ones_like
verify_any_full(any_dims(3), (2, 3, 5), relay.zeros, np.zeros, "float32")
verify_any_full(any_dims(3), (225, 115, 15), relay.zeros, np.zeros, "float32")
verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.zeros, np.zeros, "int32")
verify_any_full(any_dims(3), (2, 3, 5), relay.zeros_like, np.zeros_like, "float32")
verify_any_full(any_dims(3), (225, 115, 15), relay.zeros_like, np.zeros_like, "float32")
verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.zeros_like, np.zeros_like, "int32")
verify_any_full(any_dims(3), (2, 3, 5), relay.ones, np.ones, "float32")
verify_any_full(any_dims(3), (225, 115, 15), relay.ones, np.ones, "float32")
verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.ones, np.ones, "int32")
verify_any_full(any_dims(3), (2, 3, 5), relay.ones_like, np.ones_like, "float32")
verify_any_full(any_dims(3), (225, 115, 15), relay.ones_like, np.ones_like, "float32")
verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.ones_like, np.ones_like, "int32")

def test_any_concat():
x = relay.var('x', shape=(relay.Any(), 2), dtype="float32")
y = relay.var('y', shape=(1, 2), dtype="float32")
Expand All @@ -85,10 +112,10 @@ def test_any_concat():
mod["main"] = relay.Function([x, y], z)
x_np = np.random.uniform(size=(3, 2)).astype('float32')
y_np = np.random.uniform(size=(1, 2)).astype('float32')
ref = np.concatenate([x_np, y_np], axis=0)
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(x_np, y_np)
ref = np.concatenate([x_np, y_np], axis=0)
tvm.testing.assert_allclose(result.asnumpy(), ref)

def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape):
Expand Down Expand Up @@ -116,10 +143,10 @@ def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"):
mod = relay.module.Module()
mod["main"] = relay.Function([x], y)
data = np.random.choice([0, 1, 2, 3], size=x_np_shape).astype(dtype)
expected = np.argwhere(data)
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data).asnumpy()
expected = np.argwhere(data)
assert result.shape == expected.shape
tvm.testing.assert_allclose(result.flatten(), expected.flatten())

Expand Down Expand Up @@ -412,10 +439,10 @@ def verify_any_pad(data_shape, pad_width, static_data_shape):
y = relay.nn.pad(data, pad_width)
mod["main"] = relay.Function([data], y)
data_np = np.random.uniform(size=static_data_shape).astype(dtype)
ref_out = np.pad(data_np, pad_width)
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data_np)
ref_out = np.pad(data_np, pad_width)
tvm.testing.assert_allclose(result.asnumpy(), ref_out)

def test_any_pad():
Expand Down Expand Up @@ -497,12 +524,12 @@ def _body(i, st):
mod = relay.module.Module()
mod["main"] = func
data = np.array(0.0, dtype='int32')
ref = np.array([0] + list(range(10))).reshape((11, 1)).astype("int32")
# TODO(@jroesch): After LambdaLift pass, TypeInfer pass will fail
# so currently we cannot run this test case on VM
for kind in ["debug"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data)
ref = np.array([0] + list(range(10))).reshape((11, 1)).astype("int32")
np.testing.assert_allclose(result.asnumpy(), ref)

def test_recursive_concat_with_wrong_annotation():
Expand Down Expand Up @@ -553,6 +580,7 @@ def _body(i, st):
assert "in particular dimension 0 conflicts 2 does not match 1" in str(e)

if __name__ == "__main__":
test_any_full()
test_any_broadcast()
test_any_broadcast_fail()
test_any_concat()
Expand Down

0 comments on commit f276728

Please sign in to comment.