From ba267925a230e69cb59ad283469fa5b6f43aebe5 Mon Sep 17 00:00:00 2001 From: Li Xiaoquan Date: Mon, 3 Aug 2020 11:43:34 +0800 Subject: [PATCH] [Relay] Allow relay.full to output scalar when using vm --- src/relay/backend/vm/compiler.cc | 7 ------- tests/python/relay/test_op_level3.py | 9 ++++++--- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index b811911b4053..81b1aed9b465 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -302,13 +302,6 @@ class VMFunctionCompiler : ExprFunctor { void VisitExpr_(const ConstantNode* const_node) { // Check the shape is valid NDArray data = const_node->data; - const DLTensor* tensor = data.operator->(); - if (tensor->ndim > 0) { - int64_t* shapes = reinterpret_cast(tensor->shape); - for (auto i = 0; i < tensor->ndim; i++) { - CHECK_GT(shapes[i], 0U); - } - } size_t konst_idx = context_->constants.size(); context_->constants.push_back(const_node->data); Emit(Instruction::LoadConst(konst_idx, NewRegister())); diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 76f10d6c1a18..833b66cd19d7 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -446,13 +446,16 @@ def verify_full(fill_value, src_shape, dtype): z = relay.full(x, src_shape, dtype) func = relay.Function([x], z) ref_res = np.full(src_shape, fill_value) + mod = tvm.IRModule() + mod['main'] = func for target, ctx in ctx_list(): - for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, ctx=ctx, target=target) - op_res = intrp.evaluate(func)(np.array(fill_value, dtype)) + for kind in ["graph", "debug", "vm"]: + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(np.array(fill_value, dtype)) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) verify_full(4, (1, 3, 4, 4), "int32") verify_full(4.0, (1, 4), "float32") + verify_full(4.0, (), "float32") def test_full_like_infer_type():