Skip to content

Commit

Permalink
Better to check Infer result with topi results at build time instead …
Browse files Browse the repository at this point in the history
…of leaving to a runtime error. (#476)
  • Loading branch information
srkreddy1238 authored and tqchen committed May 29, 2018
1 parent 099dccb commit 3865410
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions nnvm/src/compiler/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,22 @@ class CompileEngine {
Array<Tensor> out = fcompute[inode.source->op()](
inode.source->attrs, op_inputs, out_info);
CHECK_EQ(out.size(), inode.source->num_outputs());

// check output dimentions also match
// This check is to make sure the NNVM operator Infer match with Compute result.
// Missing this check may pass the build but leads to runtime errors.
for (uint32_t i = 0; i < out.size(); ++i) {
CHECK_EQ(out[i].ndim(), out_info[i].ndim()) << inode.source->op()->name;
tvm::Tensor inferred_tensor = out[i];
tvm::Tensor computed_tensor = out_info[i];
for (uint32_t j = 0; j < inferred_tensor->shape.size(); ++j) {
if ((as_const_int(inferred_tensor->shape[j])) &&
(as_const_int(computed_tensor->shape[j])))
CHECK_EQ((*as_const_int(inferred_tensor->shape[j])),
(*as_const_int(computed_tensor->shape[j]))) << inode.source->op()->name;
}
}

// schedule on root node, and use master's schedule
for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
uint32_t eid = idx.entry_id(nid, index);
Expand Down

0 comments on commit 3865410

Please sign in to comment.