diff --git a/nnvm/src/pass/gradient.cc b/nnvm/src/pass/gradient.cc index 2a6bd00e0e8e..0f3f57fd7cf4 100644 --- a/nnvm/src/pass/gradient.cc +++ b/nnvm/src/pass/gradient.cc @@ -115,6 +115,8 @@ Graph Gradient(Graph src) { } std::vector input_grads = grad_fun_map[ptr->op()] (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()), out_agg_grads); + CHECK_EQ((*rit)->inputs.size(), input_grads.size()) + << "Gradient function not returning enough gradient"; auto git = input_grads.begin(); for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) { output_grads[it->node.get()][it->index].grads.emplace_back(std::move(*git));