diff --git a/paddle/fluid/framework/ir/graph_test.cc b/paddle/fluid/framework/ir/graph_test.cc index 163bd996c0010..2d8de33f411a4 100644 --- a/paddle/fluid/framework/ir/graph_test.cc +++ b/paddle/fluid/framework/ir/graph_test.cc @@ -264,6 +264,42 @@ TEST(GraphTest, TestAttrCopy) { ASSERT_FALSE(dst_g.Has(kFloatValue)); } +TEST(GraphTest, TestInterfaceConvertAllBlocks) { + // Set FLAGS_convert_all_blocks to true to make sure this test works. + bool flag_temp = FLAGS_convert_all_blocks; + FLAGS_convert_all_blocks = true; + + ProgramDesc prog; + prog.MutableBlock(0)->Var("init_var")->SetType(proto::VarType::SELECTED_ROWS); + ir::Graph g(prog); + ASSERT_TRUE(g.IsMainGraph()); + + const std::string kIntValue = "int_value"; + const int INT_VALUE = 3; + g.Set(kIntValue, new int(INT_VALUE)); + ASSERT_TRUE(g.Has(kIntValue)); + ASSERT_EQ(g.GetOrInit(kIntValue), INT_VALUE); + ASSERT_EQ(g.Get(kIntValue), INT_VALUE); + g.Erase(kIntValue); + ASSERT_TRUE(!g.Has(kIntValue)); + g.SetNotOwned(kIntValue, new int(INT_VALUE)); + ASSERT_TRUE(g.Has(kIntValue)); + g.Erase(kIntValue); + + g.ReleaseNodes(); + ASSERT_EQ(g.Nodes().size(), 0UL); + g.CreateVarNode(new VarDesc("temp_var_desc_name")); + g.CreateOpNode(prog.MutableBlock(0)->AppendOp()); + g.CreateControlDepVar(); + g.CreateEmptyNode("temp_empty_node_name", ir::Node::Type::kVariable); + ASSERT_EQ(g.Nodes().size(), 4UL); + g.RemoveNode(g.RetrieveNode(1)); + ASSERT_EQ(g.Nodes().size(), 3UL); + + // Recover FLAGS_convert_all_blocks. + FLAGS_convert_all_blocks = flag_temp; +} + TEST(GraphTest, TestMultiBlock) { // Set FLAGS_convert_all_blocks to true to make sure this test works. bool flag_temp = FLAGS_convert_all_blocks;