Skip to content

Commit

Permalink
[Relay] Fix interpreter for dyanmic shape input of ndarray_size (apac…
Browse files Browse the repository at this point in the history
  • Loading branch information
lixiaoquan authored and Trevor Morris committed Aug 26, 2020
1 parent a816bd3 commit 9c6ad45
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 14 deletions.
14 changes: 2 additions & 12 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
PatternFunctor<bool(const Pattern& p, const ObjectRef& v)> {
public:
Interpreter(IRModule mod, DLContext context, Target target)
: mod_(mod),
context_(context),
target_(target),
debug_op_(Op::Get("debug")),
shape_of_op_(Op::Get("shape_of")) {
: mod_(mod), context_(context), target_(target), debug_op_(Op::Get("debug")) {
engine_ = CompileEngine::Global();
}

Expand Down Expand Up @@ -481,12 +477,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,

Array<Shape> out_shapes;
auto ret_type = func->body->checked_type();
bool is_dyn = IsDynamic(func->checked_type());
if (call_node->op == shape_of_op_) {
// The output shape of shape_of must be static since Relay doesn't support
// dynamic rank tensors.
is_dyn = false;
}
bool is_dyn = IsDynamic(ret_type);

if (is_dyn) {
CHECK(func->HasNonzeroAttr(attr::kPrimitive));
Expand Down Expand Up @@ -722,7 +713,6 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
CompileEngine engine_;
// Cache ops that need to be frequently used later to reduce lookup overhead.
const Op& debug_op_;
const Op& shape_of_op_;
};

TypedPackedFunc<ObjectRef(Expr)> CreateInterpreter(IRModule mod, DLContext context, Target target) {
Expand Down
22 changes: 20 additions & 2 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ def test_mixed_input_type():
assert result.asnumpy().shape == ref_out_shape, \
"Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))

def verify_any_crop_and_resize(data_shape, boxes_shape, box_indices_shape, crop_size,
def verify_any_crop_and_resize(data_shape, boxes_shape, box_indices_shape, crop_size,
layout, static_boxes, static_box_indices_shape, ref_out_shape):
mod = tvm.IRModule()
dtype = "float32"
Expand Down Expand Up @@ -872,6 +872,24 @@ def test_any_mirror_pad():
static_data_shape=(1, 256, 232, 232),
ref_out_shape=(1, 256, 234, 234))

def verify_any_ndarray_size(data_np_shape):
v = relay.var("v", shape=any_dims(len(data_np_shape)), dtype='float32')
n = relay.ndarray_size(v, dtype='int32')
mod = tvm.IRModule()
mod['main'] = relay.Function([v], n)
np_data = np.zeros(data_np_shape, dtype='float32')
ref_res = np.size(np_data)

for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(np_data)
tvm.testing.assert_allclose(result.asnumpy(), ref_res)

def test_any_ndarray_size():
verify_any_ndarray_size((2,))
verify_any_ndarray_size((2, 2))
verify_any_ndarray_size((1, 2, 3, 4))

if __name__ == "__main__":
test_any_full()
test_any_full_like()
Expand Down Expand Up @@ -908,4 +926,4 @@ def test_any_mirror_pad():
test_mixed_input_type()
test_any_crop_and_resize()
test_any_mirror_pad()

test_any_ndarray_size()

0 comments on commit 9c6ad45

Please sign in to comment.