diff --git a/nnvm/include/nnvm/op_attr_types.h b/nnvm/include/nnvm/op_attr_types.h index b29ade88a034..220c1c11943a 100644 --- a/nnvm/include/nnvm/op_attr_types.h +++ b/nnvm/include/nnvm/op_attr_types.h @@ -107,6 +107,21 @@ using FInferType = FInferNodeEntryAttr; using FBackwardOutToInIndex = std::function< std::vector (const NodeAttrs& attrs)>; +/*! + * \brief Whether this op is an explicit backward operator, + * Returns list of input index that corresponds to the outputs of the forward operator. + * + * If FBackwardInGradIndex exists: + * - The first control_deps of the node points to the corresponding forward operator. + * - The FBackwardInGradIndex[i]-th input of backward op corresponds to the i-th + * output of forward operator. + * + * \note Register under "FBackwardInGradIndex" + * This enables easier shape/type inference for backward operators. + */ +using FBackwardInGradIndex = std::function< + std::vector (const NodeAttrs& attrs)>; + /*! * \brief Get possible inplace options. * This function enables optimization to reuse memory of inputs in output. diff --git a/nnvm/src/pass/infer_shape_type.cc b/nnvm/src/pass/infer_shape_type.cc index 445787f7e13f..d9f1a0ebe34f 100644 --- a/nnvm/src/pass/infer_shape_type.cc +++ b/nnvm/src/pass/infer_shape_type.cc @@ -27,6 +27,8 @@ Graph InferAttr(Graph &&ret, Op::GetAttr >(infer_name); static auto& backward_map = Op::GetAttr("FBackwardOutToInIndex"); + static auto& backward_in_grad = + Op::GetAttr("FBackwardInGradIndex"); // reshape shape vector AttrVector rshape; if (ret.attrs.count(attr_name) != 0) { @@ -54,7 +56,6 @@ Graph InferAttr(Graph &&ret, } // Temp space for shape inference. std::vector ishape, oshape; - size_t num_unknown; // inference step function for nid auto infer_step = [&](uint32_t nid) { @@ -76,21 +77,29 @@ Graph InferAttr(Graph &&ret, } else if (backward_map.count(inode.source->op())) { // Backward operator inference. CHECK_GE(inode.control_deps.size(), 1) - << "BackwardOp need to have control_deps to its forward op"; + << "BackwardOp need to have control_deps to its forward op"; const IndexedGraph::Node& fnode = idx[inode.control_deps[0]]; // Inference the outputs of backward operator (equal to the inputs // of its corresponding forward operator). std::vector out_map = backward_map[inode.source->op()](inode.source->attrs); - bool known = true; for (size_t i = 0; i < out_map.size(); ++i) { uint32_t in_id = out_map[i]; CHECK_LT(in_id, fnode.inputs.size()); rshape[idx.entry_id(nid, i)] = rshape[idx.entry_id(fnode.inputs[in_id])]; - if (fis_none(rshape[idx.entry_id(nid, i)])) known = false; } - num_unknown += !known; + if (backward_in_grad.count(inode.source->op())) { + std::vector in_grad = + backward_in_grad[inode.source->op()](inode.source->attrs); + CHECK_LE(in_grad.size(), fnode.source->num_outputs()); + for (size_t i = 0; i < in_grad.size(); ++i) { + uint32_t eid = idx.entry_id(inode.inputs[in_grad[i]]); + if (fis_none(rshape[eid])) { + rshape[eid] = rshape[idx.entry_id(inode.control_deps[0], i)]; + } + } + } } else { bool forward_known = true; // Forward operator inference. @@ -112,7 +121,6 @@ Graph InferAttr(Graph &&ret, // Call inference function of the operator. forward_known = finfer(inode.source->attrs, &ishape, &oshape); } - num_unknown += !forward_known; // Save to the result map. for (uint32_t i = 0; i < num_inputs; ++i) { rshape[idx.entry_id(inode.inputs[i])] = ishape[i]; @@ -123,16 +131,24 @@ Graph InferAttr(Graph &&ret, } }; - num_unknown = 0; - for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { - infer_step(nid); - } - if (num_unknown != 0) { + size_t num_unknown = 0; + const int kMaxStep = 3; + for (int i = 0; i < kMaxStep; ++i) { + if (i % 2 == 0) { + for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { + infer_step(nid); + } + } else { + // backward inference + for (uint32_t i = idx.num_nodes(); i != 0; --i) { + infer_step(i - 1); + } + } num_unknown = 0; - // backward inference - for (uint32_t i = idx.num_nodes(); i != 0; --i) { - infer_step(i - 1); + for (size_t i = 0; i < idx.num_node_entries(); ++i) { + if (fis_none(rshape[i])) ++num_unknown; } + if (num_unknown == 0) break; } // set the shapes ret.attrs[attr_name] = std::make_shared(std::move(rshape));