Skip to content

Commit

Permalink
add test for graph interface
Browse files Browse the repository at this point in the history
  • Loading branch information
levi131 committed Jun 23, 2021
1 parent 0b62f2d commit ec23091
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions paddle/fluid/framework/ir/graph_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,42 @@ TEST(GraphTest, TestAttrCopy) {
ASSERT_FALSE(dst_g.Has(kFloatValue));
}

TEST(GraphTest, TestInterfaceConvertAllBlocks) {
// Set FLAGS_convert_all_blocks to true to make sure this test works.
bool flag_temp = FLAGS_convert_all_blocks;
FLAGS_convert_all_blocks = true;

ProgramDesc prog;
prog.MutableBlock(0)->Var("init_var")->SetType(proto::VarType::SELECTED_ROWS);
ir::Graph g(prog);
ASSERT_TRUE(g.IsMainGraph());

const std::string kIntValue = "int_value";
const int INT_VALUE = 3;
g.Set<int>(kIntValue, new int(INT_VALUE));
ASSERT_TRUE(g.Has(kIntValue));
ASSERT_EQ(g.GetOrInit<int>(kIntValue), INT_VALUE);
ASSERT_EQ(g.Get<int>(kIntValue), INT_VALUE);
g.Erase(kIntValue);
ASSERT_TRUE(!g.Has(kIntValue));
g.SetNotOwned<int>(kIntValue, new int(INT_VALUE));
ASSERT_TRUE(g.Has(kIntValue));
g.Erase(kIntValue);

g.ReleaseNodes();
ASSERT_EQ(g.Nodes().size(), 0UL);
g.CreateVarNode(new VarDesc("temp_var_desc_name"));
g.CreateOpNode(prog.MutableBlock(0)->AppendOp());
g.CreateControlDepVar();
g.CreateEmptyNode("temp_empty_node_name", ir::Node::Type::kVariable);
ASSERT_EQ(g.Nodes().size(), 4UL);
g.RemoveNode(g.RetrieveNode(1));
ASSERT_EQ(g.Nodes().size(), 3UL);

// Recover FLAGS_convert_all_blocks.
FLAGS_convert_all_blocks = flag_temp;
}

TEST(GraphTest, TestMultiBlock) {
// Set FLAGS_convert_all_blocks to true to make sure this test works.
bool flag_temp = FLAGS_convert_all_blocks;
Expand Down

0 comments on commit ec23091

Please sign in to comment.