From dde9707071be035aa81c62e7470b4ec46a1b845a Mon Sep 17 00:00:00 2001 From: adstraw Date: Wed, 29 Nov 2017 13:34:14 -0800 Subject: [PATCH] Adding test to catch zero length subgraph bug --- tests/cpp/ngraph/test_ngraph_graph.cc | 11 +++++---- tests/cpp/ngraph/test_ngraph_graph.h | 32 +++++++++++++-------------- 2 files changed, 23 insertions(+), 20 deletions(-) mode change 100644 => 100755 tests/cpp/ngraph/test_ngraph_graph.cc mode change 100644 => 100755 tests/cpp/ngraph/test_ngraph_graph.h diff --git a/tests/cpp/ngraph/test_ngraph_graph.cc b/tests/cpp/ngraph/test_ngraph_graph.cc old mode 100644 new mode 100755 index 212fc9eeb..0e2860d11 --- a/tests/cpp/ngraph/test_ngraph_graph.cc +++ b/tests/cpp/ngraph/test_ngraph_graph.cc @@ -114,15 +114,18 @@ TEST_F(NGRAPH_GRAPH, GRAPH_IDENTIFY_SUBGRAPHS) { EXPECT_EQ(branching_graph.nodes_[3]->subgraph_, -1); EXPECT_EQ(branching_graph.nodes_[4]->subgraph_, 1); EXPECT_EQ(branching_graph.nodes_[5]->subgraph_, 1); + EXPECT_EQ(branching_graph.nodes_[6]->subgraph_, 0); } TEST_F(NGRAPH_GRAPH, GRAPH_COLLAPSE_SUBGRAPHS) { IdentifySubgraphs(branching_graph, isop); CollapseSubgraphs(branching_graph); - EXPECT_EQ(branching_graph.nodes_.size(), 4); - EXPECT_EQ(std::dynamic_pointer_cast(branching_graph.nodes_.back()) - ->nodes_.size(), - 3); + auto size = branching_graph.nodes_.size(); + EXPECT_EQ(size, 5); + auto subgraph = + std::dynamic_pointer_cast(branching_graph.nodes_[size - 2]); + EXPECT_NE(subgraph, nullptr); + EXPECT_EQ(subgraph->nodes_.size(), 3); } // TEST(NGRAPH_GRAPH, PARSENNVM) { diff --git a/tests/cpp/ngraph/test_ngraph_graph.h b/tests/cpp/ngraph/test_ngraph_graph.h old mode 100644 new mode 100755 index 22542b1e6..6b69b57fe --- a/tests/cpp/ngraph/test_ngraph_graph.h +++ b/tests/cpp/ngraph/test_ngraph_graph.h @@ -19,13 +19,13 @@ namespace ngraph_bridge { class NGRAPH_NODE : public ::testing::Test { -protected: - virtual void SetUp() { + protected: + virtual void SetUp() { auto var_node = std::make_shared(test_input, "test_input"); test_inputs.push_back(var_node); }; - virtual void TearDown() {}; + virtual void TearDown(){}; nnvm::NodePtr test_node; nnvm::NodePtr test_input; @@ -35,8 +35,7 @@ class NGRAPH_NODE : public ::testing::Test { }; class NGRAPH_GRAPH : public ::testing::Test { -protected: - + protected: static bool isop(NodePtr s) { return (s->type_ == NodeType::kOp); }; void CreateLinear() { @@ -56,14 +55,15 @@ class NGRAPH_GRAPH : public ::testing::Test { new OpNode(nullptr, "op0", opnames[0], {branching_graph.nodes_[0]}))); branching_graph.AddNode(std::shared_ptr( new OpNode(nullptr, "op1", opnames[1], {branching_graph.nodes_[1]}))); - branching_graph.AddNode(std::shared_ptr(new VariableNode( - nullptr, "variable1", {branching_graph.nodes_[1]}))); + branching_graph.AddNode(std::shared_ptr( + new VariableNode(nullptr, "variable1", {branching_graph.nodes_[1]}))); branching_graph.AddNode(std::shared_ptr( new OpNode(nullptr, "op2", opnames[2], {branching_graph.nodes_[2], branching_graph.nodes_[3]}))); branching_graph.AddNode(std::shared_ptr( - new OpNode(nullptr, "op3", opnames[3], - {branching_graph.nodes_[4]}))); + new OpNode(nullptr, "op3", opnames[3], {branching_graph.nodes_[4]}))); + branching_graph.AddNode(std::shared_ptr( + new VariableNode(nullptr, "variable2", {branching_graph.nodes_[5]}))); }; void CreateMultiOut() { @@ -77,8 +77,8 @@ class NGRAPH_GRAPH : public ::testing::Test { for (auto n : inputs[i]) input_nodes.push_back(multi_graph.nodes_[n]); if (is_op[i]) { - multi_graph.AddNode(std::shared_ptr( - new OpNode(nullptr, "op" + std::to_string(i), "tanh", input_nodes))); + multi_graph.AddNode(std::shared_ptr(new OpNode( + nullptr, "op" + std::to_string(i), "tanh", input_nodes))); } else { multi_graph.AddNode(std::make_shared( nullptr, "variable" + std::to_string(i), input_nodes)); @@ -97,14 +97,14 @@ class NGRAPH_GRAPH : public ::testing::Test { {5, 6}, {7}, {8}, {9}, {10}, {11}, {12}, {13, 14}, {14, 15}, {16, 17}, {17, 18}, {}, {19, 20, 21}, {21, 22, 23}}; - + for (size_t i = 0; i < is_op.size(); ++i) { std::vector input_nodes; for (auto n : inputs[i]) input_nodes.push_back(complex_graph.nodes_[n]); if (is_op[i]) { - complex_graph.AddNode(std::shared_ptr( - new OpNode(nullptr, "op" + std::to_string(i), "tanh", input_nodes))); + complex_graph.AddNode(std::shared_ptr(new OpNode( + nullptr, "op" + std::to_string(i), "tanh", input_nodes))); } else { complex_graph.AddNode(std::make_shared( nullptr, "variable" + std::to_string(i), input_nodes)); @@ -123,7 +123,7 @@ class NGRAPH_GRAPH : public ::testing::Test { CreateComplexGraph(); }; - virtual void TearDown() {}; + virtual void TearDown(){}; nnvm::NodePtr test_node; @@ -139,4 +139,4 @@ class NGRAPH_GRAPH : public ::testing::Test { Graph complex_graph; }; -} \ No newline at end of file +} // namespace ngraph_bridge