Skip to content
This repository has been archived by the owner on Jan 3, 2023. It is now read-only.

Adding test to catch zero length subgraph bug #84

Merged
merged 1 commit into from
Nov 30, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions tests/cpp/ngraph/test_ngraph_graph.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,18 @@ TEST_F(NGRAPH_GRAPH, GRAPH_IDENTIFY_SUBGRAPHS) {
EXPECT_EQ(branching_graph.nodes_[3]->subgraph_, -1);
EXPECT_EQ(branching_graph.nodes_[4]->subgraph_, 1);
EXPECT_EQ(branching_graph.nodes_[5]->subgraph_, 1);
EXPECT_EQ(branching_graph.nodes_[6]->subgraph_, 0);
}

TEST_F(NGRAPH_GRAPH, GRAPH_COLLAPSE_SUBGRAPHS) {
IdentifySubgraphs(branching_graph, isop);
CollapseSubgraphs(branching_graph);
EXPECT_EQ(branching_graph.nodes_.size(), 4);
EXPECT_EQ(std::dynamic_pointer_cast<Graph>(branching_graph.nodes_.back())
->nodes_.size(),
3);
auto size = branching_graph.nodes_.size();
EXPECT_EQ(size, 5);
auto subgraph =
std::dynamic_pointer_cast<Graph>(branching_graph.nodes_[size - 2]);
EXPECT_NE(subgraph, nullptr);
EXPECT_EQ(subgraph->nodes_.size(), 3);
}

// TEST(NGRAPH_GRAPH, PARSENNVM) {
Expand Down
32 changes: 16 additions & 16 deletions tests/cpp/ngraph/test_ngraph_graph.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
namespace ngraph_bridge {

class NGRAPH_NODE : public ::testing::Test {
protected:
virtual void SetUp() {
protected:
virtual void SetUp() {
auto var_node = std::make_shared<VariableNode>(test_input, "test_input");
test_inputs.push_back(var_node);
};

virtual void TearDown() {};
virtual void TearDown(){};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised clang format removed a space in this case and kept a space in other cases.


nnvm::NodePtr test_node;
nnvm::NodePtr test_input;
Expand All @@ -35,8 +35,7 @@ class NGRAPH_NODE : public ::testing::Test {
};

class NGRAPH_GRAPH : public ::testing::Test {
protected:

protected:
static bool isop(NodePtr s) { return (s->type_ == NodeType::kOp); };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, here it didn't touch the space. Just wondering out loud.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:( Who knows.


void CreateLinear() {
Expand All @@ -56,14 +55,15 @@ class NGRAPH_GRAPH : public ::testing::Test {
new OpNode(nullptr, "op0", opnames[0], {branching_graph.nodes_[0]})));
branching_graph.AddNode(std::shared_ptr<OpNode>(
new OpNode(nullptr, "op1", opnames[1], {branching_graph.nodes_[1]})));
branching_graph.AddNode(std::shared_ptr<VariableNode>(new VariableNode(
nullptr, "variable1", {branching_graph.nodes_[1]})));
branching_graph.AddNode(std::shared_ptr<VariableNode>(
new VariableNode(nullptr, "variable1", {branching_graph.nodes_[1]})));
branching_graph.AddNode(std::shared_ptr<OpNode>(
new OpNode(nullptr, "op2", opnames[2],
{branching_graph.nodes_[2], branching_graph.nodes_[3]})));
branching_graph.AddNode(std::shared_ptr<OpNode>(
new OpNode(nullptr, "op3", opnames[3],
{branching_graph.nodes_[4]})));
new OpNode(nullptr, "op3", opnames[3], {branching_graph.nodes_[4]})));
branching_graph.AddNode(std::shared_ptr<VariableNode>(
new VariableNode(nullptr, "variable2", {branching_graph.nodes_[5]})));
};

void CreateMultiOut() {
Expand All @@ -77,8 +77,8 @@ class NGRAPH_GRAPH : public ::testing::Test {
for (auto n : inputs[i]) input_nodes.push_back(multi_graph.nodes_[n]);

if (is_op[i]) {
multi_graph.AddNode(std::shared_ptr<OpNode>(
new OpNode(nullptr, "op" + std::to_string(i), "tanh", input_nodes)));
multi_graph.AddNode(std::shared_ptr<OpNode>(new OpNode(
nullptr, "op" + std::to_string(i), "tanh", input_nodes)));
} else {
multi_graph.AddNode(std::make_shared<VariableNode>(
nullptr, "variable" + std::to_string(i), input_nodes));
Expand All @@ -97,14 +97,14 @@ class NGRAPH_GRAPH : public ::testing::Test {
{5, 6}, {7}, {8}, {9}, {10}, {11},
{12}, {13, 14}, {14, 15}, {16, 17}, {17, 18}, {},
{19, 20, 21}, {21, 22, 23}};

for (size_t i = 0; i < is_op.size(); ++i) {
std::vector<NodePtr> input_nodes;
for (auto n : inputs[i]) input_nodes.push_back(complex_graph.nodes_[n]);

if (is_op[i]) {
complex_graph.AddNode(std::shared_ptr<OpNode>(
new OpNode(nullptr, "op" + std::to_string(i), "tanh", input_nodes)));
complex_graph.AddNode(std::shared_ptr<OpNode>(new OpNode(
nullptr, "op" + std::to_string(i), "tanh", input_nodes)));
} else {
complex_graph.AddNode(std::make_shared<VariableNode>(
nullptr, "variable" + std::to_string(i), input_nodes));
Expand All @@ -123,7 +123,7 @@ class NGRAPH_GRAPH : public ::testing::Test {
CreateComplexGraph();
};

virtual void TearDown() {};
virtual void TearDown(){};

nnvm::NodePtr test_node;

Expand All @@ -139,4 +139,4 @@ class NGRAPH_GRAPH : public ::testing::Test {
Graph complex_graph;
};

}
} // namespace ngraph_bridge