Skip to content

Commit

Permalink
[Relay] Fix interpreter for dyanmic shape input of ndarray_size
Browse files Browse the repository at this point in the history
  • Loading branch information
lixiaoquan committed Jul 23, 2020
1 parent 6e1b09e commit 1644b8a
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
context_(context),
target_(target),
debug_op_(Op::Get("debug")),
shape_of_op_(Op::Get("shape_of")) {
shape_of_op_(Op::Get("shape_of")),
ndarray_size_op_(Op::Get("ndarray_size")) {
engine_ = CompileEngine::Global();
}

Expand Down Expand Up @@ -486,6 +487,9 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
// The output shape of shape_of must be static since Relay doesn't support
// dynamic rank tensors.
is_dyn = false;
} else if (call_node->op == ndarray_size_op_) {
// The output shape of ndarray_size is static
is_dyn = false;
}

if (is_dyn) {
Expand Down Expand Up @@ -723,6 +727,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
// Cache ops that need to be frequently used later to reduce lookup overhead.
const Op& debug_op_;
const Op& shape_of_op_;
const Op& ndarray_size_op_;
};

TypedPackedFunc<ObjectRef(Expr)> CreateInterpreter(IRModule mod, DLContext context, Target target) {
Expand Down
19 changes: 19 additions & 0 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,24 @@ def test_mixed_input_type():
assert result.asnumpy().shape == ref_out_shape, \
"Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))

def verify_any_ndarray_size(data_np_shape):
v = relay.var("v", shape=any_dims(len(data_np_shape)), dtype='float32')
n = relay.ndarray_size(v, dtype='int32')
mod = tvm.IRModule()
mod['main'] = relay.Function([v], n)
np_data = np.zeros(data_np_shape, dtype='float32')
ref_res = np.size(np_data)

for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(np_data)
tvm.testing.assert_allclose(result.asnumpy(), ref_res)

def test_any_ndarray_size():
verify_any_ndarray_size((2,))
verify_any_ndarray_size((2, 2))
verify_any_ndarray_size((1, 2, 3, 4))

if __name__ == "__main__":
test_any_full()
test_any_full_like()
Expand Down Expand Up @@ -850,3 +868,4 @@ def test_mixed_input_type():
test_recursive_concat_with_wrong_annotation()
test_tuple_get_item()
test_mixed_input_type()
test_any_ndarray_size()

0 comments on commit 1644b8a

Please sign in to comment.