Skip to content

Commit

Permalink
enable some TestPass cases
Browse files Browse the repository at this point in the history
  • Loading branch information
levi131 committed Jun 23, 2021
1 parent cc4c312 commit 0b62f2d
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 1 deletion.
2 changes: 1 addition & 1 deletion paddle/fluid/framework/ir/graph_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ TEST(GraphTest, TestMultiBlock) {
ASSERT_NE(control_dep2, nullptr);
ASSERT_EQ(control_dep1, control_dep2);

// Step3: Colne graph.
// Step3: Clone graph.
std::shared_ptr<ir::Graph> clone_g = g->Clone();
ASSERT_EQ(clone_g->IsMainGraph(), true);

Expand Down
109 changes: 109 additions & 0 deletions paddle/fluid/framework/ir/pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,93 @@ TEST(PassTest, TestPassAttrCheck) {
exception.npos);
}

TEST(PassTest, TestPassAttrCheckConvertAllBlocks) {
// Set FLAGS_convert_all_blocks to true to make sure this test works.
bool flag_temp = FLAGS_convert_all_blocks;
FLAGS_convert_all_blocks = true;

ProgramDesc prog;
auto pass = PassRegistry::Instance().Get("test_pass");
std::unique_ptr<Graph> graph(new Graph(prog));
std::string exception;
try {
graph.reset(pass->Apply(graph.release()));
} catch (paddle::platform::EnforceNotMet& e) {
exception = std::string(e.what());
}
ASSERT_TRUE(exception.find("Required atrribute test_pass_attr for pass < "
"test_pass > is not set") != exception.npos);

int val = 1;
graph.reset(new Graph(prog));
pass->SetNotOwned<int>("test_pass_attr", &val);

for (std::string try_type : {"bool", "const int", "std::string"}) {
try {
if (try_type == "bool") {
pass->Get<bool>("test_pass_attr");
} else if (try_type == "const int") {
pass->Get<const int>("test_pass_attr");
} else if (try_type == "std::string") {
pass->Get<std::string>("test_pass_attr");
}
} catch (paddle::platform::EnforceNotMet& e) {
exception = std::string(e.what());
}
std::string msg = "Invalid type for attritube test_pass_attr, expected: " +
try_type + ", actual: int";
ASSERT_TRUE(exception.find(msg) != exception.npos);
}

try {
graph.reset(pass->Apply(graph.release()));
} catch (paddle::platform::EnforceNotMet& e) {
exception = std::string(e.what());
}
ASSERT_TRUE(exception.find(
"Required atrribute test_graph_attr for graph is not set") !=
exception.npos);

graph.reset(new Graph(prog));
graph->Set<int>("test_graph_attr", new int);
graph->Get<int>("test_graph_attr") = 1;
graph.reset(pass->Apply(graph.release()));
ASSERT_EQ(graph->Get<int>("copy_test_pass_attr"), 2);
ASSERT_EQ(graph->Get<int>("copy_test_graph_attr"), 2);

// Allow apply more than once.
graph.reset(new Graph(prog));
graph->Set<int>("test_graph_attr", new int);
graph.reset(pass->Apply(graph.release()));

pass = PassRegistry::Instance().Get("test_pass");
pass->SetNotOwned<int>("test_pass_attr", &val);
graph.reset(new Graph(prog));
BuildCircleGraph(graph.get());
graph->Set<int>("test_graph_attr", new int);
graph->Get<int>("test_graph_attr") = 2;
try {
pass->Apply(graph.release());
} catch (paddle::platform::EnforceNotMet& e) {
exception = std::string(e.what());
}
ASSERT_TRUE(exception.find("shouldn't contain cycle") != exception.npos);

pass = PassRegistry::Instance().Get("test_pass");
pass->Set<int>("test_pass_attr", new int);
try {
pass->Set<int>("test_pass_attr", new int);
} catch (paddle::platform::EnforceNotMet& e) {
exception = std::string(e.what());
}
ASSERT_TRUE(
exception.find("Attribute test_pass_attr already set in the pass") !=
exception.npos);

// Recover FLAGS_convert_all_blocks.
FLAGS_convert_all_blocks = flag_temp;
}

class TestPassWithDefault : public Pass {
protected:
void ApplyImpl(ir::Graph* graph) const {
Expand All @@ -160,6 +247,28 @@ TEST(PassTest, TestPassDefaultAttrCheck) {
ASSERT_EQ(pass->Get<int>("default_attr"), 3);
}

TEST(PassTest, TestPassDefaultAttrCheckConvertAllBlocks) {
// Set FLAGS_convert_all_blocks to true to make sure this test works.
bool flag_temp = FLAGS_convert_all_blocks;
FLAGS_convert_all_blocks = true;

ProgramDesc prog;
// check if default value is set
auto pass = PassRegistry::Instance().Get("test_pass_default_attr");
std::unique_ptr<Graph> graph(new Graph(prog));
ASSERT_EQ(pass->Get<int>("default_attr"), 1);
graph.reset(pass->Apply(graph.release()));
ASSERT_EQ(graph->Get<int>("copy_default_attr"), 2);

// check if new value overrides default value
pass = PassRegistry::Instance().Get("test_pass_default_attr");
pass->Set<int>("default_attr", new int{3});
ASSERT_EQ(pass->Get<int>("default_attr"), 3);

// Recover FLAGS_convert_all_blocks.
FLAGS_convert_all_blocks = flag_temp;
}

TEST(PassTest, TestPassRegistrarDeconstructor) {
auto pass_registrary =
new PassRegistrar<paddle::framework::ir::TestPassWithDefault>(
Expand Down

0 comments on commit 0b62f2d

Please sign in to comment.