From 1644b8aa2c562a6a2b21333679678096106320e5 Mon Sep 17 00:00:00 2001 From: Li Xiaoquan Date: Fri, 17 Jul 2020 14:18:30 +0800 Subject: [PATCH] [Relay] Fix interpreter for dyanmic shape input of ndarray_size --- src/relay/backend/interpreter.cc | 7 ++++++- tests/python/relay/test_any.py | 19 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 9a75c0ab76eed..00bf5eb75152b 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -217,7 +217,8 @@ class Interpreter : public ExprFunctor, context_(context), target_(target), debug_op_(Op::Get("debug")), - shape_of_op_(Op::Get("shape_of")) { + shape_of_op_(Op::Get("shape_of")), + ndarray_size_op_(Op::Get("ndarray_size")) { engine_ = CompileEngine::Global(); } @@ -486,6 +487,9 @@ class Interpreter : public ExprFunctor, // The output shape of shape_of must be static since Relay doesn't support // dynamic rank tensors. is_dyn = false; + } else if (call_node->op == ndarray_size_op_) { + // The output shape of ndarray_size is static + is_dyn = false; } if (is_dyn) { @@ -723,6 +727,7 @@ class Interpreter : public ExprFunctor, // Cache ops that need to be frequently used later to reduce lookup overhead. const Op& debug_op_; const Op& shape_of_op_; + const Op& ndarray_size_op_; }; TypedPackedFunc CreateInterpreter(IRModule mod, DLContext context, Target target) { diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 6d940a5635664..31f60dc4797e5 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -816,6 +816,24 @@ def test_mixed_input_type(): assert result.asnumpy().shape == ref_out_shape, \ "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) +def verify_any_ndarray_size(data_np_shape): + v = relay.var("v", shape=any_dims(len(data_np_shape)), dtype='float32') + n = relay.ndarray_size(v, dtype='int32') + mod = tvm.IRModule() + mod['main'] = relay.Function([v], n) + np_data = np.zeros(data_np_shape, dtype='float32') + ref_res = np.size(np_data) + + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(np_data) + tvm.testing.assert_allclose(result.asnumpy(), ref_res) + +def test_any_ndarray_size(): + verify_any_ndarray_size((2,)) + verify_any_ndarray_size((2, 2)) + verify_any_ndarray_size((1, 2, 3, 4)) + if __name__ == "__main__": test_any_full() test_any_full_like() @@ -850,3 +868,4 @@ def test_mixed_input_type(): test_recursive_concat_with_wrong_annotation() test_tuple_get_item() test_mixed_input_type() + test_any_ndarray_size()