diff --git a/nnvm/python/nnvm/graph.py b/nnvm/python/nnvm/graph.py index b8cb655c9b65..5a7b36342a36 100644 --- a/nnvm/python/nnvm/graph.py +++ b/nnvm/python/nnvm/graph.py @@ -177,8 +177,23 @@ def index(self): self._index = GraphIndex(self) return self._index - def graphir(self): - """Get text form of graph ir.""" + def ir(self, join_entry_attrs=None, join_node_attrs=None): + """Get text form of graph ir. + + Parameters + ---------- + join_entry_attrs : list of str + List of graph NodeEntry attribute to be + printed along each operator. + + join_node_attrs : list of str + List of graph node attribute to be + printed along each operator. + """ + if join_entry_attrs: + self._set_json_attr("join_entry_attrs", join_entry_attrs, "list_str") + if join_node_attrs: + self._set_json_attr("join_node_attrs", join_node_attrs, "list_str") return self.apply("PrintGraphIR").json_attr("graphir") def apply(self, passes): diff --git a/nnvm/src/pass/infer_shape_type.cc b/nnvm/src/pass/infer_shape_type.cc index 0332532717e6..fd9b77c42c3f 100644 --- a/nnvm/src/pass/infer_shape_type.cc +++ b/nnvm/src/pass/infer_shape_type.cc @@ -67,6 +67,8 @@ Graph InferAttr(Graph &&ret, shape_attr_key = ret.GetAttr(attr_key_name); // erase the provided arguments ret.attrs.erase(attr_key_name); + } else { + shape_attr_key = attr_name; } // Temp space for shape inference. std::vector ishape, oshape; diff --git a/nnvm/src/pass/print_graph_ir.cc b/nnvm/src/pass/print_graph_ir.cc index a29ee922b644..6a42aabce616 100644 --- a/nnvm/src/pass/print_graph_ir.cc +++ b/nnvm/src/pass/print_graph_ir.cc @@ -5,14 +5,80 @@ */ #include #include +#include #include namespace nnvm { namespace pass { +using AttrPrinter = std::function; // NOLINT(*) + +template +AttrPrinter GetVectorPrinter_(const T& vec) { + return [&vec](uint32_t index, std::ostream& os) { // NOLINT(*) + os << vec[index]; + }; +} + +AttrPrinter GetVectorPrinter(const Graph& graph, + const std::string& key) { + auto it = graph.attrs.find(key); + CHECK(it != graph.attrs.end()) + << "Cannot find " << key << " in graph attr"; + const any& value = *(it->second); + if (value.type() == typeid(std::vector)) { + return GetVectorPrinter_( + nnvm::get >(value)); + } else if (value.type() == typeid(std::vector)) { + return GetVectorPrinter_( + nnvm::get >(value)); + } else if (value.type() == typeid(std::vector)) { + return GetVectorPrinter_( + nnvm::get >(value)); + } else { + LOG(FATAL) << "Cannot handle type " << value.type().name(); + return nullptr; + } +} + + // print the graph ir in readable format -void PrintGraphIR_(Graph src, std::ostream& os) { // NOLINT(*) +void PrintGraphIR_(Graph src, + const std::vector& join_entry_attrs, + const std::vector& join_node_attrs, + std::ostream& os) { // NOLINT(*) const IndexedGraph& idx = src.indexed_graph(); + std::vector > trigger; // NOLINT(*) + + for (const std::string& key : join_entry_attrs) { + AttrPrinter fp = GetVectorPrinter(src, key); + auto fprint = [&idx, key, fp]( + uint32_t nid, std::ostream& os) { // NOLINT(*) + const IndexedGraph::Node& inode = idx[nid]; + os << ", " << key << "="; + if (inode.source->num_outputs() != 1) { + os << '['; + for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) { + if (i != 0) os << ", "; + fp(idx.entry_id(nid, i), os); + } + os << ']'; + } else { + fp(idx.entry_id(nid, 0), os); + } + }; + trigger.push_back(fprint); + } + for (const std::string& key : join_node_attrs) { + AttrPrinter fp = GetVectorPrinter(src, key); + auto fprint = [&idx, key, fp]( + uint32_t nid, std::ostream& os) { // NOLINT(*) + os << key << "="; + fp(idx.entry_id(nid, 0), os); + }; + trigger.push_back(fprint); + } + os << "Graph("; if (idx.input_nodes().size() < 4) { for (size_t i = 0; i < idx.input_nodes().size(); ++i) { @@ -79,6 +145,10 @@ void PrintGraphIR_(Graph src, std::ostream& os) { // NOLINT(*) } os << "]"; } + // additional attribute trigger + for (const auto& fp : trigger) { + fp(nid, os); + } os << "\n"; } os << " ret "; @@ -112,7 +182,16 @@ void PrintGraphIR_(Graph src, std::ostream& os) { // NOLINT(*) // save a graph to json Graph PrintGraphIR(Graph src) { std::ostringstream os; - PrintGraphIR_(src, os); + std::vector join_entry_attrs, join_node_attrs; + if (src.attrs.count("join_entry_attrs") != 0) { + join_entry_attrs = src.MoveCopyAttr >( + "join_entry_attrs"); + } + if (src.attrs.count("join_node_attrs") != 0) { + join_node_attrs = src.MoveCopyAttr >( + "join_node_attrs"); + } + PrintGraphIR_(src, join_entry_attrs, join_node_attrs, os); Graph ret; ret.attrs["graphir"] = std::make_shared(os.str()); return ret; diff --git a/nnvm/tests/python/compiler/test_simplify_batchnorm.py b/nnvm/tests/python/compiler/test_simplify_batchnorm.py index 54cb8bc0bc57..67d582769ef2 100644 --- a/nnvm/tests/python/compiler/test_simplify_batchnorm.py +++ b/nnvm/tests/python/compiler/test_simplify_batchnorm.py @@ -38,7 +38,7 @@ def check(dim, axis, nstep): graph_attr.set_shape_inputs(g, ishape) g1 = g.apply("InferShape").apply("SimplifyBatchNormInference") # Some prints for debug - # print(g1.graphir()) + # print(g1.ir()) # assert graph equals as expected graph_pass.check_graph_equal(g1, g2) diff --git a/nnvm/tests/python/unittest/test_graph.py b/nnvm/tests/python/unittest/test_graph.py index e706ae1b428c..f41d62538817 100644 --- a/nnvm/tests/python/unittest/test_graph.py +++ b/nnvm/tests/python/unittest/test_graph.py @@ -99,8 +99,19 @@ def test_plan_memory(): assert (storage_id[jnode_row_ptr[nindex["add2"]]] == storage_id[jnode_row_ptr[nindex["reshapek"]]]) +def test_print_graph_ir(): + x = sym.Variable("x", shape=(1, 1, 10, 20)) + y = sym.conv2d(x + 1, name="y", channels=10, kernel_size=(3,3)) + g = graph.create(y) + g = g.apply("InferShape") + ir1 = g.ir() + ir2 = g.ir(join_entry_attrs=["shape"]) + assert("y_bias" in ir1) + assert("shape=" in ir2) + if __name__ == "__main__": + test_print_graph_ir() test_json_pass_with_attr() test_graph_json_attr() test_json_pass()