diff --git a/nnvm/src/compiler/compile_engine.cc b/nnvm/src/compiler/compile_engine.cc index 1ef39f851203..a9d4aa2d016a 100644 --- a/nnvm/src/compiler/compile_engine.cc +++ b/nnvm/src/compiler/compile_engine.cc @@ -202,6 +202,22 @@ class CompileEngine { Array out = fcompute[inode.source->op()]( inode.source->attrs, op_inputs, out_info); CHECK_EQ(out.size(), inode.source->num_outputs()); + + // check output dimentions also match + // This check is to make sure the NNVM operator Infer match with Compute result. + // Missing this check may pass the build but leads to runtime errors. + for (uint32_t i = 0; i < out.size(); ++i) { + CHECK_EQ(out[i].ndim(), out_info[i].ndim()) << inode.source->op()->name; + tvm::Tensor inferred_tensor = out[i]; + tvm::Tensor computed_tensor = out_info[i]; + for (uint32_t j = 0; j < inferred_tensor->shape.size(); ++j) { + if ((as_const_int(inferred_tensor->shape[j])) && + (as_const_int(computed_tensor->shape[j]))) + CHECK_EQ((*as_const_int(inferred_tensor->shape[j])), + (*as_const_int(computed_tensor->shape[j]))) << inode.source->op()->name; + } + } + // schedule on root node, and use master's schedule for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) { uint32_t eid = idx.entry_id(nid, index);