From 3e2059db6fc7547a0651d4bd5151d7e957714c67 Mon Sep 17 00:00:00 2001 From: levi131 Date: Wed, 2 Jun 2021 12:10:07 +0000 Subject: [PATCH 01/17] draft for ssa_program to graph with sub_graphs --- paddle/fluid/framework/ir/graph.cc | 66 ++++++++++++++++++++++-------- paddle/fluid/framework/ir/graph.h | 16 +++++++- 2 files changed, 62 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 706df467d3535..63195d9e5e7f2 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -22,23 +22,40 @@ namespace framework { namespace ir { Graph::Graph(const ProgramDesc &program) : program_(program) { - auto var_nodes = InitFromProgram(program_); + PADDLE_ENFORCE_GE( + program_.Size(), 1, + + platform::errors::InvalidArgument("Can't construct a graph from this " + "program, it doesn't have a block")); + for (size_t idx = 0; idx < program_.Size(); ++idx) { + std::unique_ptr sub_graph = + std::make_unique(program_.Block(idx), this); + PADDLE_ENFORCE_EQ(sub_graph->OriginProgram().Size(), 0, + platform::errors::InvalidArgument( + "The sub_graph should has an empty program, but has " + "a program with %d blocks", + sub_graph->OriginProgram().Size())); + sub_graphs_.push_back(std::move(sub_graph)); + } +} + +Graph::Graph(const BlockDesc &block, const Graph *parent) : parent_(parent) { + auto var_nodes = InitFromBlock(block); ResolveHazard(var_nodes); } -std::map> Graph::InitFromProgram( - const ProgramDesc &program) { - VLOG(3) << "block in program:" << program_.Size(); +std::map> Graph::InitFromBlock( + const BlockDesc &block) { std::unordered_map all_vars; // var nodes for each var name, will have multiple versions in SSA std::map> var_nodes; - for (auto *var : program.Block(0).AllVars()) { + for (auto *var : block.AllVars()) { all_vars.emplace(var->Name(), var); } auto not_visited_vars = all_vars; - for (auto *op : program.Block(0).AllOps()) { + for (auto *op : block.AllOps()) { ir::Node *node = CreateOpNode(op); // For input args, reuse the same var name if it was created before. // Otherwise, create a new one. @@ -97,9 +114,8 @@ std::map> Graph::InitFromProgram( } } - Set>( - details::kStaleProgramOpDescs, - new std::vector(program.Block(0).AllOps())); + Set>(details::kStaleProgramOpDescs, + new std::vector(block.AllOps())); return var_nodes; } @@ -176,22 +192,36 @@ void Graph::ResolveHazard( } std::shared_ptr Graph::Clone() { + PADDLE_ENFORCE_EQ( + this->parent_, nullptr, + platform::errors::InvalidArgument( + "This graph is a subgraph, and can't be cloned individually")); auto cloned_graph = std::make_shared(this->program_); - cloned_graph->ReleaseNodes(); - cloned_graph->num_node_created_ = 0; + cloned_graph->ReleaseSubgraphs(); + for (size_t idx = 0; idx < this->program_.Size(); ++idx) { + cloned_graph->AddSubgraph(this->CloneSubgraph(idx)); + } + return cloned_graph; +} + +std::unique_ptr Graph::CloneSubgraph(const size_t idx) { + std::unique_ptr cloned_sub_graph = + std::make_unique(this->program_.Block(idx), this); + cloned_sub_graph->ReleaseNodes(); + cloned_sub_graph->num_node_created_ = 0; std::unordered_map origin_to_cloned; - for (auto *n : this->node_set_) { + for (auto *n : this->sub_graphs_.at(idx)->Nodes()) { PADDLE_ENFORCE_NOT_NULL(n, platform::errors::InvalidArgument( "The node to be cloned is nullptr.")); ir::Node *cloned_node = nullptr; if (n->IsCtrlVar()) { - cloned_node = cloned_graph->CreateControlDepVar(); + cloned_node = cloned_sub_graph->CreateControlDepVar(); } else if (!n->var_desc_ && !n->op_desc_) { // empty node - cloned_node = cloned_graph->CreateEmptyNode(n->Name(), n->NodeType()); + cloned_node = cloned_sub_graph->CreateEmptyNode(n->Name(), n->NodeType()); } else if (n->IsVar()) { - cloned_node = cloned_graph->CreateVarNode(n->Var()); + cloned_node = cloned_sub_graph->CreateVarNode(n->Var()); } else if (n->IsOp()) { - cloned_node = cloned_graph->CreateOpNode(n->Op()); + cloned_node = cloned_sub_graph->CreateOpNode(n->Op()); } PADDLE_ENFORCE_NOT_NULL( cloned_node, @@ -199,7 +229,7 @@ std::shared_ptr Graph::Clone() { "Failed to clone new node from original node in graph.")); origin_to_cloned[n] = cloned_node; } - for (auto *n : this->node_set_) { + for (auto *n : this->sub_graphs_.at(idx)->Nodes()) { for (auto it = n->inputs.begin(); it != n->inputs.end(); it++) { origin_to_cloned[n]->inputs.push_back(origin_to_cloned[*it]); } @@ -207,7 +237,7 @@ std::shared_ptr Graph::Clone() { origin_to_cloned[n]->outputs.push_back(origin_to_cloned[*it]); } } - return cloned_graph; + return cloned_sub_graph; } bool IsControlDepVar(const ir::Node &var) { diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 593ac214e56f9..8a4ecc1c21d61 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -79,6 +79,7 @@ namespace ir { class Graph { public: explicit Graph(const ProgramDesc &program); + Graph(const BlockDesc &block, const Graph *parent); virtual ~Graph() { for (auto &attr : attrs_) { @@ -252,11 +253,22 @@ class Graph { std::shared_ptr Clone(); private: - std::map> InitFromProgram( - const ProgramDesc &program); + std::map> InitFromBlock( + const BlockDesc &block); + + void ReleaseSubgraphs() { sub_graphs_.clear(); } + + void AddSubgraph(std::unique_ptr sub_graph) { + sub_graphs_.push_back(std::move(sub_graph)); + } + + std::unique_ptr CloneSubgraph(const size_t idx); // NOTE: program_ shouldn't be exposed to user. const ProgramDesc program_; + const Graph *parent_; // not owned. + std::vector> sub_graphs_; + std::map attrs_; std::map> attr_dels_; std::map> nodes_; From 0c58b91114d6fe33c23ad07cde1d4faf3c0ad6d7 Mon Sep 17 00:00:00 2001 From: levi131 Date: Thu, 3 Jun 2021 06:40:34 +0000 Subject: [PATCH 02/17] try add gflags ssa_program --- paddle/fluid/framework/ir/graph.cc | 1 - paddle/fluid/framework/ir/graph.h | 3 +++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 63195d9e5e7f2..ab5f38e2bfdbf 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -24,7 +24,6 @@ namespace ir { Graph::Graph(const ProgramDesc &program) : program_(program) { PADDLE_ENFORCE_GE( program_.Size(), 1, - platform::errors::InvalidArgument("Can't construct a graph from this " "program, it doesn't have a block")); for (size_t idx = 0; idx < program_.Size(); ++idx) { diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 8a4ecc1c21d61..17787fe6d0d5d 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include #include @@ -25,6 +26,8 @@ limitations under the License. */ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/variant.h" +DEFINE_bool(ssa_program, true, "Convert all blocks in program into SSAgraphs"); + namespace paddle { namespace framework { class OpDesc; From 69c71edb6bfd00b7e92ef67a31eff60adfad0951 Mon Sep 17 00:00:00 2001 From: levi131 Date: Thu, 3 Jun 2021 07:20:02 +0000 Subject: [PATCH 03/17] use gflags --- paddle/fluid/framework/ir/graph.cc | 123 +++++++++++++++++++++++++++++ paddle/fluid/framework/ir/graph.h | 11 ++- 2 files changed, 133 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index ab5f38e2bfdbf..ce0a87c36853b 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -17,10 +17,13 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/operator.h" +DEFINE_bool(ssa_program, false, "Convert all blocks in program into SSAgraphs"); + namespace paddle { namespace framework { namespace ir { +#if FLAGS_ssa_program Graph::Graph(const ProgramDesc &program) : program_(program) { PADDLE_ENFORCE_GE( program_.Size(), 1, @@ -117,6 +120,89 @@ std::map> Graph::InitFromBlock( new std::vector(block.AllOps())); return var_nodes; } +#else // FLAGS_ssa_program +Graph::Graph(const ProgramDesc &program) : program_(program) { + auto var_nodes = InitFromProgram(program_); + ResolveHazard(var_nodes); +} + +std::map> Graph::InitFromProgram( + const ProgramDesc &program) { + VLOG(3) << "block in program:" << program_.Size(); + std::unordered_map all_vars; + // var nodes for each var name, will have multiple versions in SSA + std::map> var_nodes; + for (auto *var : program.Block(0).AllVars()) { + all_vars.emplace(var->Name(), var); + } + + auto not_visited_vars = all_vars; + + for (auto *op : program.Block(0).AllOps()) { + ir::Node *node = CreateOpNode(op); + // For input args, reuse the same var name if it was created before. + // Otherwise, create a new one. + for (auto &each_var_name : op->InputArgumentNames()) { + not_visited_vars.erase(each_var_name); + ir::Node *var = nullptr; + if (var_nodes.find(each_var_name) != var_nodes.end()) { + var = var_nodes.at(each_var_name).back(); + } else if (all_vars.count(each_var_name) != 0) { + var = CreateVarNode(all_vars.at(each_var_name)); + var_nodes[each_var_name].push_back(var); + } else { + // Operation input var can be optional (dispensable). Which means + // the operation doesn't really need the var at runtime. In this + // case, the no-existed var is ready at the beginning. + var = CreateEmptyNode(each_var_name, ir::Node::Type::kVariable); + var_nodes[each_var_name].push_back(var); + } + node->inputs.push_back(var); + var->outputs.push_back(node); + } + // For output args, always create a new var. + std::unordered_set out_arg_set; + for (auto &each_var_name : op->OutputArgumentNames()) { + not_visited_vars.erase(each_var_name); + if (each_var_name != kEmptyVarName) { + PADDLE_ENFORCE_EQ(out_arg_set.count(each_var_name), 0, + platform::errors::InvalidArgument( + "The input Program is invalid. Variable %s occurs" + " in output of %s multiple times.", + each_var_name, op->Type())); + out_arg_set.insert(each_var_name); + } + + ir::Node *var = nullptr; + if (all_vars.count(each_var_name) != 0) { + var = CreateVarNode(all_vars.at(each_var_name)); + } else { + // Operation output vars can be @EMPTY@. For example, while_grad + // can have multi @EMPTY@ outputs with no VarDesc. + // TODO(panyx0718): Add a test. + var = CreateEmptyNode(each_var_name, ir::Node::Type::kVariable); + } + var_nodes[each_var_name].push_back(var); + node->outputs.push_back(var); + var->inputs.push_back(node); + } + } + + for (auto &pair : not_visited_vars) { + const auto &var_name = pair.first; + auto *var_desc = pair.second; + if (var_name != kEmptyVarName) { + VLOG(10) << "Create isolated var node " << var_name; + var_nodes[var_name].push_back(CreateVarNode(var_desc)); + } + } + + Set>( + details::kStaleProgramOpDescs, + new std::vector(program.Block(0).AllOps())); + return var_nodes; +} +#endif // FLAGS_ssa_program void Graph::ResolveHazard( const std::map> &var_nodes) { @@ -190,6 +276,7 @@ void Graph::ResolveHazard( } } +#if FLAGS_ssa_program std::shared_ptr Graph::Clone() { PADDLE_ENFORCE_EQ( this->parent_, nullptr, @@ -238,6 +325,42 @@ std::unique_ptr Graph::CloneSubgraph(const size_t idx) { } return cloned_sub_graph; } +#else // FLAGS_ssa_program +std::shared_ptr Graph::Clone() { + auto cloned_graph = std::make_shared(this->program_); + cloned_graph->ReleaseNodes(); + cloned_graph->num_node_created_ = 0; + std::unordered_map origin_to_cloned; + for (auto *n : this->node_set_) { + PADDLE_ENFORCE_NOT_NULL(n, platform::errors::InvalidArgument( + "The node to be cloned is nullptr.")); + ir::Node *cloned_node = nullptr; + if (n->IsCtrlVar()) { + cloned_node = cloned_graph->CreateControlDepVar(); + } else if (!n->var_desc_ && !n->op_desc_) { // empty node + cloned_node = cloned_graph->CreateEmptyNode(n->Name(), n->NodeType()); + } else if (n->IsVar()) { + cloned_node = cloned_graph->CreateVarNode(n->Var()); + } else if (n->IsOp()) { + cloned_node = cloned_graph->CreateOpNode(n->Op()); + } + PADDLE_ENFORCE_NOT_NULL( + cloned_node, + platform::errors::InvalidArgument( + "Failed to clone new node from original node in graph.")); + origin_to_cloned[n] = cloned_node; + } + for (auto *n : this->node_set_) { + for (auto it = n->inputs.begin(); it != n->inputs.end(); it++) { + origin_to_cloned[n]->inputs.push_back(origin_to_cloned[*it]); + } + for (auto it = n->outputs.begin(); it != n->outputs.end(); it++) { + origin_to_cloned[n]->outputs.push_back(origin_to_cloned[*it]); + } + } + return cloned_graph; +} +#endif // FLAGS_ssa_program bool IsControlDepVar(const ir::Node &var) { return var.Name().find(ir::Node::kControlDepVarName) != std::string::npos; diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 17787fe6d0d5d..d5f231eba3659 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -26,7 +26,7 @@ limitations under the License. */ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/variant.h" -DEFINE_bool(ssa_program, true, "Convert all blocks in program into SSAgraphs"); +DECLARE_bool(ssa_program); namespace paddle { namespace framework { @@ -82,7 +82,9 @@ namespace ir { class Graph { public: explicit Graph(const ProgramDesc &program); +#if FLAGS_ssa_program Graph(const BlockDesc &block, const Graph *parent); +#endif // FLAGS_ssa_program virtual ~Graph() { for (auto &attr : attrs_) { @@ -256,6 +258,7 @@ class Graph { std::shared_ptr Clone(); private: +#if FLAGS_ssa_program std::map> InitFromBlock( const BlockDesc &block); @@ -266,11 +269,17 @@ class Graph { } std::unique_ptr CloneSubgraph(const size_t idx); +#else // FLAGS_ssa_program + std::map> InitFromProgram( + const ProgramDesc &program); +#endif // FLAGS_ssa_program // NOTE: program_ shouldn't be exposed to user. const ProgramDesc program_; +#if FLAGS_ssa_program const Graph *parent_; // not owned. std::vector> sub_graphs_; +#endif // FLAGS_ssa_program std::map attrs_; std::map> attr_dels_; From 1ccb8df45175d36364ecbdd1e807664ba95375fa Mon Sep 17 00:00:00 2001 From: levi131 Date: Mon, 7 Jun 2021 08:10:22 +0000 Subject: [PATCH 04/17] rename gflag macro and some member of Graph --- paddle/fluid/framework/ir/graph.cc | 20 +++++++++++--------- paddle/fluid/framework/ir/graph.h | 21 ++++++++++++--------- 2 files changed, 23 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index ce0a87c36853b..1644d6c34e0b0 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -17,13 +17,14 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/operator.h" -DEFINE_bool(ssa_program, false, "Convert all blocks in program into SSAgraphs"); +DEFINE_bool(convert_all_blocks, false, + "Convert all blocks in program into SSAgraphs"); namespace paddle { namespace framework { namespace ir { -#if FLAGS_ssa_program +#if FLAGS_convert_all_blocks Graph::Graph(const ProgramDesc &program) : program_(program) { PADDLE_ENFORCE_GE( program_.Size(), 1, @@ -41,7 +42,8 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { } } -Graph::Graph(const BlockDesc &block, const Graph *parent) : parent_(parent) { +Graph::Graph(const BlockDesc &block, const Graph *main_graph) + : main_graph_(main_graph) { auto var_nodes = InitFromBlock(block); ResolveHazard(var_nodes); } @@ -120,7 +122,7 @@ std::map> Graph::InitFromBlock( new std::vector(block.AllOps())); return var_nodes; } -#else // FLAGS_ssa_program +#else // FLAGS_convert_all_blocks Graph::Graph(const ProgramDesc &program) : program_(program) { auto var_nodes = InitFromProgram(program_); ResolveHazard(var_nodes); @@ -202,7 +204,7 @@ std::map> Graph::InitFromProgram( new std::vector(program.Block(0).AllOps())); return var_nodes; } -#endif // FLAGS_ssa_program +#endif // FLAGS_convert_all_blocks void Graph::ResolveHazard( const std::map> &var_nodes) { @@ -276,10 +278,10 @@ void Graph::ResolveHazard( } } -#if FLAGS_ssa_program +#if FLAGS_convert_all_blocks std::shared_ptr Graph::Clone() { PADDLE_ENFORCE_EQ( - this->parent_, nullptr, + this->main_graph_, nullptr, platform::errors::InvalidArgument( "This graph is a subgraph, and can't be cloned individually")); auto cloned_graph = std::make_shared(this->program_); @@ -325,7 +327,7 @@ std::unique_ptr Graph::CloneSubgraph(const size_t idx) { } return cloned_sub_graph; } -#else // FLAGS_ssa_program +#else // FLAGS_convert_all_blocks std::shared_ptr Graph::Clone() { auto cloned_graph = std::make_shared(this->program_); cloned_graph->ReleaseNodes(); @@ -360,7 +362,7 @@ std::shared_ptr Graph::Clone() { } return cloned_graph; } -#endif // FLAGS_ssa_program +#endif // FLAGS_convert_all_blocks bool IsControlDepVar(const ir::Node &var) { return var.Name().find(ir::Node::kControlDepVarName) != std::string::npos; diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index d5f231eba3659..baaa1dd33c421 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -26,7 +26,7 @@ limitations under the License. */ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/variant.h" -DECLARE_bool(ssa_program); +DECLARE_bool(convert_all_blocks); namespace paddle { namespace framework { @@ -82,9 +82,9 @@ namespace ir { class Graph { public: explicit Graph(const ProgramDesc &program); -#if FLAGS_ssa_program +#if FLAGS_convert_all_blocks Graph(const BlockDesc &block, const Graph *parent); -#endif // FLAGS_ssa_program +#endif // FLAGS_convert_all_blocks virtual ~Graph() { for (auto &attr : attrs_) { @@ -258,7 +258,7 @@ class Graph { std::shared_ptr Clone(); private: -#if FLAGS_ssa_program +#if FLAGS_convert_all_blocks std::map> InitFromBlock( const BlockDesc &block); @@ -269,17 +269,20 @@ class Graph { } std::unique_ptr CloneSubgraph(const size_t idx); -#else // FLAGS_ssa_program +#else // FLAGS_convert_all_blocks std::map> InitFromProgram( const ProgramDesc &program); -#endif // FLAGS_ssa_program +#endif // FLAGS_convert_all_blocks // NOTE: program_ shouldn't be exposed to user. const ProgramDesc program_; -#if FLAGS_ssa_program - const Graph *parent_; // not owned. +#if FLAGS_convert_all_blocks + // NOTE: main_graph_ doesn't hold any node. It's used as a container of + // sub_graphs, + // and the sub_graph holds the nodes. + const Graph *main_graph_; // not owned. std::vector> sub_graphs_; -#endif // FLAGS_ssa_program +#endif // FLAGS_convert_all_blocks std::map attrs_; std::map> attr_dels_; From 2495c0618539a3fad4eb8544d9390a270db88618 Mon Sep 17 00:00:00 2001 From: levi131 Date: Mon, 7 Jun 2021 12:28:03 +0000 Subject: [PATCH 05/17] use nomal if-else instead of #if --- paddle/fluid/framework/ir/graph.cc | 134 +++++++++++++++-------------- paddle/fluid/framework/ir/graph.h | 26 +++--- paddle/fluid/framework/ir/pass.cc | 9 +- 3 files changed, 91 insertions(+), 78 deletions(-) diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 1644d6c34e0b0..2891c1b812906 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -24,21 +24,26 @@ namespace paddle { namespace framework { namespace ir { -#if FLAGS_convert_all_blocks Graph::Graph(const ProgramDesc &program) : program_(program) { - PADDLE_ENFORCE_GE( - program_.Size(), 1, - platform::errors::InvalidArgument("Can't construct a graph from this " - "program, it doesn't have a block")); - for (size_t idx = 0; idx < program_.Size(); ++idx) { - std::unique_ptr sub_graph = - std::make_unique(program_.Block(idx), this); - PADDLE_ENFORCE_EQ(sub_graph->OriginProgram().Size(), 0, - platform::errors::InvalidArgument( - "The sub_graph should has an empty program, but has " - "a program with %d blocks", - sub_graph->OriginProgram().Size())); - sub_graphs_.push_back(std::move(sub_graph)); + if (FLAGS_convert_all_blocks) { + PADDLE_ENFORCE_GE( + program_.Size(), 1, + platform::errors::InvalidArgument("Can't construct a graph from this " + "program, it doesn't have a block")); + for (size_t idx = 0; idx < program_.Size(); ++idx) { + std::unique_ptr sub_graph = + std::make_unique(program_.Block(idx), this); + PADDLE_ENFORCE_EQ( + sub_graph->OriginProgram().Size(), 0, + platform::errors::InvalidArgument( + "The sub_graph should has an empty program, but has " + "a program with %d blocks", + sub_graph->OriginProgram().Size())); + sub_graphs_.push_back(std::move(sub_graph)); + } + } else { + auto var_nodes = InitFromProgram(program_); + ResolveHazard(var_nodes); } } @@ -122,12 +127,9 @@ std::map> Graph::InitFromBlock( new std::vector(block.AllOps())); return var_nodes; } -#else // FLAGS_convert_all_blocks -Graph::Graph(const ProgramDesc &program) : program_(program) { - auto var_nodes = InitFromProgram(program_); - ResolveHazard(var_nodes); -} +// TODO(levi): delete this interface after when we can convert all +// blocks into sub_graphs. std::map> Graph::InitFromProgram( const ProgramDesc &program) { VLOG(3) << "block in program:" << program_.Size(); @@ -204,7 +206,6 @@ std::map> Graph::InitFromProgram( new std::vector(program.Block(0).AllOps())); return var_nodes; } -#endif // FLAGS_convert_all_blocks void Graph::ResolveHazard( const std::map> &var_nodes) { @@ -278,21 +279,58 @@ void Graph::ResolveHazard( } } -#if FLAGS_convert_all_blocks std::shared_ptr Graph::Clone() { - PADDLE_ENFORCE_EQ( - this->main_graph_, nullptr, - platform::errors::InvalidArgument( - "This graph is a subgraph, and can't be cloned individually")); - auto cloned_graph = std::make_shared(this->program_); - cloned_graph->ReleaseSubgraphs(); - for (size_t idx = 0; idx < this->program_.Size(); ++idx) { - cloned_graph->AddSubgraph(this->CloneSubgraph(idx)); + if (FLAGS_convert_all_blocks) { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), true, + platform::errors::InvalidArgument( + "This graph is a subgraph, and can't be cloned individually")); + auto cloned_graph = std::make_shared(this->program_); + cloned_graph->ReleaseSubgraphs(); + for (size_t idx = 0; idx < this->program_.Size(); ++idx) { + cloned_graph->AddSubgraph(this->CloneSubgraph(idx)); + } + return cloned_graph; + } else { + auto cloned_graph = std::make_shared(this->program_); + cloned_graph->ReleaseNodes(); + cloned_graph->num_node_created_ = 0; + std::unordered_map origin_to_cloned; + for (auto *n : this->node_set_) { + PADDLE_ENFORCE_NOT_NULL(n, platform::errors::InvalidArgument( + "The node to be cloned is nullptr.")); + ir::Node *cloned_node = nullptr; + if (n->IsCtrlVar()) { + cloned_node = cloned_graph->CreateControlDepVar(); + } else if (!n->var_desc_ && !n->op_desc_) { // empty node + cloned_node = cloned_graph->CreateEmptyNode(n->Name(), n->NodeType()); + } else if (n->IsVar()) { + cloned_node = cloned_graph->CreateVarNode(n->Var()); + } else if (n->IsOp()) { + cloned_node = cloned_graph->CreateOpNode(n->Op()); + } + PADDLE_ENFORCE_NOT_NULL( + cloned_node, + platform::errors::InvalidArgument( + "Failed to clone new node from original node in graph.")); + origin_to_cloned[n] = cloned_node; + } + for (auto *n : this->node_set_) { + for (auto it = n->inputs.begin(); it != n->inputs.end(); it++) { + origin_to_cloned[n]->inputs.push_back(origin_to_cloned[*it]); + } + for (auto it = n->outputs.begin(); it != n->outputs.end(); it++) { + origin_to_cloned[n]->outputs.push_back(origin_to_cloned[*it]); + } + } + return cloned_graph; } - return cloned_graph; } std::unique_ptr Graph::CloneSubgraph(const size_t idx) { + PADDLE_ENFORCE_LT( + idx, this->sub_graphs_.size(), + platform::errors::InvalidArgument("Invalid sub_graph index")); std::unique_ptr cloned_sub_graph = std::make_unique(this->program_.Block(idx), this); cloned_sub_graph->ReleaseNodes(); @@ -327,42 +365,6 @@ std::unique_ptr Graph::CloneSubgraph(const size_t idx) { } return cloned_sub_graph; } -#else // FLAGS_convert_all_blocks -std::shared_ptr Graph::Clone() { - auto cloned_graph = std::make_shared(this->program_); - cloned_graph->ReleaseNodes(); - cloned_graph->num_node_created_ = 0; - std::unordered_map origin_to_cloned; - for (auto *n : this->node_set_) { - PADDLE_ENFORCE_NOT_NULL(n, platform::errors::InvalidArgument( - "The node to be cloned is nullptr.")); - ir::Node *cloned_node = nullptr; - if (n->IsCtrlVar()) { - cloned_node = cloned_graph->CreateControlDepVar(); - } else if (!n->var_desc_ && !n->op_desc_) { // empty node - cloned_node = cloned_graph->CreateEmptyNode(n->Name(), n->NodeType()); - } else if (n->IsVar()) { - cloned_node = cloned_graph->CreateVarNode(n->Var()); - } else if (n->IsOp()) { - cloned_node = cloned_graph->CreateOpNode(n->Op()); - } - PADDLE_ENFORCE_NOT_NULL( - cloned_node, - platform::errors::InvalidArgument( - "Failed to clone new node from original node in graph.")); - origin_to_cloned[n] = cloned_node; - } - for (auto *n : this->node_set_) { - for (auto it = n->inputs.begin(); it != n->inputs.end(); it++) { - origin_to_cloned[n]->inputs.push_back(origin_to_cloned[*it]); - } - for (auto it = n->outputs.begin(); it != n->outputs.end(); it++) { - origin_to_cloned[n]->outputs.push_back(origin_to_cloned[*it]); - } - } - return cloned_graph; -} -#endif // FLAGS_convert_all_blocks bool IsControlDepVar(const ir::Node &var) { return var.Name().find(ir::Node::kControlDepVarName) != std::string::npos; diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index baaa1dd33c421..263dd4f07d36b 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -82,9 +82,7 @@ namespace ir { class Graph { public: explicit Graph(const ProgramDesc &program); -#if FLAGS_convert_all_blocks Graph(const BlockDesc &block, const Graph *parent); -#endif // FLAGS_convert_all_blocks virtual ~Graph() { for (auto &attr : attrs_) { @@ -257,8 +255,16 @@ class Graph { // WARN: The method only clones the graph structure, not its attributes. std::shared_ptr Clone(); + bool IsMainGraph() const { return main_graph_ == nullptr; } + + Graph *GetSubGraph(const size_t idx) const { + PADDLE_ENFORCE_LT( + idx, sub_graphs_.size(), + platform::errors::InvalidArgument("Invalid sub_graph index")); + return sub_graphs_.at(idx).get(); + } + private: -#if FLAGS_convert_all_blocks std::map> InitFromBlock( const BlockDesc &block); @@ -269,20 +275,18 @@ class Graph { } std::unique_ptr CloneSubgraph(const size_t idx); -#else // FLAGS_convert_all_blocks + + // TODO(levi): delete this interface after when we can convert all + // blocks into sub_graphs. std::map> InitFromProgram( const ProgramDesc &program); -#endif // FLAGS_convert_all_blocks - // NOTE: program_ shouldn't be exposed to user. - const ProgramDesc program_; -#if FLAGS_convert_all_blocks // NOTE: main_graph_ doesn't hold any node. It's used as a container of - // sub_graphs, - // and the sub_graph holds the nodes. + // sub_graphs, and the sub_graph holds the nodes. const Graph *main_graph_; // not owned. + // NOTE: program_ shouldn't be exposed to user. + const ProgramDesc program_; std::vector> sub_graphs_; -#endif // FLAGS_convert_all_blocks std::map attrs_; std::map> attr_dels_; diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index 0e5f5867f47b2..763e376b40aaf 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -32,9 +32,16 @@ namespace framework { namespace ir { Graph* Pass::Apply(Graph* graph) const { - CheckPrevPass(); PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + if (FLAGS_convert_all_blocks) { + // NOTE(levi): If graph is main_graph, apply pass on the 1st sub_graph. + if (graph->IsMainGraph()) { + this->Apply(graph->GetSubGraph(0)); + return graph; + } + } + CheckPrevPass(); for (const std::string& attr : required_pass_attrs_) { PADDLE_ENFORCE_NE( attrs_.find(attr), attrs_.end(), From a9a55df7d63473bfcbe756d5a4f2c9d63692a0d2 Mon Sep 17 00:00:00 2001 From: levi131 Date: Tue, 8 Jun 2021 12:06:36 +0000 Subject: [PATCH 06/17] add unittest for convert_all_blocks --- paddle/fluid/framework/ir/graph.cc | 18 +-- paddle/fluid/framework/ir/graph.h | 127 +++++++++++++++- paddle/fluid/framework/ir/graph_test.cc | 192 ++++++++++++++++++++++-- 3 files changed, 308 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 2891c1b812906..1215556839be5 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -24,7 +24,8 @@ namespace paddle { namespace framework { namespace ir { -Graph::Graph(const ProgramDesc &program) : program_(program) { +Graph::Graph(const ProgramDesc &program) + : program_(program), main_graph_(nullptr) { if (FLAGS_convert_all_blocks) { PADDLE_ENFORCE_GE( program_.Size(), 1, @@ -33,12 +34,6 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { for (size_t idx = 0; idx < program_.Size(); ++idx) { std::unique_ptr sub_graph = std::make_unique(program_.Block(idx), this); - PADDLE_ENFORCE_EQ( - sub_graph->OriginProgram().Size(), 0, - platform::errors::InvalidArgument( - "The sub_graph should has an empty program, but has " - "a program with %d blocks", - sub_graph->OriginProgram().Size())); sub_graphs_.push_back(std::move(sub_graph)); } } else { @@ -286,9 +281,9 @@ std::shared_ptr Graph::Clone() { platform::errors::InvalidArgument( "This graph is a subgraph, and can't be cloned individually")); auto cloned_graph = std::make_shared(this->program_); - cloned_graph->ReleaseSubgraphs(); + cloned_graph->ReleaseSubGraphs(); for (size_t idx = 0; idx < this->program_.Size(); ++idx) { - cloned_graph->AddSubgraph(this->CloneSubgraph(idx)); + cloned_graph->AddSubGraph(this->CloneSubGraph(idx)); } return cloned_graph; } else { @@ -327,7 +322,10 @@ std::shared_ptr Graph::Clone() { } } -std::unique_ptr Graph::CloneSubgraph(const size_t idx) { +std::unique_ptr Graph::CloneSubGraph(const size_t idx) { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), true, + platform::errors::InvalidArgument("This graph is not main_graph")); PADDLE_ENFORCE_LT( idx, this->sub_graphs_.size(), platform::errors::InvalidArgument("Invalid sub_graph index")); diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 263dd4f07d36b..8f0b6e45b0f4e 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -93,11 +93,23 @@ class Graph { } bool Has(const std::string &attr_name) const { + if (FLAGS_convert_all_blocks) { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), false, + platform::errors::InvalidArgument( + "This graph is main_graph, which shouldn't have nodes or attrs")); + } return attrs_.count(attr_name) > 0; } template AttrType &GetOrInit(const std::string &attr_name) { + if (FLAGS_convert_all_blocks) { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), false, + platform::errors::InvalidArgument( + "This graph is main_graph, which shouldn't have nodes or attrs")); + } if (!Has(attr_name)) { Set(attr_name, new AttrType); } @@ -106,6 +118,12 @@ class Graph { template AttrType &Get(const std::string &attr_name) const { + if (FLAGS_convert_all_blocks) { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), false, + platform::errors::InvalidArgument( + "This graph is main_graph, which shouldn't have nodes or attrs")); + } PADDLE_ENFORCE_EQ( Has(attr_name), true, platform::errors::PreconditionNotMet( @@ -122,6 +140,12 @@ class Graph { template void Set(const std::string &attr_name, AttrType *attr) { + if (FLAGS_convert_all_blocks) { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), false, + platform::errors::InvalidArgument( + "This graph is main_graph, which shouldn't have nodes or attrs")); + } PADDLE_ENFORCE_EQ( attrs_.count(attr_name), 0, platform::errors::AlreadyExists( @@ -136,6 +160,12 @@ class Graph { template void SetNotOwned(const std::string &attr_name, AttrType *attr) { + if (FLAGS_convert_all_blocks) { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), false, + platform::errors::InvalidArgument( + "This graph is main_graph, which shouldn't have nodes or attrs")); + } PADDLE_ENFORCE_EQ( attrs_.count(attr_name), 0, platform::errors::AlreadyExists("The attribute %s to be set(not owned) " @@ -146,6 +176,12 @@ class Graph { } void Erase(const std::string &attr_name) { + if (FLAGS_convert_all_blocks) { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), false, + platform::errors::InvalidArgument( + "This graph is main_graph, which shouldn't have nodes or attrs")); + } PADDLE_ENFORCE_NE( attrs_.count(attr_name), 0, platform::errors::NotFound( @@ -156,10 +192,24 @@ class Graph { attr_dels_.erase(attr_name); } - const std::unordered_set &Nodes() const { return node_set_; } + const std::unordered_set &Nodes() const { + if (FLAGS_convert_all_blocks) { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), false, + platform::errors::InvalidArgument( + "This graph is main_graph, which shouldn't have nodes or attrs")); + } + return node_set_; + } // Create a normal variable with non-null VarDesc. ir::Node *CreateVarNode(VarDesc *var_desc) { + if (FLAGS_convert_all_blocks) { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), false, + platform::errors::InvalidArgument( + "This graph is main_graph, which shouldn't have nodes or attrs")); + } PADDLE_ENFORCE_NOT_NULL( var_desc, platform::errors::InvalidArgument( "The VarDesc used to create variable node is null.")); @@ -170,6 +220,12 @@ class Graph { // Create a normal runnable operator with OpDesc. ir::Node *CreateOpNode(OpDesc *op_desc) { + if (FLAGS_convert_all_blocks) { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), false, + platform::errors::InvalidArgument( + "This graph is main_graph, which shouldn't have nodes or attrs")); + } PADDLE_ENFORCE_NOT_NULL( op_desc, platform::errors::InvalidArgument( "The OpDesc used to create operator node is null.")); @@ -182,6 +238,12 @@ class Graph { // var doesn't hold any data. Other than that, it's no different from // other var, considering dependency analysis. ir::Node *CreateControlDepVar() { + if (FLAGS_convert_all_blocks) { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), false, + platform::errors::InvalidArgument( + "This graph is main_graph, which shouldn't have nodes or attrs")); + } // TODO(panyx0718): control var name should be really unique. const std::string name = string::Sprintf( "%s@%llu", static_cast(ir::Node::kControlDepVarName), @@ -194,6 +256,12 @@ class Graph { // A more free style way of creating a graph node. Mostly use for test // or "copy" from another node. Avoid using it if possible. ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) { + if (FLAGS_convert_all_blocks) { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), false, + platform::errors::InvalidArgument( + "This graph is main_graph, which shouldn't have nodes or attrs")); + } auto *x = AddNode(new ir::Node(name, type)); x->SetId(num_node_created_++); return x; @@ -202,6 +270,12 @@ class Graph { // Clear all node information of the graph and return the ownership of the // nodes. std::vector> ReleaseNodes() { + if (FLAGS_convert_all_blocks) { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), false, + platform::errors::InvalidArgument( + "This graph is main_graph, which shouldn't have nodes or attrs")); + } std::vector> ret; for (auto &n : nodes_) { ret.emplace_back(n.second.release()); @@ -212,6 +286,12 @@ class Graph { } std::unique_ptr RemoveNode(ir::Node *node) { + if (FLAGS_convert_all_blocks) { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), false, + platform::errors::InvalidArgument( + "This graph is main_graph, which shouldn't have nodes or attrs")); + } PADDLE_ENFORCE_EQ(node_set_.find(node) != node_set_.end(), true, platform::errors::PreconditionNotMet( "The node to be removed does not exist.")); @@ -224,6 +304,12 @@ class Graph { // NOTE low performance, but simple and secure. Node *RetrieveNode(int id) { + if (FLAGS_convert_all_blocks) { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), false, + platform::errors::InvalidArgument( + "This graph is main_graph, which shouldn't have nodes or attrs")); + } for (auto &node : nodes_) { if (node.second->id() == id) { return node.second.get(); @@ -236,10 +322,26 @@ class Graph { // WARN: After a series of passes, the current graph can be quite // different from OriginProgram. Caller shouldn't assume much from // the returned OriginProgram. - const ProgramDesc &OriginProgram() const { return program_; } + const ProgramDesc &OriginProgram() const { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return program_; + } else { + return main_graph_->OriginProgram(); + } + } else { + return program_; + } + } // This method takes ownership of `node`. ir::Node *AddNode(ir::Node *node) { + if (FLAGS_convert_all_blocks) { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), false, + platform::errors::InvalidArgument( + "This graph is main_graph, which shouldn't have nodes or attrs")); + } PADDLE_ENFORCE_EQ(node_set_.find(node) == node_set_.end(), true, platform::errors::PreconditionNotMet( "The node to be added already exists.")); @@ -258,6 +360,9 @@ class Graph { bool IsMainGraph() const { return main_graph_ == nullptr; } Graph *GetSubGraph(const size_t idx) const { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), true, + platform::errors::InvalidArgument("This graph is not main_graph")); PADDLE_ENFORCE_LT( idx, sub_graphs_.size(), platform::errors::InvalidArgument("Invalid sub_graph index")); @@ -268,24 +373,32 @@ class Graph { std::map> InitFromBlock( const BlockDesc &block); - void ReleaseSubgraphs() { sub_graphs_.clear(); } + void ReleaseSubGraphs() { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), true, + platform::errors::InvalidArgument("This graph is not main_graph")); + sub_graphs_.clear(); + } - void AddSubgraph(std::unique_ptr sub_graph) { + void AddSubGraph(std::unique_ptr sub_graph) { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), true, + platform::errors::InvalidArgument("This graph is not main_graph")); sub_graphs_.push_back(std::move(sub_graph)); } - std::unique_ptr CloneSubgraph(const size_t idx); + std::unique_ptr CloneSubGraph(const size_t idx); // TODO(levi): delete this interface after when we can convert all // blocks into sub_graphs. std::map> InitFromProgram( const ProgramDesc &program); + // NOTE: program_ shouldn't be exposed to user. + const ProgramDesc program_; // NOTE: main_graph_ doesn't hold any node. It's used as a container of // sub_graphs, and the sub_graph holds the nodes. const Graph *main_graph_; // not owned. - // NOTE: program_ shouldn't be exposed to user. - const ProgramDesc program_; std::vector> sub_graphs_; std::map attrs_; diff --git a/paddle/fluid/framework/ir/graph_test.cc b/paddle/fluid/framework/ir/graph_test.cc index 66507fe7cafbb..753f863d45249 100644 --- a/paddle/fluid/framework/ir/graph_test.cc +++ b/paddle/fluid/framework/ir/graph_test.cc @@ -104,7 +104,13 @@ TEST(GraphTest, Basic) { ASSERT_EQ(proto::VarType::LOD_TENSOR, prog.MutableBlock(0)->Var("test_out")->GetType()); - std::unique_ptr g(new ir::Graph(prog)); + std::unique_ptr _g(new ir::Graph(prog)); + const ir::Graph *g; + if (FLAGS_convert_all_blocks) { + g = _g->GetSubGraph(0); + } else { + g = _g.get(); + } std::vector nodes(g->Nodes().begin(), g->Nodes().end()); for (ir::Node *n : nodes) { if (n->Name() == "sum") { @@ -141,7 +147,13 @@ TEST(GraphTest, WriteAfterRead) { prog.MutableBlock(0)->Var("b")->SetType(proto::VarType::LOD_TENSOR); prog.MutableBlock(0)->Var("c")->SetType(proto::VarType::LOD_TENSOR); - std::unique_ptr g(new ir::Graph(prog)); + std::unique_ptr _g(new ir::Graph(prog)); + const ir::Graph *g; + if (FLAGS_convert_all_blocks) { + g = _g->GetSubGraph(0); + } else { + g = _g.get(); + } ir::Node *control_dep1 = nullptr; ir::Node *control_dep2 = nullptr; for (ir::Node *n : g->Nodes()) { @@ -180,7 +192,13 @@ TEST(GraphTest, WriteAfterWrite) { prog.MutableBlock(0)->Var("b")->SetType(proto::VarType::LOD_TENSOR); prog.MutableBlock(0)->Var("c")->SetType(proto::VarType::LOD_TENSOR); - std::unique_ptr g(new ir::Graph(prog)); + std::unique_ptr _g(new ir::Graph(prog)); + const ir::Graph *g; + if (FLAGS_convert_all_blocks) { + g = _g->GetSubGraph(0); + } else { + g = _g.get(); + } ir::Node *control_dep1 = nullptr; ir::Node *control_dep2 = nullptr; for (ir::Node *n : g->Nodes()) { @@ -204,7 +222,13 @@ TEST(GraphTest, WriteAfterWrite) { TEST(GraphTest, TestException) { ProgramDesc prog; - std::unique_ptr g(new ir::Graph(prog)); + std::unique_ptr _g(new ir::Graph(prog)); + ir::Graph *g; + if (FLAGS_convert_all_blocks) { + g = _g->GetSubGraph(0); + } else { + g = _g.get(); + } bool not_met_exception = false; try { @@ -250,18 +274,162 @@ TEST(GraphTest, TestException) { TEST(GraphTest, TestAttrCopy) { ProgramDesc prog; - ir::Graph src_g(prog); - ir::Graph dst_g(prog); + std::unique_ptr _src_g(new ir::Graph(prog)); + std::unique_ptr _dst_g(new ir::Graph(prog)); + ir::Graph *src_g; + ir::Graph *dst_g; + if (FLAGS_convert_all_blocks) { + src_g = _src_g->GetSubGraph(0); + dst_g = _dst_g->GetSubGraph(0); + } else { + src_g = _src_g.get(); + dst_g = _dst_g.get(); + } const std::string kIntValue = "int_value"; const std::string kFloatValue = "float_value"; const int INT_VALUE = 3; - src_g.Set(kIntValue, new int(INT_VALUE)); - details::CopyGraphAttrIfExists(src_g, &dst_g, kIntValue); - details::CopyGraphAttrIfExists(src_g, &dst_g, kFloatValue); + src_g->Set(kIntValue, new int(INT_VALUE)); + details::CopyGraphAttrIfExists(*src_g, dst_g, kIntValue); + details::CopyGraphAttrIfExists(*src_g, dst_g, kFloatValue); + + ASSERT_TRUE(dst_g->Has(kIntValue)); + ASSERT_EQ(dst_g->Get(kIntValue), INT_VALUE); + ASSERT_FALSE(dst_g->Has(kFloatValue)); +} + +TEST(GraphTest, TestMultiBlock) { + if (FLAGS_convert_all_blocks) { + // Step1: Build a program with 3 blocks. + ProgramDesc prog; + ASSERT_EQ(prog.Size(), 1UL); + prog.AppendBlock(prog.Block(0)); + prog.AppendBlock(prog.Block(0)); + ASSERT_EQ(prog.Size(), 3UL); + + // Set contents in block_0. + auto *op = prog.MutableBlock(0)->AppendOp(); + op->SetType("sum"); + op->SetInput("X", {"test_a", "test_b", "test_c"}); + op->SetOutput("Out", {"test_out"}); + op->SetAttr("op_role", 1); + + prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarType::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarType::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test_out"); + + op->InferVarType(prog.MutableBlock(0)); + + ASSERT_EQ(proto::VarType::SELECTED_ROWS, + prog.MutableBlock(0)->Var("test_out")->GetType()); + + prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::LOD_TENSOR); + op->InferVarType(prog.MutableBlock(0)); + ASSERT_EQ(proto::VarType::LOD_TENSOR, + prog.MutableBlock(0)->Var("test_out")->GetType()); + + // Set contents in block_1. + op = prog.MutableBlock(1)->AppendOp(); + op->SetType("sum"); + op->SetInput("X", {"a"}); + op->SetOutput("Out", {"b"}); + op->SetAttr("op_role", 1); + + op = prog.MutableBlock(1)->AppendOp(); + op->SetType("dummy"); + op->SetInput("X", {"c"}); + op->SetOutput("Out", {"a"}); + op->SetAttr("op_role", 1); + + prog.MutableBlock(1)->Var("a")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(1)->Var("b")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(1)->Var("c")->SetType(proto::VarType::LOD_TENSOR); + + // Set contents in block_2. + op = prog.MutableBlock(2)->AppendOp(); + op->SetType("sum"); + op->SetInput("X", {"a"}); + op->SetOutput("Out", {"b"}); + op->SetAttr("op_role", 1); + + op = prog.MutableBlock(2)->AppendOp(); + op->SetType("dummy"); + op->SetInput("X", {"c"}); + op->SetOutput("Out", {"b"}); + op->SetAttr("op_role", 1); + + prog.MutableBlock(2)->Var("a")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(2)->Var("b")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(2)->Var("c")->SetType(proto::VarType::LOD_TENSOR); + + // Step2: Convert program into graph, 3 blocks corresponding 3 sub_graphs. + std::unique_ptr g(new ir::Graph(prog)); + ASSERT_EQ(g->IsMainGraph(), true); + + // Check contents in sub_graph_0. + const ir::Graph *g0 = g->GetSubGraph(0); + std::vector nodes(g0->Nodes().begin(), g0->Nodes().end()); + for (ir::Node *n : nodes) { + if (n->Name() == "sum") { + ASSERT_EQ(n->inputs.size(), 3UL); + ASSERT_EQ(n->outputs.size(), 1UL); + } else if (n->Name() == "test_a" || n->Name() == "test_b" || + n->Name() == "test_c") { + ASSERT_EQ(n->inputs.size(), 0UL); + ASSERT_EQ(n->outputs.size(), 1UL); + } else if (n->Name() == "test_out") { + ASSERT_EQ(n->inputs.size(), 1UL); + ASSERT_EQ(n->outputs.size(), 0UL); + } + } + ASSERT_EQ(nodes.size(), 5UL); + + // Check contents in sub_graph_1. + const ir::Graph *g1 = g->GetSubGraph(1); + ir::Node *control_dep1 = nullptr; + ir::Node *control_dep2 = nullptr; + for (ir::Node *n : g1->Nodes()) { + if (n->Name() == "sum") { + ASSERT_EQ(n->outputs[0]->Name(), "b"); + ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1])); + control_dep1 = n->outputs[1]; + ASSERT_EQ(n->outputs.size(), 2UL); + } + if (n->Name() == "dummy") { + ASSERT_EQ(n->inputs[0]->Name(), "c"); + ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1])); + control_dep2 = n->inputs[1]; + ASSERT_EQ(n->inputs.size(), 2UL); + } + } + ASSERT_EQ(control_dep1, control_dep2); + + // Check contents in sub_graph_2. + const ir::Graph *g2 = g->GetSubGraph(2); + control_dep1 = nullptr; + control_dep2 = nullptr; + for (ir::Node *n : g2->Nodes()) { + if (n->Name() == "sum") { + ASSERT_EQ(n->outputs[0]->Name(), "b"); + ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1])); + ASSERT_EQ(n->outputs.size(), 2UL); + control_dep1 = n->outputs[1]; + } + if (n->Name() == "dummy") { + ASSERT_EQ(n->inputs[0]->Name(), "c"); + ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1])); + control_dep2 = n->inputs[1]; + ASSERT_EQ(n->inputs.size(), 2UL); + } + } + ASSERT_NE(control_dep1, nullptr); + ASSERT_NE(control_dep2, nullptr); + ASSERT_EQ(control_dep1, control_dep2); - ASSERT_TRUE(dst_g.Has(kIntValue)); - ASSERT_EQ(dst_g.Get(kIntValue), INT_VALUE); - ASSERT_FALSE(dst_g.Has(kFloatValue)); + // Step3: Colne graph. + std::shared_ptr clone_g = g->Clone(); + ASSERT_EQ(clone_g->IsMainGraph(), true); + } } } // namespace framework From 9dffa5b10bb969d44ff29177d1533dc3db7a208a Mon Sep 17 00:00:00 2001 From: levi131 Date: Wed, 9 Jun 2021 08:26:47 +0000 Subject: [PATCH 07/17] use LOG(WARNING) --- paddle/fluid/framework/ir/graph.cc | 8 +- paddle/fluid/framework/ir/graph.h | 159 ++++++++++++++---------- paddle/fluid/framework/ir/graph_test.cc | 57 ++------- paddle/fluid/framework/ir/pass.cc | 9 +- 4 files changed, 111 insertions(+), 122 deletions(-) diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 1215556839be5..c7957516b08ac 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -275,11 +275,11 @@ void Graph::ResolveHazard( } std::shared_ptr Graph::Clone() { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), true, + platform::errors::InvalidArgument( + "This graph is a subgraph, and can't be cloned individually")); if (FLAGS_convert_all_blocks) { - PADDLE_ENFORCE_EQ( - this->IsMainGraph(), true, - platform::errors::InvalidArgument( - "This graph is a subgraph, and can't be cloned individually")); auto cloned_graph = std::make_shared(this->program_); cloned_graph->ReleaseSubGraphs(); for (size_t idx = 0; idx < this->program_.Size(); ++idx) { diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 8f0b6e45b0f4e..e9c8b0ed4bdbc 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -94,10 +94,12 @@ class Graph { bool Has(const std::string &attr_name) const { if (FLAGS_convert_all_blocks) { - PADDLE_ENFORCE_EQ( - this->IsMainGraph(), false, - platform::errors::InvalidArgument( - "This graph is main_graph, which shouldn't have nodes or attrs")); + if (IsMainGraph()) { + LOG(WARNING) + << "This graph is main_graph, which doesn't have nodes or attrs. " + << "Applying function Has() on the 1st sub_graph"; + return GetSubGraph(0)->Has(attr_name); + } } return attrs_.count(attr_name) > 0; } @@ -105,10 +107,12 @@ class Graph { template AttrType &GetOrInit(const std::string &attr_name) { if (FLAGS_convert_all_blocks) { - PADDLE_ENFORCE_EQ( - this->IsMainGraph(), false, - platform::errors::InvalidArgument( - "This graph is main_graph, which shouldn't have nodes or attrs")); + if (IsMainGraph()) { + LOG(WARNING) + << "This graph is main_graph, which doesn't have nodes or attrs. " + << "Applying function GetOrInit() on the 1st sub_graph"; + return GetSubGraph(0)->GetOrInit(attr_name); + } } if (!Has(attr_name)) { Set(attr_name, new AttrType); @@ -119,10 +123,12 @@ class Graph { template AttrType &Get(const std::string &attr_name) const { if (FLAGS_convert_all_blocks) { - PADDLE_ENFORCE_EQ( - this->IsMainGraph(), false, - platform::errors::InvalidArgument( - "This graph is main_graph, which shouldn't have nodes or attrs")); + if (IsMainGraph()) { + LOG(WARNING) + << "This graph is main_graph, which doesn't have nodes or attrs. " + << "Applying function Get() on the 1st sub_graph"; + return GetSubGraph(0)->Get(attr_name); + } } PADDLE_ENFORCE_EQ( Has(attr_name), true, @@ -141,10 +147,12 @@ class Graph { template void Set(const std::string &attr_name, AttrType *attr) { if (FLAGS_convert_all_blocks) { - PADDLE_ENFORCE_EQ( - this->IsMainGraph(), false, - platform::errors::InvalidArgument( - "This graph is main_graph, which shouldn't have nodes or attrs")); + if (IsMainGraph()) { + LOG(WARNING) + << "This graph is main_graph, which doesn't have nodes or attrs. " + << "Applying function Set() on the 1st sub_graph"; + return GetSubGraph(0)->Set(attr_name, attr); + } } PADDLE_ENFORCE_EQ( attrs_.count(attr_name), 0, @@ -161,10 +169,12 @@ class Graph { template void SetNotOwned(const std::string &attr_name, AttrType *attr) { if (FLAGS_convert_all_blocks) { - PADDLE_ENFORCE_EQ( - this->IsMainGraph(), false, - platform::errors::InvalidArgument( - "This graph is main_graph, which shouldn't have nodes or attrs")); + if (IsMainGraph()) { + LOG(WARNING) + << "This graph is main_graph, which doesn't have nodes or attrs. " + << "Applying function SetNotOwned() on the 1st sub_graph"; + return GetSubGraph(0)->SetNotOwned(attr_name, attr); + } } PADDLE_ENFORCE_EQ( attrs_.count(attr_name), 0, @@ -177,10 +187,12 @@ class Graph { void Erase(const std::string &attr_name) { if (FLAGS_convert_all_blocks) { - PADDLE_ENFORCE_EQ( - this->IsMainGraph(), false, - platform::errors::InvalidArgument( - "This graph is main_graph, which shouldn't have nodes or attrs")); + if (IsMainGraph()) { + LOG(WARNING) + << "This graph is main_graph, which doesn't have nodes or attrs. " + << "Applying function Erase() on the 1st sub_graph"; + return GetSubGraph(0)->Erase(attr_name); + } } PADDLE_ENFORCE_NE( attrs_.count(attr_name), 0, @@ -194,10 +206,12 @@ class Graph { const std::unordered_set &Nodes() const { if (FLAGS_convert_all_blocks) { - PADDLE_ENFORCE_EQ( - this->IsMainGraph(), false, - platform::errors::InvalidArgument( - "This graph is main_graph, which shouldn't have nodes or attrs")); + if (IsMainGraph()) { + LOG(WARNING) + << "This graph is main_graph, which doesn't have nodes or attrs. " + << "Applying function Nodes() on the 1st sub_graph"; + return GetSubGraph(0)->Nodes(); + } } return node_set_; } @@ -205,10 +219,12 @@ class Graph { // Create a normal variable with non-null VarDesc. ir::Node *CreateVarNode(VarDesc *var_desc) { if (FLAGS_convert_all_blocks) { - PADDLE_ENFORCE_EQ( - this->IsMainGraph(), false, - platform::errors::InvalidArgument( - "This graph is main_graph, which shouldn't have nodes or attrs")); + if (IsMainGraph()) { + LOG(WARNING) + << "This graph is main_graph, which doesn't have nodes or attrs. " + << "Applying function CreateVarNode() on the 1st sub_graph"; + return GetSubGraph(0)->CreateVarNode(var_desc); + } } PADDLE_ENFORCE_NOT_NULL( var_desc, platform::errors::InvalidArgument( @@ -221,10 +237,12 @@ class Graph { // Create a normal runnable operator with OpDesc. ir::Node *CreateOpNode(OpDesc *op_desc) { if (FLAGS_convert_all_blocks) { - PADDLE_ENFORCE_EQ( - this->IsMainGraph(), false, - platform::errors::InvalidArgument( - "This graph is main_graph, which shouldn't have nodes or attrs")); + if (IsMainGraph()) { + LOG(WARNING) + << "This graph is main_graph, which doesn't have nodes or attrs. " + << "Applying function CreateOpNode() on the 1st sub_graph"; + return GetSubGraph(0)->CreateOpNode(op_desc); + } } PADDLE_ENFORCE_NOT_NULL( op_desc, platform::errors::InvalidArgument( @@ -239,10 +257,12 @@ class Graph { // other var, considering dependency analysis. ir::Node *CreateControlDepVar() { if (FLAGS_convert_all_blocks) { - PADDLE_ENFORCE_EQ( - this->IsMainGraph(), false, - platform::errors::InvalidArgument( - "This graph is main_graph, which shouldn't have nodes or attrs")); + if (IsMainGraph()) { + LOG(WARNING) + << "This graph is main_graph, which doesn't have nodes or attrs. " + << "Applying function CreateControlDepVar() on the 1st sub_graph"; + return GetSubGraph(0)->CreateControlDepVar(); + } } // TODO(panyx0718): control var name should be really unique. const std::string name = string::Sprintf( @@ -257,10 +277,12 @@ class Graph { // or "copy" from another node. Avoid using it if possible. ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) { if (FLAGS_convert_all_blocks) { - PADDLE_ENFORCE_EQ( - this->IsMainGraph(), false, - platform::errors::InvalidArgument( - "This graph is main_graph, which shouldn't have nodes or attrs")); + if (IsMainGraph()) { + LOG(WARNING) + << "This graph is main_graph, which doesn't have nodes or attrs. " + << "Applying function CreateEmptyNode() on the 1st sub_graph"; + return GetSubGraph(0)->CreateEmptyNode(name, type); + } } auto *x = AddNode(new ir::Node(name, type)); x->SetId(num_node_created_++); @@ -271,10 +293,12 @@ class Graph { // nodes. std::vector> ReleaseNodes() { if (FLAGS_convert_all_blocks) { - PADDLE_ENFORCE_EQ( - this->IsMainGraph(), false, - platform::errors::InvalidArgument( - "This graph is main_graph, which shouldn't have nodes or attrs")); + if (IsMainGraph()) { + LOG(WARNING) + << "This graph is main_graph, which doesn't have nodes or attrs. " + << "Applying function ReleaseNodes() on the 1st sub_graph"; + return GetSubGraph(0)->ReleaseNodes(); + } } std::vector> ret; for (auto &n : nodes_) { @@ -287,10 +311,12 @@ class Graph { std::unique_ptr RemoveNode(ir::Node *node) { if (FLAGS_convert_all_blocks) { - PADDLE_ENFORCE_EQ( - this->IsMainGraph(), false, - platform::errors::InvalidArgument( - "This graph is main_graph, which shouldn't have nodes or attrs")); + if (IsMainGraph()) { + LOG(WARNING) + << "This graph is main_graph, which doesn't have nodes or attrs. " + << "Applying function RemoveNode() on the 1st sub_graph"; + return GetSubGraph(0)->RemoveNode(node); + } } PADDLE_ENFORCE_EQ(node_set_.find(node) != node_set_.end(), true, platform::errors::PreconditionNotMet( @@ -305,10 +331,12 @@ class Graph { // NOTE low performance, but simple and secure. Node *RetrieveNode(int id) { if (FLAGS_convert_all_blocks) { - PADDLE_ENFORCE_EQ( - this->IsMainGraph(), false, - platform::errors::InvalidArgument( - "This graph is main_graph, which shouldn't have nodes or attrs")); + if (IsMainGraph()) { + LOG(WARNING) + << "This graph is main_graph, which doesn't have nodes or attrs. " + << "Applying function RetrieveNode() on the 1st sub_graph"; + return GetSubGraph(0)->RetrieveNode(id); + } } for (auto &node : nodes_) { if (node.second->id() == id) { @@ -324,23 +352,24 @@ class Graph { // the returned OriginProgram. const ProgramDesc &OriginProgram() const { if (FLAGS_convert_all_blocks) { - if (IsMainGraph()) { - return program_; - } else { + if (!IsMainGraph()) { + LOG(WARNING) << "This graph is sub_graph, will return the program_ of " + "main_graph"; return main_graph_->OriginProgram(); } - } else { - return program_; } + return program_; } // This method takes ownership of `node`. ir::Node *AddNode(ir::Node *node) { if (FLAGS_convert_all_blocks) { - PADDLE_ENFORCE_EQ( - this->IsMainGraph(), false, - platform::errors::InvalidArgument( - "This graph is main_graph, which shouldn't have nodes or attrs")); + if (IsMainGraph()) { + LOG(WARNING) + << "This graph is main_graph, which doesn't have nodes or attrs. " + << "Applying function AddNode() on the 1st sub_graph"; + return GetSubGraph(0)->AddNode(node); + } } PADDLE_ENFORCE_EQ(node_set_.find(node) == node_set_.end(), true, platform::errors::PreconditionNotMet( diff --git a/paddle/fluid/framework/ir/graph_test.cc b/paddle/fluid/framework/ir/graph_test.cc index 753f863d45249..2032433b91d9d 100644 --- a/paddle/fluid/framework/ir/graph_test.cc +++ b/paddle/fluid/framework/ir/graph_test.cc @@ -104,13 +104,7 @@ TEST(GraphTest, Basic) { ASSERT_EQ(proto::VarType::LOD_TENSOR, prog.MutableBlock(0)->Var("test_out")->GetType()); - std::unique_ptr _g(new ir::Graph(prog)); - const ir::Graph *g; - if (FLAGS_convert_all_blocks) { - g = _g->GetSubGraph(0); - } else { - g = _g.get(); - } + std::unique_ptr g(new ir::Graph(prog)); std::vector nodes(g->Nodes().begin(), g->Nodes().end()); for (ir::Node *n : nodes) { if (n->Name() == "sum") { @@ -147,13 +141,7 @@ TEST(GraphTest, WriteAfterRead) { prog.MutableBlock(0)->Var("b")->SetType(proto::VarType::LOD_TENSOR); prog.MutableBlock(0)->Var("c")->SetType(proto::VarType::LOD_TENSOR); - std::unique_ptr _g(new ir::Graph(prog)); - const ir::Graph *g; - if (FLAGS_convert_all_blocks) { - g = _g->GetSubGraph(0); - } else { - g = _g.get(); - } + std::unique_ptr g(new ir::Graph(prog)); ir::Node *control_dep1 = nullptr; ir::Node *control_dep2 = nullptr; for (ir::Node *n : g->Nodes()) { @@ -192,13 +180,7 @@ TEST(GraphTest, WriteAfterWrite) { prog.MutableBlock(0)->Var("b")->SetType(proto::VarType::LOD_TENSOR); prog.MutableBlock(0)->Var("c")->SetType(proto::VarType::LOD_TENSOR); - std::unique_ptr _g(new ir::Graph(prog)); - const ir::Graph *g; - if (FLAGS_convert_all_blocks) { - g = _g->GetSubGraph(0); - } else { - g = _g.get(); - } + std::unique_ptr g(new ir::Graph(prog)); ir::Node *control_dep1 = nullptr; ir::Node *control_dep2 = nullptr; for (ir::Node *n : g->Nodes()) { @@ -222,13 +204,7 @@ TEST(GraphTest, WriteAfterWrite) { TEST(GraphTest, TestException) { ProgramDesc prog; - std::unique_ptr _g(new ir::Graph(prog)); - ir::Graph *g; - if (FLAGS_convert_all_blocks) { - g = _g->GetSubGraph(0); - } else { - g = _g.get(); - } + std::unique_ptr g(new ir::Graph(prog)); bool not_met_exception = false; try { @@ -274,27 +250,18 @@ TEST(GraphTest, TestException) { TEST(GraphTest, TestAttrCopy) { ProgramDesc prog; - std::unique_ptr _src_g(new ir::Graph(prog)); - std::unique_ptr _dst_g(new ir::Graph(prog)); - ir::Graph *src_g; - ir::Graph *dst_g; - if (FLAGS_convert_all_blocks) { - src_g = _src_g->GetSubGraph(0); - dst_g = _dst_g->GetSubGraph(0); - } else { - src_g = _src_g.get(); - dst_g = _dst_g.get(); - } + ir::Graph src_g(prog); + ir::Graph dst_g(prog); const std::string kIntValue = "int_value"; const std::string kFloatValue = "float_value"; const int INT_VALUE = 3; - src_g->Set(kIntValue, new int(INT_VALUE)); - details::CopyGraphAttrIfExists(*src_g, dst_g, kIntValue); - details::CopyGraphAttrIfExists(*src_g, dst_g, kFloatValue); + src_g.Set(kIntValue, new int(INT_VALUE)); + details::CopyGraphAttrIfExists(src_g, &dst_g, kIntValue); + details::CopyGraphAttrIfExists(src_g, &dst_g, kFloatValue); - ASSERT_TRUE(dst_g->Has(kIntValue)); - ASSERT_EQ(dst_g->Get(kIntValue), INT_VALUE); - ASSERT_FALSE(dst_g->Has(kFloatValue)); + ASSERT_TRUE(dst_g.Has(kIntValue)); + ASSERT_EQ(dst_g.Get(kIntValue), INT_VALUE); + ASSERT_FALSE(dst_g.Has(kFloatValue)); } TEST(GraphTest, TestMultiBlock) { diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index 763e376b40aaf..0e5f5867f47b2 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -32,16 +32,9 @@ namespace framework { namespace ir { Graph* Pass::Apply(Graph* graph) const { + CheckPrevPass(); PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); - if (FLAGS_convert_all_blocks) { - // NOTE(levi): If graph is main_graph, apply pass on the 1st sub_graph. - if (graph->IsMainGraph()) { - this->Apply(graph->GetSubGraph(0)); - return graph; - } - } - CheckPrevPass(); for (const std::string& attr : required_pass_attrs_) { PADDLE_ENFORCE_NE( attrs_.find(attr), attrs_.end(), From 1655150d02d53844c9682f7f3583341cfc730252 Mon Sep 17 00:00:00 2001 From: levi131 Date: Wed, 9 Jun 2021 08:34:15 +0000 Subject: [PATCH 08/17] modify format --- paddle/fluid/framework/ir/graph.h | 85 +++++++++++++++---------------- 1 file changed, 40 insertions(+), 45 deletions(-) diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index e9c8b0ed4bdbc..ae02cb6712bb5 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -95,9 +95,8 @@ class Graph { bool Has(const std::string &attr_name) const { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { - LOG(WARNING) - << "This graph is main_graph, which doesn't have nodes or attrs. " - << "Applying function Has() on the 1st sub_graph"; + LOG(WARNING) << "This graph is main_graph, which doesn't have nodes or " + "attrs. Applying function Has() on the 1st sub_graph"; return GetSubGraph(0)->Has(attr_name); } } @@ -108,9 +107,9 @@ class Graph { AttrType &GetOrInit(const std::string &attr_name) { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { - LOG(WARNING) - << "This graph is main_graph, which doesn't have nodes or attrs. " - << "Applying function GetOrInit() on the 1st sub_graph"; + LOG(WARNING) << "This graph is main_graph, which doesn't have nodes or " + "attrs. Applying function GetOrInit() on the 1st " + "sub_graph"; return GetSubGraph(0)->GetOrInit(attr_name); } } @@ -124,9 +123,8 @@ class Graph { AttrType &Get(const std::string &attr_name) const { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { - LOG(WARNING) - << "This graph is main_graph, which doesn't have nodes or attrs. " - << "Applying function Get() on the 1st sub_graph"; + LOG(WARNING) << "This graph is main_graph, which doesn't have nodes or " + "attrs. Applying function Get() on the 1st sub_graph"; return GetSubGraph(0)->Get(attr_name); } } @@ -148,9 +146,8 @@ class Graph { void Set(const std::string &attr_name, AttrType *attr) { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { - LOG(WARNING) - << "This graph is main_graph, which doesn't have nodes or attrs. " - << "Applying function Set() on the 1st sub_graph"; + LOG(WARNING) << "This graph is main_graph, which doesn't have nodes or " + "attrs. Applying function Set() on the 1st sub_graph"; return GetSubGraph(0)->Set(attr_name, attr); } } @@ -170,9 +167,9 @@ class Graph { void SetNotOwned(const std::string &attr_name, AttrType *attr) { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { - LOG(WARNING) - << "This graph is main_graph, which doesn't have nodes or attrs. " - << "Applying function SetNotOwned() on the 1st sub_graph"; + LOG(WARNING) << "This graph is main_graph, which doesn't have nodes or " + "attrs. Applying function SetNotOwned() on the 1st " + "sub_graph"; return GetSubGraph(0)->SetNotOwned(attr_name, attr); } } @@ -188,9 +185,8 @@ class Graph { void Erase(const std::string &attr_name) { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { - LOG(WARNING) - << "This graph is main_graph, which doesn't have nodes or attrs. " - << "Applying function Erase() on the 1st sub_graph"; + LOG(WARNING) << "This graph is main_graph, which doesn't have nodes or " + "attrs. Applying function Erase() on the 1st sub_graph"; return GetSubGraph(0)->Erase(attr_name); } } @@ -207,9 +203,8 @@ class Graph { const std::unordered_set &Nodes() const { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { - LOG(WARNING) - << "This graph is main_graph, which doesn't have nodes or attrs. " - << "Applying function Nodes() on the 1st sub_graph"; + LOG(WARNING) << "This graph is main_graph, which doesn't have nodes or " + "attrs. Applying function Nodes() on the 1st sub_graph"; return GetSubGraph(0)->Nodes(); } } @@ -220,9 +215,9 @@ class Graph { ir::Node *CreateVarNode(VarDesc *var_desc) { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { - LOG(WARNING) - << "This graph is main_graph, which doesn't have nodes or attrs. " - << "Applying function CreateVarNode() on the 1st sub_graph"; + LOG(WARNING) << "This graph is main_graph, which doesn't have nodes or " + "attrs. Applying function CreateVarNode() on the 1st " + "sub_graph"; return GetSubGraph(0)->CreateVarNode(var_desc); } } @@ -238,9 +233,9 @@ class Graph { ir::Node *CreateOpNode(OpDesc *op_desc) { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { - LOG(WARNING) - << "This graph is main_graph, which doesn't have nodes or attrs. " - << "Applying function CreateOpNode() on the 1st sub_graph"; + LOG(WARNING) << "This graph is main_graph, which doesn't have nodes or " + "attrs. Applying function CreateOpNode() on the 1st " + "sub_graph"; return GetSubGraph(0)->CreateOpNode(op_desc); } } @@ -258,9 +253,9 @@ class Graph { ir::Node *CreateControlDepVar() { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { - LOG(WARNING) - << "This graph is main_graph, which doesn't have nodes or attrs. " - << "Applying function CreateControlDepVar() on the 1st sub_graph"; + LOG(WARNING) << "This graph is main_graph, which doesn't have nodes or " + "attrs. Applying function CreateControlDepVar() on the " + "1st sub_graph"; return GetSubGraph(0)->CreateControlDepVar(); } } @@ -278,9 +273,9 @@ class Graph { ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { - LOG(WARNING) - << "This graph is main_graph, which doesn't have nodes or attrs. " - << "Applying function CreateEmptyNode() on the 1st sub_graph"; + LOG(WARNING) << "This graph is main_graph, which doesn't have nodes or " + "attrs. Applying function CreateEmptyNode() on the 1st " + "sub_graph"; return GetSubGraph(0)->CreateEmptyNode(name, type); } } @@ -294,9 +289,9 @@ class Graph { std::vector> ReleaseNodes() { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { - LOG(WARNING) - << "This graph is main_graph, which doesn't have nodes or attrs. " - << "Applying function ReleaseNodes() on the 1st sub_graph"; + LOG(WARNING) << "This graph is main_graph, which doesn't have nodes or " + "attrs. Applying function ReleaseNodes() on the 1st " + "sub_graph"; return GetSubGraph(0)->ReleaseNodes(); } } @@ -312,9 +307,9 @@ class Graph { std::unique_ptr RemoveNode(ir::Node *node) { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { - LOG(WARNING) - << "This graph is main_graph, which doesn't have nodes or attrs. " - << "Applying function RemoveNode() on the 1st sub_graph"; + LOG(WARNING) << "This graph is main_graph, which doesn't have nodes or " + "attrs. Applying function RemoveNode() on the 1st " + "sub_graph"; return GetSubGraph(0)->RemoveNode(node); } } @@ -332,9 +327,9 @@ class Graph { Node *RetrieveNode(int id) { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { - LOG(WARNING) - << "This graph is main_graph, which doesn't have nodes or attrs. " - << "Applying function RetrieveNode() on the 1st sub_graph"; + LOG(WARNING) << "This graph is main_graph, which doesn't have nodes or " + "attrs. Applying function RetrieveNode() on the 1st " + "sub_graph"; return GetSubGraph(0)->RetrieveNode(id); } } @@ -365,9 +360,9 @@ class Graph { ir::Node *AddNode(ir::Node *node) { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { - LOG(WARNING) - << "This graph is main_graph, which doesn't have nodes or attrs. " - << "Applying function AddNode() on the 1st sub_graph"; + LOG(WARNING) << "This graph is main_graph, which doesn't have nodes or " + "attrs. Applying function AddNode() on the 1st " + "sub_graph"; return GetSubGraph(0)->AddNode(node); } } From 0229f37cfb86abb75799ac432c1edaa6003d067b Mon Sep 17 00:00:00 2001 From: levi131 Date: Tue, 15 Jun 2021 03:00:33 +0000 Subject: [PATCH 09/17] small spell modify --- paddle/fluid/framework/ir/graph.cc | 2 +- paddle/fluid/framework/ir/graph.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index c7957516b08ac..b4e13114639c9 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -278,7 +278,7 @@ std::shared_ptr Graph::Clone() { PADDLE_ENFORCE_EQ( this->IsMainGraph(), true, platform::errors::InvalidArgument( - "This graph is a subgraph, and can't be cloned individually")); + "This graph is a sub_graph, and can't be cloned individually")); if (FLAGS_convert_all_blocks) { auto cloned_graph = std::make_shared(this->program_); cloned_graph->ReleaseSubGraphs(); diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index ae02cb6712bb5..a3384689f3716 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -82,7 +82,7 @@ namespace ir { class Graph { public: explicit Graph(const ProgramDesc &program); - Graph(const BlockDesc &block, const Graph *parent); + Graph(const BlockDesc &block, const Graph *main_graph); virtual ~Graph() { for (auto &attr : attrs_) { From 98bc7adce6ac6ff2ce494928ebad8e7ffb517315 Mon Sep 17 00:00:00 2001 From: levi131 Date: Tue, 22 Jun 2021 08:45:51 +0000 Subject: [PATCH 10/17] ensure GraphTest.TestMultiBlock works and rm WARNINGs in APIs for class Graph --- paddle/fluid/framework/ir/graph.h | 42 ---- paddle/fluid/framework/ir/graph_test.cc | 259 ++++++++++++------------ 2 files changed, 132 insertions(+), 169 deletions(-) diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index a3384689f3716..f03be51cbbdd6 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -95,8 +95,6 @@ class Graph { bool Has(const std::string &attr_name) const { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { - LOG(WARNING) << "This graph is main_graph, which doesn't have nodes or " - "attrs. Applying function Has() on the 1st sub_graph"; return GetSubGraph(0)->Has(attr_name); } } @@ -107,9 +105,6 @@ class Graph { AttrType &GetOrInit(const std::string &attr_name) { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { - LOG(WARNING) << "This graph is main_graph, which doesn't have nodes or " - "attrs. Applying function GetOrInit() on the 1st " - "sub_graph"; return GetSubGraph(0)->GetOrInit(attr_name); } } @@ -123,8 +118,6 @@ class Graph { AttrType &Get(const std::string &attr_name) const { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { - LOG(WARNING) << "This graph is main_graph, which doesn't have nodes or " - "attrs. Applying function Get() on the 1st sub_graph"; return GetSubGraph(0)->Get(attr_name); } } @@ -146,8 +139,6 @@ class Graph { void Set(const std::string &attr_name, AttrType *attr) { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { - LOG(WARNING) << "This graph is main_graph, which doesn't have nodes or " - "attrs. Applying function Set() on the 1st sub_graph"; return GetSubGraph(0)->Set(attr_name, attr); } } @@ -167,9 +158,6 @@ class Graph { void SetNotOwned(const std::string &attr_name, AttrType *attr) { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { - LOG(WARNING) << "This graph is main_graph, which doesn't have nodes or " - "attrs. Applying function SetNotOwned() on the 1st " - "sub_graph"; return GetSubGraph(0)->SetNotOwned(attr_name, attr); } } @@ -185,8 +173,6 @@ class Graph { void Erase(const std::string &attr_name) { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { - LOG(WARNING) << "This graph is main_graph, which doesn't have nodes or " - "attrs. Applying function Erase() on the 1st sub_graph"; return GetSubGraph(0)->Erase(attr_name); } } @@ -203,8 +189,6 @@ class Graph { const std::unordered_set &Nodes() const { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { - LOG(WARNING) << "This graph is main_graph, which doesn't have nodes or " - "attrs. Applying function Nodes() on the 1st sub_graph"; return GetSubGraph(0)->Nodes(); } } @@ -215,9 +199,6 @@ class Graph { ir::Node *CreateVarNode(VarDesc *var_desc) { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { - LOG(WARNING) << "This graph is main_graph, which doesn't have nodes or " - "attrs. Applying function CreateVarNode() on the 1st " - "sub_graph"; return GetSubGraph(0)->CreateVarNode(var_desc); } } @@ -233,9 +214,6 @@ class Graph { ir::Node *CreateOpNode(OpDesc *op_desc) { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { - LOG(WARNING) << "This graph is main_graph, which doesn't have nodes or " - "attrs. Applying function CreateOpNode() on the 1st " - "sub_graph"; return GetSubGraph(0)->CreateOpNode(op_desc); } } @@ -253,9 +231,6 @@ class Graph { ir::Node *CreateControlDepVar() { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { - LOG(WARNING) << "This graph is main_graph, which doesn't have nodes or " - "attrs. Applying function CreateControlDepVar() on the " - "1st sub_graph"; return GetSubGraph(0)->CreateControlDepVar(); } } @@ -273,9 +248,6 @@ class Graph { ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { - LOG(WARNING) << "This graph is main_graph, which doesn't have nodes or " - "attrs. Applying function CreateEmptyNode() on the 1st " - "sub_graph"; return GetSubGraph(0)->CreateEmptyNode(name, type); } } @@ -289,9 +261,6 @@ class Graph { std::vector> ReleaseNodes() { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { - LOG(WARNING) << "This graph is main_graph, which doesn't have nodes or " - "attrs. Applying function ReleaseNodes() on the 1st " - "sub_graph"; return GetSubGraph(0)->ReleaseNodes(); } } @@ -307,9 +276,6 @@ class Graph { std::unique_ptr RemoveNode(ir::Node *node) { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { - LOG(WARNING) << "This graph is main_graph, which doesn't have nodes or " - "attrs. Applying function RemoveNode() on the 1st " - "sub_graph"; return GetSubGraph(0)->RemoveNode(node); } } @@ -327,9 +293,6 @@ class Graph { Node *RetrieveNode(int id) { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { - LOG(WARNING) << "This graph is main_graph, which doesn't have nodes or " - "attrs. Applying function RetrieveNode() on the 1st " - "sub_graph"; return GetSubGraph(0)->RetrieveNode(id); } } @@ -348,8 +311,6 @@ class Graph { const ProgramDesc &OriginProgram() const { if (FLAGS_convert_all_blocks) { if (!IsMainGraph()) { - LOG(WARNING) << "This graph is sub_graph, will return the program_ of " - "main_graph"; return main_graph_->OriginProgram(); } } @@ -360,9 +321,6 @@ class Graph { ir::Node *AddNode(ir::Node *node) { if (FLAGS_convert_all_blocks) { if (IsMainGraph()) { - LOG(WARNING) << "This graph is main_graph, which doesn't have nodes or " - "attrs. Applying function AddNode() on the 1st " - "sub_graph"; return GetSubGraph(0)->AddNode(node); } } diff --git a/paddle/fluid/framework/ir/graph_test.cc b/paddle/fluid/framework/ir/graph_test.cc index 2032433b91d9d..05252933b3d00 100644 --- a/paddle/fluid/framework/ir/graph_test.cc +++ b/paddle/fluid/framework/ir/graph_test.cc @@ -265,138 +265,143 @@ TEST(GraphTest, TestAttrCopy) { } TEST(GraphTest, TestMultiBlock) { - if (FLAGS_convert_all_blocks) { - // Step1: Build a program with 3 blocks. - ProgramDesc prog; - ASSERT_EQ(prog.Size(), 1UL); - prog.AppendBlock(prog.Block(0)); - prog.AppendBlock(prog.Block(0)); - ASSERT_EQ(prog.Size(), 3UL); - - // Set contents in block_0. - auto *op = prog.MutableBlock(0)->AppendOp(); - op->SetType("sum"); - op->SetInput("X", {"test_a", "test_b", "test_c"}); - op->SetOutput("Out", {"test_out"}); - op->SetAttr("op_role", 1); - - prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarType::SELECTED_ROWS); - prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::SELECTED_ROWS); - prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarType::SELECTED_ROWS); - prog.MutableBlock(0)->Var("test_out"); - - op->InferVarType(prog.MutableBlock(0)); - - ASSERT_EQ(proto::VarType::SELECTED_ROWS, - prog.MutableBlock(0)->Var("test_out")->GetType()); - - prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::LOD_TENSOR); - op->InferVarType(prog.MutableBlock(0)); - ASSERT_EQ(proto::VarType::LOD_TENSOR, - prog.MutableBlock(0)->Var("test_out")->GetType()); - - // Set contents in block_1. - op = prog.MutableBlock(1)->AppendOp(); - op->SetType("sum"); - op->SetInput("X", {"a"}); - op->SetOutput("Out", {"b"}); - op->SetAttr("op_role", 1); - - op = prog.MutableBlock(1)->AppendOp(); - op->SetType("dummy"); - op->SetInput("X", {"c"}); - op->SetOutput("Out", {"a"}); - op->SetAttr("op_role", 1); - - prog.MutableBlock(1)->Var("a")->SetType(proto::VarType::LOD_TENSOR); - prog.MutableBlock(1)->Var("b")->SetType(proto::VarType::LOD_TENSOR); - prog.MutableBlock(1)->Var("c")->SetType(proto::VarType::LOD_TENSOR); - - // Set contents in block_2. - op = prog.MutableBlock(2)->AppendOp(); - op->SetType("sum"); - op->SetInput("X", {"a"}); - op->SetOutput("Out", {"b"}); - op->SetAttr("op_role", 1); - - op = prog.MutableBlock(2)->AppendOp(); - op->SetType("dummy"); - op->SetInput("X", {"c"}); - op->SetOutput("Out", {"b"}); - op->SetAttr("op_role", 1); - - prog.MutableBlock(2)->Var("a")->SetType(proto::VarType::LOD_TENSOR); - prog.MutableBlock(2)->Var("b")->SetType(proto::VarType::LOD_TENSOR); - prog.MutableBlock(2)->Var("c")->SetType(proto::VarType::LOD_TENSOR); - - // Step2: Convert program into graph, 3 blocks corresponding 3 sub_graphs. - std::unique_ptr g(new ir::Graph(prog)); - ASSERT_EQ(g->IsMainGraph(), true); - - // Check contents in sub_graph_0. - const ir::Graph *g0 = g->GetSubGraph(0); - std::vector nodes(g0->Nodes().begin(), g0->Nodes().end()); - for (ir::Node *n : nodes) { - if (n->Name() == "sum") { - ASSERT_EQ(n->inputs.size(), 3UL); - ASSERT_EQ(n->outputs.size(), 1UL); - } else if (n->Name() == "test_a" || n->Name() == "test_b" || - n->Name() == "test_c") { - ASSERT_EQ(n->inputs.size(), 0UL); - ASSERT_EQ(n->outputs.size(), 1UL); - } else if (n->Name() == "test_out") { - ASSERT_EQ(n->inputs.size(), 1UL); - ASSERT_EQ(n->outputs.size(), 0UL); - } + // Set FLAGS_convert_all_blocks to true to make sure this test works. + bool flag_temp = FLAGS_convert_all_blocks; + FLAGS_convert_all_blocks = true; + + // Step1: Build a program with 3 blocks. + ProgramDesc prog; + ASSERT_EQ(prog.Size(), 1UL); + prog.AppendBlock(prog.Block(0)); + prog.AppendBlock(prog.Block(0)); + ASSERT_EQ(prog.Size(), 3UL); + + // Set contents in block_0. + auto *op = prog.MutableBlock(0)->AppendOp(); + op->SetType("sum"); + op->SetInput("X", {"test_a", "test_b", "test_c"}); + op->SetOutput("Out", {"test_out"}); + op->SetAttr("op_role", 1); + + prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarType::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarType::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test_out"); + + op->InferVarType(prog.MutableBlock(0)); + + ASSERT_EQ(proto::VarType::SELECTED_ROWS, + prog.MutableBlock(0)->Var("test_out")->GetType()); + + prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::LOD_TENSOR); + op->InferVarType(prog.MutableBlock(0)); + ASSERT_EQ(proto::VarType::LOD_TENSOR, + prog.MutableBlock(0)->Var("test_out")->GetType()); + + // Set contents in block_1. + op = prog.MutableBlock(1)->AppendOp(); + op->SetType("sum"); + op->SetInput("X", {"a"}); + op->SetOutput("Out", {"b"}); + op->SetAttr("op_role", 1); + + op = prog.MutableBlock(1)->AppendOp(); + op->SetType("dummy"); + op->SetInput("X", {"c"}); + op->SetOutput("Out", {"a"}); + op->SetAttr("op_role", 1); + + prog.MutableBlock(1)->Var("a")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(1)->Var("b")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(1)->Var("c")->SetType(proto::VarType::LOD_TENSOR); + + // Set contents in block_2. + op = prog.MutableBlock(2)->AppendOp(); + op->SetType("sum"); + op->SetInput("X", {"a"}); + op->SetOutput("Out", {"b"}); + op->SetAttr("op_role", 1); + + op = prog.MutableBlock(2)->AppendOp(); + op->SetType("dummy"); + op->SetInput("X", {"c"}); + op->SetOutput("Out", {"b"}); + op->SetAttr("op_role", 1); + + prog.MutableBlock(2)->Var("a")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(2)->Var("b")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(2)->Var("c")->SetType(proto::VarType::LOD_TENSOR); + + // Step2: Convert program into graph, 3 blocks corresponding 3 sub_graphs. + std::unique_ptr g(new ir::Graph(prog)); + ASSERT_EQ(g->IsMainGraph(), true); + + // Check contents in sub_graph_0. + const ir::Graph *g0 = g->GetSubGraph(0); + std::vector nodes(g0->Nodes().begin(), g0->Nodes().end()); + for (ir::Node *n : nodes) { + if (n->Name() == "sum") { + ASSERT_EQ(n->inputs.size(), 3UL); + ASSERT_EQ(n->outputs.size(), 1UL); + } else if (n->Name() == "test_a" || n->Name() == "test_b" || + n->Name() == "test_c") { + ASSERT_EQ(n->inputs.size(), 0UL); + ASSERT_EQ(n->outputs.size(), 1UL); + } else if (n->Name() == "test_out") { + ASSERT_EQ(n->inputs.size(), 1UL); + ASSERT_EQ(n->outputs.size(), 0UL); } - ASSERT_EQ(nodes.size(), 5UL); - - // Check contents in sub_graph_1. - const ir::Graph *g1 = g->GetSubGraph(1); - ir::Node *control_dep1 = nullptr; - ir::Node *control_dep2 = nullptr; - for (ir::Node *n : g1->Nodes()) { - if (n->Name() == "sum") { - ASSERT_EQ(n->outputs[0]->Name(), "b"); - ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1])); - control_dep1 = n->outputs[1]; - ASSERT_EQ(n->outputs.size(), 2UL); - } - if (n->Name() == "dummy") { - ASSERT_EQ(n->inputs[0]->Name(), "c"); - ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1])); - control_dep2 = n->inputs[1]; - ASSERT_EQ(n->inputs.size(), 2UL); - } + } + ASSERT_EQ(nodes.size(), 5UL); + + // Check contents in sub_graph_1. + const ir::Graph *g1 = g->GetSubGraph(1); + ir::Node *control_dep1 = nullptr; + ir::Node *control_dep2 = nullptr; + for (ir::Node *n : g1->Nodes()) { + if (n->Name() == "sum") { + ASSERT_EQ(n->outputs[0]->Name(), "b"); + ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1])); + control_dep1 = n->outputs[1]; + ASSERT_EQ(n->outputs.size(), 2UL); } - ASSERT_EQ(control_dep1, control_dep2); - - // Check contents in sub_graph_2. - const ir::Graph *g2 = g->GetSubGraph(2); - control_dep1 = nullptr; - control_dep2 = nullptr; - for (ir::Node *n : g2->Nodes()) { - if (n->Name() == "sum") { - ASSERT_EQ(n->outputs[0]->Name(), "b"); - ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1])); - ASSERT_EQ(n->outputs.size(), 2UL); - control_dep1 = n->outputs[1]; - } - if (n->Name() == "dummy") { - ASSERT_EQ(n->inputs[0]->Name(), "c"); - ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1])); - control_dep2 = n->inputs[1]; - ASSERT_EQ(n->inputs.size(), 2UL); - } + if (n->Name() == "dummy") { + ASSERT_EQ(n->inputs[0]->Name(), "c"); + ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1])); + control_dep2 = n->inputs[1]; + ASSERT_EQ(n->inputs.size(), 2UL); } - ASSERT_NE(control_dep1, nullptr); - ASSERT_NE(control_dep2, nullptr); - ASSERT_EQ(control_dep1, control_dep2); + } + ASSERT_EQ(control_dep1, control_dep2); - // Step3: Colne graph. - std::shared_ptr clone_g = g->Clone(); - ASSERT_EQ(clone_g->IsMainGraph(), true); + // Check contents in sub_graph_2. + const ir::Graph *g2 = g->GetSubGraph(2); + control_dep1 = nullptr; + control_dep2 = nullptr; + for (ir::Node *n : g2->Nodes()) { + if (n->Name() == "sum") { + ASSERT_EQ(n->outputs[0]->Name(), "b"); + ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1])); + ASSERT_EQ(n->outputs.size(), 2UL); + control_dep1 = n->outputs[1]; + } + if (n->Name() == "dummy") { + ASSERT_EQ(n->inputs[0]->Name(), "c"); + ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1])); + control_dep2 = n->inputs[1]; + ASSERT_EQ(n->inputs.size(), 2UL); + } } + ASSERT_NE(control_dep1, nullptr); + ASSERT_NE(control_dep2, nullptr); + ASSERT_EQ(control_dep1, control_dep2); + + // Step3: Colne graph. + std::shared_ptr clone_g = g->Clone(); + ASSERT_EQ(clone_g->IsMainGraph(), true); + + // Recover FLAGS_convert_all_blocks. + FLAGS_convert_all_blocks = flag_temp; } } // namespace framework From c0df219422a5c1725a6cddc5cc79114d8b77bbca Mon Sep 17 00:00:00 2001 From: levi131 Date: Tue, 22 Jun 2021 08:57:19 +0000 Subject: [PATCH 11/17] small change to re-start CI --- paddle/fluid/framework/ir/graph_test.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_test.cc b/paddle/fluid/framework/ir/graph_test.cc index 05252933b3d00..6ceb98b900625 100644 --- a/paddle/fluid/framework/ir/graph_test.cc +++ b/paddle/fluid/framework/ir/graph_test.cc @@ -287,9 +287,7 @@ TEST(GraphTest, TestMultiBlock) { prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::SELECTED_ROWS); prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarType::SELECTED_ROWS); prog.MutableBlock(0)->Var("test_out"); - op->InferVarType(prog.MutableBlock(0)); - ASSERT_EQ(proto::VarType::SELECTED_ROWS, prog.MutableBlock(0)->Var("test_out")->GetType()); From 0b62f2d24e992720266987beeb98c4cce29deea3 Mon Sep 17 00:00:00 2001 From: levi131 Date: Wed, 23 Jun 2021 08:40:57 +0000 Subject: [PATCH 12/17] enable some TestPass cases --- paddle/fluid/framework/ir/graph_test.cc | 2 +- paddle/fluid/framework/ir/pass_test.cc | 109 ++++++++++++++++++++++++ 2 files changed, 110 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/ir/graph_test.cc b/paddle/fluid/framework/ir/graph_test.cc index 6ceb98b900625..163bd996c0010 100644 --- a/paddle/fluid/framework/ir/graph_test.cc +++ b/paddle/fluid/framework/ir/graph_test.cc @@ -394,7 +394,7 @@ TEST(GraphTest, TestMultiBlock) { ASSERT_NE(control_dep2, nullptr); ASSERT_EQ(control_dep1, control_dep2); - // Step3: Colne graph. + // Step3: Clone graph. std::shared_ptr clone_g = g->Clone(); ASSERT_EQ(clone_g->IsMainGraph(), true); diff --git a/paddle/fluid/framework/ir/pass_test.cc b/paddle/fluid/framework/ir/pass_test.cc index 65b9c427869ee..616ba7f1a9761 100644 --- a/paddle/fluid/framework/ir/pass_test.cc +++ b/paddle/fluid/framework/ir/pass_test.cc @@ -135,6 +135,93 @@ TEST(PassTest, TestPassAttrCheck) { exception.npos); } +TEST(PassTest, TestPassAttrCheckConvertAllBlocks) { + // Set FLAGS_convert_all_blocks to true to make sure this test works. + bool flag_temp = FLAGS_convert_all_blocks; + FLAGS_convert_all_blocks = true; + + ProgramDesc prog; + auto pass = PassRegistry::Instance().Get("test_pass"); + std::unique_ptr graph(new Graph(prog)); + std::string exception; + try { + graph.reset(pass->Apply(graph.release())); + } catch (paddle::platform::EnforceNotMet& e) { + exception = std::string(e.what()); + } + ASSERT_TRUE(exception.find("Required atrribute test_pass_attr for pass < " + "test_pass > is not set") != exception.npos); + + int val = 1; + graph.reset(new Graph(prog)); + pass->SetNotOwned("test_pass_attr", &val); + + for (std::string try_type : {"bool", "const int", "std::string"}) { + try { + if (try_type == "bool") { + pass->Get("test_pass_attr"); + } else if (try_type == "const int") { + pass->Get("test_pass_attr"); + } else if (try_type == "std::string") { + pass->Get("test_pass_attr"); + } + } catch (paddle::platform::EnforceNotMet& e) { + exception = std::string(e.what()); + } + std::string msg = "Invalid type for attritube test_pass_attr, expected: " + + try_type + ", actual: int"; + ASSERT_TRUE(exception.find(msg) != exception.npos); + } + + try { + graph.reset(pass->Apply(graph.release())); + } catch (paddle::platform::EnforceNotMet& e) { + exception = std::string(e.what()); + } + ASSERT_TRUE(exception.find( + "Required atrribute test_graph_attr for graph is not set") != + exception.npos); + + graph.reset(new Graph(prog)); + graph->Set("test_graph_attr", new int); + graph->Get("test_graph_attr") = 1; + graph.reset(pass->Apply(graph.release())); + ASSERT_EQ(graph->Get("copy_test_pass_attr"), 2); + ASSERT_EQ(graph->Get("copy_test_graph_attr"), 2); + + // Allow apply more than once. + graph.reset(new Graph(prog)); + graph->Set("test_graph_attr", new int); + graph.reset(pass->Apply(graph.release())); + + pass = PassRegistry::Instance().Get("test_pass"); + pass->SetNotOwned("test_pass_attr", &val); + graph.reset(new Graph(prog)); + BuildCircleGraph(graph.get()); + graph->Set("test_graph_attr", new int); + graph->Get("test_graph_attr") = 2; + try { + pass->Apply(graph.release()); + } catch (paddle::platform::EnforceNotMet& e) { + exception = std::string(e.what()); + } + ASSERT_TRUE(exception.find("shouldn't contain cycle") != exception.npos); + + pass = PassRegistry::Instance().Get("test_pass"); + pass->Set("test_pass_attr", new int); + try { + pass->Set("test_pass_attr", new int); + } catch (paddle::platform::EnforceNotMet& e) { + exception = std::string(e.what()); + } + ASSERT_TRUE( + exception.find("Attribute test_pass_attr already set in the pass") != + exception.npos); + + // Recover FLAGS_convert_all_blocks. + FLAGS_convert_all_blocks = flag_temp; +} + class TestPassWithDefault : public Pass { protected: void ApplyImpl(ir::Graph* graph) const { @@ -160,6 +247,28 @@ TEST(PassTest, TestPassDefaultAttrCheck) { ASSERT_EQ(pass->Get("default_attr"), 3); } +TEST(PassTest, TestPassDefaultAttrCheckConvertAllBlocks) { + // Set FLAGS_convert_all_blocks to true to make sure this test works. + bool flag_temp = FLAGS_convert_all_blocks; + FLAGS_convert_all_blocks = true; + + ProgramDesc prog; + // check if default value is set + auto pass = PassRegistry::Instance().Get("test_pass_default_attr"); + std::unique_ptr graph(new Graph(prog)); + ASSERT_EQ(pass->Get("default_attr"), 1); + graph.reset(pass->Apply(graph.release())); + ASSERT_EQ(graph->Get("copy_default_attr"), 2); + + // check if new value overrides default value + pass = PassRegistry::Instance().Get("test_pass_default_attr"); + pass->Set("default_attr", new int{3}); + ASSERT_EQ(pass->Get("default_attr"), 3); + + // Recover FLAGS_convert_all_blocks. + FLAGS_convert_all_blocks = flag_temp; +} + TEST(PassTest, TestPassRegistrarDeconstructor) { auto pass_registrary = new PassRegistrar( From ec23091ed7e228ddf0f15851449ff95a0307f38f Mon Sep 17 00:00:00 2001 From: levi131 Date: Wed, 23 Jun 2021 12:14:31 +0000 Subject: [PATCH 13/17] add test for graph interface --- paddle/fluid/framework/ir/graph_test.cc | 36 +++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/paddle/fluid/framework/ir/graph_test.cc b/paddle/fluid/framework/ir/graph_test.cc index 163bd996c0010..2d8de33f411a4 100644 --- a/paddle/fluid/framework/ir/graph_test.cc +++ b/paddle/fluid/framework/ir/graph_test.cc @@ -264,6 +264,42 @@ TEST(GraphTest, TestAttrCopy) { ASSERT_FALSE(dst_g.Has(kFloatValue)); } +TEST(GraphTest, TestInterfaceConvertAllBlocks) { + // Set FLAGS_convert_all_blocks to true to make sure this test works. + bool flag_temp = FLAGS_convert_all_blocks; + FLAGS_convert_all_blocks = true; + + ProgramDesc prog; + prog.MutableBlock(0)->Var("init_var")->SetType(proto::VarType::SELECTED_ROWS); + ir::Graph g(prog); + ASSERT_TRUE(g.IsMainGraph()); + + const std::string kIntValue = "int_value"; + const int INT_VALUE = 3; + g.Set(kIntValue, new int(INT_VALUE)); + ASSERT_TRUE(g.Has(kIntValue)); + ASSERT_EQ(g.GetOrInit(kIntValue), INT_VALUE); + ASSERT_EQ(g.Get(kIntValue), INT_VALUE); + g.Erase(kIntValue); + ASSERT_TRUE(!g.Has(kIntValue)); + g.SetNotOwned(kIntValue, new int(INT_VALUE)); + ASSERT_TRUE(g.Has(kIntValue)); + g.Erase(kIntValue); + + g.ReleaseNodes(); + ASSERT_EQ(g.Nodes().size(), 0UL); + g.CreateVarNode(new VarDesc("temp_var_desc_name")); + g.CreateOpNode(prog.MutableBlock(0)->AppendOp()); + g.CreateControlDepVar(); + g.CreateEmptyNode("temp_empty_node_name", ir::Node::Type::kVariable); + ASSERT_EQ(g.Nodes().size(), 4UL); + g.RemoveNode(g.RetrieveNode(1)); + ASSERT_EQ(g.Nodes().size(), 3UL); + + // Recover FLAGS_convert_all_blocks. + FLAGS_convert_all_blocks = flag_temp; +} + TEST(GraphTest, TestMultiBlock) { // Set FLAGS_convert_all_blocks to true to make sure this test works. bool flag_temp = FLAGS_convert_all_blocks; From df2c43c1e54aeecd5e1ebc3b5ddcb57f5d7ea592 Mon Sep 17 00:00:00 2001 From: levi131 Date: Wed, 7 Jul 2021 02:49:58 +0000 Subject: [PATCH 14/17] add SubGraphsSize() and use InitFromBlock in InitFromProgram --- paddle/fluid/framework/ir/graph.cc | 73 +------------------------ paddle/fluid/framework/ir/graph.h | 7 +++ paddle/fluid/framework/ir/graph_test.cc | 2 + 3 files changed, 10 insertions(+), 72 deletions(-) diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index b4e13114639c9..6bf8a96fecaec 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -128,78 +128,7 @@ std::map> Graph::InitFromBlock( std::map> Graph::InitFromProgram( const ProgramDesc &program) { VLOG(3) << "block in program:" << program_.Size(); - std::unordered_map all_vars; - // var nodes for each var name, will have multiple versions in SSA - std::map> var_nodes; - for (auto *var : program.Block(0).AllVars()) { - all_vars.emplace(var->Name(), var); - } - - auto not_visited_vars = all_vars; - - for (auto *op : program.Block(0).AllOps()) { - ir::Node *node = CreateOpNode(op); - // For input args, reuse the same var name if it was created before. - // Otherwise, create a new one. - for (auto &each_var_name : op->InputArgumentNames()) { - not_visited_vars.erase(each_var_name); - ir::Node *var = nullptr; - if (var_nodes.find(each_var_name) != var_nodes.end()) { - var = var_nodes.at(each_var_name).back(); - } else if (all_vars.count(each_var_name) != 0) { - var = CreateVarNode(all_vars.at(each_var_name)); - var_nodes[each_var_name].push_back(var); - } else { - // Operation input var can be optional (dispensable). Which means - // the operation doesn't really need the var at runtime. In this - // case, the no-existed var is ready at the beginning. - var = CreateEmptyNode(each_var_name, ir::Node::Type::kVariable); - var_nodes[each_var_name].push_back(var); - } - node->inputs.push_back(var); - var->outputs.push_back(node); - } - // For output args, always create a new var. - std::unordered_set out_arg_set; - for (auto &each_var_name : op->OutputArgumentNames()) { - not_visited_vars.erase(each_var_name); - if (each_var_name != kEmptyVarName) { - PADDLE_ENFORCE_EQ(out_arg_set.count(each_var_name), 0, - platform::errors::InvalidArgument( - "The input Program is invalid. Variable %s occurs" - " in output of %s multiple times.", - each_var_name, op->Type())); - out_arg_set.insert(each_var_name); - } - - ir::Node *var = nullptr; - if (all_vars.count(each_var_name) != 0) { - var = CreateVarNode(all_vars.at(each_var_name)); - } else { - // Operation output vars can be @EMPTY@. For example, while_grad - // can have multi @EMPTY@ outputs with no VarDesc. - // TODO(panyx0718): Add a test. - var = CreateEmptyNode(each_var_name, ir::Node::Type::kVariable); - } - var_nodes[each_var_name].push_back(var); - node->outputs.push_back(var); - var->inputs.push_back(node); - } - } - - for (auto &pair : not_visited_vars) { - const auto &var_name = pair.first; - auto *var_desc = pair.second; - if (var_name != kEmptyVarName) { - VLOG(10) << "Create isolated var node " << var_name; - var_nodes[var_name].push_back(CreateVarNode(var_desc)); - } - } - - Set>( - details::kStaleProgramOpDescs, - new std::vector(program.Block(0).AllOps())); - return var_nodes; + return InitFromBlock(program.Block(0)); } void Graph::ResolveHazard( diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index f03be51cbbdd6..5971208d7fb0c 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -351,6 +351,13 @@ class Graph { return sub_graphs_.at(idx).get(); } + size_t SubGraphsSize() const { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), true, + platform::errors::InvalidArgument("This graph is not main_graph")); + return sub_graphs_.size(); + } + private: std::map> InitFromBlock( const BlockDesc &block); diff --git a/paddle/fluid/framework/ir/graph_test.cc b/paddle/fluid/framework/ir/graph_test.cc index 2d8de33f411a4..1ff67ae0fe0d9 100644 --- a/paddle/fluid/framework/ir/graph_test.cc +++ b/paddle/fluid/framework/ir/graph_test.cc @@ -369,6 +369,7 @@ TEST(GraphTest, TestMultiBlock) { // Step2: Convert program into graph, 3 blocks corresponding 3 sub_graphs. std::unique_ptr g(new ir::Graph(prog)); ASSERT_EQ(g->IsMainGraph(), true); + ASSERT_EQ(g->SubGraphsSize(), 3UL); // Check contents in sub_graph_0. const ir::Graph *g0 = g->GetSubGraph(0); @@ -433,6 +434,7 @@ TEST(GraphTest, TestMultiBlock) { // Step3: Clone graph. std::shared_ptr clone_g = g->Clone(); ASSERT_EQ(clone_g->IsMainGraph(), true); + ASSERT_EQ(clone_g->SubGraphsSize(), 3UL); // Recover FLAGS_convert_all_blocks. FLAGS_convert_all_blocks = flag_temp; From 21a9c9fcf943ccb92f49769673315e56deda8b2c Mon Sep 17 00:00:00 2001 From: levi131 Date: Thu, 15 Jul 2021 06:31:58 +0000 Subject: [PATCH 15/17] merge upstream and resolve conflicts --- paddle/fluid/framework/ir/graph.cc | 145 ++++++++++++++++--- paddle/fluid/framework/ir/graph.h | 157 ++++++++++++++++++++- paddle/fluid/framework/ir/graph_test.cc | 176 ++++++++++++++++++++++++ paddle/fluid/framework/ir/pass_test.cc | 109 +++++++++++++++ 4 files changed, 559 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index e8a3de1a88a9d..ce0f3d5edce88 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -17,6 +17,9 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/operator.h" +DEFINE_bool(convert_all_blocks, false, + "Convert all blocks in program into SSAgraphs"); + namespace paddle { namespace framework { namespace ir { @@ -24,16 +27,9 @@ namespace ir { Graph::Graph(const ProgramDesc &program) : Graph(program, 0, program.Block(0).AllOps().size()) {} -Graph::Graph(const ProgramDesc &program, int64_t start_op_index, - int64_t end_op_index) - : program_(program) { - auto var_nodes = InitFromProgram(program_, start_op_index, end_op_index); - ResolveHazard(var_nodes); -} - -std::map> Graph::InitFromProgram( - const ProgramDesc &program, int64_t start_op_index, int64_t end_op_index) { - VLOG(3) << "block in program:" << program_.Size(); +Graph::Graph(const ProgramDesc &program, const int64_t start_op_index, + const int64_t end_op_index) + : program_(program), main_graph_(nullptr) { PADDLE_ENFORCE_GE(start_op_index, 0, platform::errors::InvalidArgument( "Required start_op_index >= 0, but received " @@ -44,16 +40,66 @@ std::map> Graph::InitFromProgram( "Required end_op_index >= start_op_index, but received " "end_op_index: %d < start_op_index: %d", end_op_index, start_op_index)); + PADDLE_ENFORCE_GE( + program_.Size(), 1, + platform::errors::InvalidArgument("Can't construct a graph from this " + "program, it doesn't have a block")); + PADDLE_ENFORCE_GE(end_op_index, program_.Block(0).AllOps().size(), + platform::errors::InvalidArgument( + "Required end_op_index <= block_op_size, but received " + "end_op_index: %d > block_op_size: %d", + end_op_index, program_.Block(0).AllOps().size())); + if (FLAGS_convert_all_blocks) { + // NOTE(levi): start_op_index and end_op_index only work on the first + // sub_graph. + std::unique_ptr first_sub_graph = std::make_unique( + program_.Block(0), this, start_op_index, end_op_index); + sub_graphs_.push_back(std::move(first_sub_graph)); + for (size_t idx = 1; idx < program_.Size(); ++idx) { + std::unique_ptr sub_graph = + std::make_unique(program_.Block(idx), this); + sub_graphs_.push_back(std::move(sub_graph)); + } + } else { + auto var_nodes = InitFromProgram(program_, start_op_index, end_op_index); + ResolveHazard(var_nodes); + } +} + +Graph::Graph(const BlockDesc &block, const Graph *main_graph) + : main_graph_(main_graph) { + auto var_nodes = InitFromBlock(block, 0, block.AllOps().size()); + ResolveHazard(var_nodes); +} + +Graph::Graph(const BlockDesc &block, const Graph *main_graph, + const int64_t start_op_index, const int64_t end_op_index) + : main_graph_(main_graph) { + auto var_nodes = InitFromBlock(block, start_op_index, end_op_index); + ResolveHazard(var_nodes); +} +// TODO(levi): delete this interface after when we can convert all +// blocks into sub_graphs. +std::map> Graph::InitFromProgram( + const ProgramDesc &program, const int64_t start_op_index, + const int64_t end_op_index) { + VLOG(3) << "block in program:" << program_.Size(); + return InitFromBlock(program.Block(0), start_op_index, end_op_index); +} + +std::map> Graph::InitFromBlock( + const BlockDesc &block, const int64_t start_op_index, + const int64_t end_op_index) { std::unordered_map all_vars; // var nodes for each var name, will have multiple versions in SSA std::map> var_nodes; - for (auto *var : program.Block(0).AllVars()) { + for (auto *var : block.AllVars()) { all_vars.emplace(var->Name(), var); } auto not_visited_vars = all_vars; - auto all_ops = program.Block(0).AllOps(); + auto all_ops = block.AllOps(); PADDLE_ENFORCE_LE( end_op_index, all_ops.size(), platform::errors::InvalidArgument( @@ -210,22 +256,77 @@ void Graph::ResolveHazard( } std::shared_ptr Graph::Clone() { - auto cloned_graph = std::make_shared(this->program_); - cloned_graph->ReleaseNodes(); - cloned_graph->num_node_created_ = 0; + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), true, + platform::errors::InvalidArgument( + "This graph is a sub_graph, and can't be cloned individually")); + if (FLAGS_convert_all_blocks) { + auto cloned_graph = std::make_shared(this->program_); + cloned_graph->ReleaseSubGraphs(); + for (size_t idx = 0; idx < this->program_.Size(); ++idx) { + cloned_graph->AddSubGraph(this->CloneSubGraph(idx)); + } + return cloned_graph; + } else { + auto cloned_graph = std::make_shared(this->program_); + cloned_graph->ReleaseNodes(); + cloned_graph->num_node_created_ = 0; + std::unordered_map origin_to_cloned; + for (auto *n : this->node_set_) { + PADDLE_ENFORCE_NOT_NULL(n, platform::errors::InvalidArgument( + "The node to be cloned is nullptr.")); + ir::Node *cloned_node = nullptr; + if (n->IsCtrlVar()) { + cloned_node = cloned_graph->CreateControlDepVar(); + } else if (!n->var_desc_ && !n->op_desc_) { // empty node + cloned_node = cloned_graph->CreateEmptyNode(n->Name(), n->NodeType()); + } else if (n->IsVar()) { + cloned_node = cloned_graph->CreateVarNode(n->Var()); + } else if (n->IsOp()) { + cloned_node = cloned_graph->CreateOpNode(n->Op()); + } + PADDLE_ENFORCE_NOT_NULL( + cloned_node, + platform::errors::InvalidArgument( + "Failed to clone new node from original node in graph.")); + origin_to_cloned[n] = cloned_node; + } + for (auto *n : this->node_set_) { + for (auto it = n->inputs.begin(); it != n->inputs.end(); it++) { + origin_to_cloned[n]->inputs.push_back(origin_to_cloned[*it]); + } + for (auto it = n->outputs.begin(); it != n->outputs.end(); it++) { + origin_to_cloned[n]->outputs.push_back(origin_to_cloned[*it]); + } + } + return cloned_graph; + } +} + +std::unique_ptr Graph::CloneSubGraph(const size_t idx) { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), true, + platform::errors::InvalidArgument("This graph is not main_graph")); + PADDLE_ENFORCE_LT( + idx, this->sub_graphs_.size(), + platform::errors::InvalidArgument("Invalid sub_graph index")); + std::unique_ptr cloned_sub_graph = + std::make_unique(this->program_.Block(idx), this); + cloned_sub_graph->ReleaseNodes(); + cloned_sub_graph->num_node_created_ = 0; std::unordered_map origin_to_cloned; - for (auto *n : this->node_set_) { + for (auto *n : this->sub_graphs_.at(idx)->Nodes()) { PADDLE_ENFORCE_NOT_NULL(n, platform::errors::InvalidArgument( "The node to be cloned is nullptr.")); ir::Node *cloned_node = nullptr; if (n->IsCtrlVar()) { - cloned_node = cloned_graph->CreateControlDepVar(); + cloned_node = cloned_sub_graph->CreateControlDepVar(); } else if (!n->var_desc_ && !n->op_desc_) { // empty node - cloned_node = cloned_graph->CreateEmptyNode(n->Name(), n->NodeType()); + cloned_node = cloned_sub_graph->CreateEmptyNode(n->Name(), n->NodeType()); } else if (n->IsVar()) { - cloned_node = cloned_graph->CreateVarNode(n->Var()); + cloned_node = cloned_sub_graph->CreateVarNode(n->Var()); } else if (n->IsOp()) { - cloned_node = cloned_graph->CreateOpNode(n->Op()); + cloned_node = cloned_sub_graph->CreateOpNode(n->Op()); } PADDLE_ENFORCE_NOT_NULL( cloned_node, @@ -233,7 +334,7 @@ std::shared_ptr Graph::Clone() { "Failed to clone new node from original node in graph.")); origin_to_cloned[n] = cloned_node; } - for (auto *n : this->node_set_) { + for (auto *n : this->sub_graphs_.at(idx)->Nodes()) { for (auto it = n->inputs.begin(); it != n->inputs.end(); it++) { origin_to_cloned[n]->inputs.push_back(origin_to_cloned[*it]); } @@ -241,7 +342,7 @@ std::shared_ptr Graph::Clone() { origin_to_cloned[n]->outputs.push_back(origin_to_cloned[*it]); } } - return cloned_graph; + return cloned_sub_graph; } bool IsControlDepVar(const ir::Node &var) { diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 26ca64ba821cb..f81590651887f 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include #include @@ -25,6 +26,8 @@ limitations under the License. */ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/variant.h" +DECLARE_bool(convert_all_blocks); + namespace paddle { namespace framework { class OpDesc; @@ -78,10 +81,21 @@ namespace ir { */ class Graph { public: + // Construct a main_graph with some sub_graphs explicit Graph(const ProgramDesc &program); - // Construct a Graph with ops[start_op_index, end_op_index) - explicit Graph(const ProgramDesc &program, int64_t start_op_index, - int64_t end_op_index); + + // Construct a main_graph with some sub_graphs, and the 1st sub_graph is + // constructed + // with ops[start_op_index, end_op_index) + Graph(const ProgramDesc &program, const int64_t start_op_index, + const int64_t end_op_index); + + // Construct a sub_graph + Graph(const BlockDesc &block, const Graph *main_graph); + + // Construct a sub_graph with ops[start_op_index, end_op_index) + Graph(const BlockDesc &block, const Graph *main_graph, + const int64_t start_op_index, const int64_t end_op_index); virtual ~Graph() { for (auto &attr : attrs_) { @@ -94,11 +108,21 @@ class Graph { bool IsConstructedByPartialProgram() const { return is_partial_; } bool Has(const std::string &attr_name) const { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->Has(attr_name); + } + } return attrs_.count(attr_name) > 0; } template AttrType &GetOrInit(const std::string &attr_name) { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->GetOrInit(attr_name); + } + } if (!Has(attr_name)) { Set(attr_name, new AttrType); } @@ -107,6 +131,11 @@ class Graph { template AttrType &Get(const std::string &attr_name) const { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->Get(attr_name); + } + } PADDLE_ENFORCE_EQ( Has(attr_name), true, platform::errors::PreconditionNotMet( @@ -123,6 +152,11 @@ class Graph { template void Set(const std::string &attr_name, AttrType *attr) { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->Set(attr_name, attr); + } + } PADDLE_ENFORCE_EQ( attrs_.count(attr_name), 0, platform::errors::AlreadyExists( @@ -137,6 +171,11 @@ class Graph { template void SetNotOwned(const std::string &attr_name, AttrType *attr) { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->SetNotOwned(attr_name, attr); + } + } PADDLE_ENFORCE_EQ( attrs_.count(attr_name), 0, platform::errors::AlreadyExists("The attribute %s to be set(not owned) " @@ -147,6 +186,11 @@ class Graph { } void Erase(const std::string &attr_name) { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->Erase(attr_name); + } + } PADDLE_ENFORCE_NE( attrs_.count(attr_name), 0, platform::errors::NotFound( @@ -157,10 +201,22 @@ class Graph { attr_dels_.erase(attr_name); } - const std::unordered_set &Nodes() const { return node_set_; } + const std::unordered_set &Nodes() const { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->Nodes(); + } + } + return node_set_; + } // Create a normal variable with non-null VarDesc. ir::Node *CreateVarNode(VarDesc *var_desc) { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->CreateVarNode(var_desc); + } + } PADDLE_ENFORCE_NOT_NULL( var_desc, platform::errors::InvalidArgument( "The VarDesc used to create variable node is null.")); @@ -171,6 +227,11 @@ class Graph { // Create a normal runnable operator with OpDesc. ir::Node *CreateOpNode(OpDesc *op_desc) { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->CreateOpNode(op_desc); + } + } PADDLE_ENFORCE_NOT_NULL( op_desc, platform::errors::InvalidArgument( "The OpDesc used to create operator node is null.")); @@ -183,6 +244,11 @@ class Graph { // var doesn't hold any data. Other than that, it's no different from // other var, considering dependency analysis. ir::Node *CreateControlDepVar() { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->CreateControlDepVar(); + } + } // TODO(panyx0718): control var name should be really unique. const std::string name = string::Sprintf( "%s@%llu", static_cast(ir::Node::kControlDepVarName), @@ -195,6 +261,11 @@ class Graph { // A more free style way of creating a graph node. Mostly use for test // or "copy" from another node. Avoid using it if possible. ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->CreateEmptyNode(name, type); + } + } auto *x = AddNode(new ir::Node(name, type)); x->SetId(num_node_created_++); return x; @@ -203,6 +274,11 @@ class Graph { // Clear all node information of the graph and return the ownership of the // nodes. std::vector> ReleaseNodes() { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->ReleaseNodes(); + } + } std::vector> ret; for (auto &n : nodes_) { ret.emplace_back(n.second.release()); @@ -213,6 +289,11 @@ class Graph { } std::unique_ptr RemoveNode(ir::Node *node) { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->RemoveNode(node); + } + } PADDLE_ENFORCE_EQ(node_set_.find(node) != node_set_.end(), true, platform::errors::PreconditionNotMet( "The node to be removed does not exist.")); @@ -225,6 +306,11 @@ class Graph { // NOTE low performance, but simple and secure. Node *RetrieveNode(int id) { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->RetrieveNode(id); + } + } for (auto &node : nodes_) { if (node.second->id() == id) { return node.second.get(); @@ -237,10 +323,22 @@ class Graph { // WARN: After a series of passes, the current graph can be quite // different from OriginProgram. Caller shouldn't assume much from // the returned OriginProgram. - const ProgramDesc &OriginProgram() const { return program_; } + const ProgramDesc &OriginProgram() const { + if (FLAGS_convert_all_blocks) { + if (!IsMainGraph()) { + return main_graph_->OriginProgram(); + } + } + return program_; + } // This method takes ownership of `node`. ir::Node *AddNode(ir::Node *node) { + if (FLAGS_convert_all_blocks) { + if (IsMainGraph()) { + return GetSubGraph(0)->AddNode(node); + } + } PADDLE_ENFORCE_EQ(node_set_.find(node) == node_set_.end(), true, platform::errors::PreconditionNotMet( "The node to be added already exists.")); @@ -256,12 +354,59 @@ class Graph { // WARN: The method only clones the graph structure, not its attributes. std::shared_ptr Clone(); + bool IsMainGraph() const { return main_graph_ == nullptr; } + + Graph *GetSubGraph(const size_t idx) const { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), true, + platform::errors::InvalidArgument("This graph is not main_graph")); + PADDLE_ENFORCE_LT( + idx, sub_graphs_.size(), + platform::errors::InvalidArgument("Invalid sub_graph index")); + return sub_graphs_.at(idx).get(); + } + + size_t SubGraphsSize() const { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), true, + platform::errors::InvalidArgument("This graph is not main_graph")); + return sub_graphs_.size(); + } + private: + // TODO(levi): delete this interface after when we can convert all + // blocks into sub_graphs. std::map> InitFromProgram( - const ProgramDesc &program, int64_t start_op_index, int64_t end_op_index); + const ProgramDesc &program, const int64_t start_op_index, + const int64_t end_op_index); + + std::map> InitFromBlock( + const BlockDesc &block, const int64_t start_op_index, + const int64_t end_op_index); + + void ReleaseSubGraphs() { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), true, + platform::errors::InvalidArgument("This graph is not main_graph")); + sub_graphs_.clear(); + } + + void AddSubGraph(std::unique_ptr sub_graph) { + PADDLE_ENFORCE_EQ( + this->IsMainGraph(), true, + platform::errors::InvalidArgument("This graph is not main_graph")); + sub_graphs_.push_back(std::move(sub_graph)); + } + + std::unique_ptr CloneSubGraph(const size_t idx); // NOTE: program_ shouldn't be exposed to user. const ProgramDesc program_; + // NOTE: main_graph_ doesn't hold any node. It's used as a container of + // sub_graphs, and the sub_graph holds the nodes. + const Graph *main_graph_; // not owned. + std::vector> sub_graphs_; + std::map attrs_; std::map> attr_dels_; std::map> nodes_; diff --git a/paddle/fluid/framework/ir/graph_test.cc b/paddle/fluid/framework/ir/graph_test.cc index 66507fe7cafbb..1ff67ae0fe0d9 100644 --- a/paddle/fluid/framework/ir/graph_test.cc +++ b/paddle/fluid/framework/ir/graph_test.cc @@ -264,5 +264,181 @@ TEST(GraphTest, TestAttrCopy) { ASSERT_FALSE(dst_g.Has(kFloatValue)); } +TEST(GraphTest, TestInterfaceConvertAllBlocks) { + // Set FLAGS_convert_all_blocks to true to make sure this test works. + bool flag_temp = FLAGS_convert_all_blocks; + FLAGS_convert_all_blocks = true; + + ProgramDesc prog; + prog.MutableBlock(0)->Var("init_var")->SetType(proto::VarType::SELECTED_ROWS); + ir::Graph g(prog); + ASSERT_TRUE(g.IsMainGraph()); + + const std::string kIntValue = "int_value"; + const int INT_VALUE = 3; + g.Set(kIntValue, new int(INT_VALUE)); + ASSERT_TRUE(g.Has(kIntValue)); + ASSERT_EQ(g.GetOrInit(kIntValue), INT_VALUE); + ASSERT_EQ(g.Get(kIntValue), INT_VALUE); + g.Erase(kIntValue); + ASSERT_TRUE(!g.Has(kIntValue)); + g.SetNotOwned(kIntValue, new int(INT_VALUE)); + ASSERT_TRUE(g.Has(kIntValue)); + g.Erase(kIntValue); + + g.ReleaseNodes(); + ASSERT_EQ(g.Nodes().size(), 0UL); + g.CreateVarNode(new VarDesc("temp_var_desc_name")); + g.CreateOpNode(prog.MutableBlock(0)->AppendOp()); + g.CreateControlDepVar(); + g.CreateEmptyNode("temp_empty_node_name", ir::Node::Type::kVariable); + ASSERT_EQ(g.Nodes().size(), 4UL); + g.RemoveNode(g.RetrieveNode(1)); + ASSERT_EQ(g.Nodes().size(), 3UL); + + // Recover FLAGS_convert_all_blocks. + FLAGS_convert_all_blocks = flag_temp; +} + +TEST(GraphTest, TestMultiBlock) { + // Set FLAGS_convert_all_blocks to true to make sure this test works. + bool flag_temp = FLAGS_convert_all_blocks; + FLAGS_convert_all_blocks = true; + + // Step1: Build a program with 3 blocks. + ProgramDesc prog; + ASSERT_EQ(prog.Size(), 1UL); + prog.AppendBlock(prog.Block(0)); + prog.AppendBlock(prog.Block(0)); + ASSERT_EQ(prog.Size(), 3UL); + + // Set contents in block_0. + auto *op = prog.MutableBlock(0)->AppendOp(); + op->SetType("sum"); + op->SetInput("X", {"test_a", "test_b", "test_c"}); + op->SetOutput("Out", {"test_out"}); + op->SetAttr("op_role", 1); + + prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarType::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarType::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test_out"); + op->InferVarType(prog.MutableBlock(0)); + ASSERT_EQ(proto::VarType::SELECTED_ROWS, + prog.MutableBlock(0)->Var("test_out")->GetType()); + + prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::LOD_TENSOR); + op->InferVarType(prog.MutableBlock(0)); + ASSERT_EQ(proto::VarType::LOD_TENSOR, + prog.MutableBlock(0)->Var("test_out")->GetType()); + + // Set contents in block_1. + op = prog.MutableBlock(1)->AppendOp(); + op->SetType("sum"); + op->SetInput("X", {"a"}); + op->SetOutput("Out", {"b"}); + op->SetAttr("op_role", 1); + + op = prog.MutableBlock(1)->AppendOp(); + op->SetType("dummy"); + op->SetInput("X", {"c"}); + op->SetOutput("Out", {"a"}); + op->SetAttr("op_role", 1); + + prog.MutableBlock(1)->Var("a")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(1)->Var("b")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(1)->Var("c")->SetType(proto::VarType::LOD_TENSOR); + + // Set contents in block_2. + op = prog.MutableBlock(2)->AppendOp(); + op->SetType("sum"); + op->SetInput("X", {"a"}); + op->SetOutput("Out", {"b"}); + op->SetAttr("op_role", 1); + + op = prog.MutableBlock(2)->AppendOp(); + op->SetType("dummy"); + op->SetInput("X", {"c"}); + op->SetOutput("Out", {"b"}); + op->SetAttr("op_role", 1); + + prog.MutableBlock(2)->Var("a")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(2)->Var("b")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(2)->Var("c")->SetType(proto::VarType::LOD_TENSOR); + + // Step2: Convert program into graph, 3 blocks corresponding 3 sub_graphs. + std::unique_ptr g(new ir::Graph(prog)); + ASSERT_EQ(g->IsMainGraph(), true); + ASSERT_EQ(g->SubGraphsSize(), 3UL); + + // Check contents in sub_graph_0. + const ir::Graph *g0 = g->GetSubGraph(0); + std::vector nodes(g0->Nodes().begin(), g0->Nodes().end()); + for (ir::Node *n : nodes) { + if (n->Name() == "sum") { + ASSERT_EQ(n->inputs.size(), 3UL); + ASSERT_EQ(n->outputs.size(), 1UL); + } else if (n->Name() == "test_a" || n->Name() == "test_b" || + n->Name() == "test_c") { + ASSERT_EQ(n->inputs.size(), 0UL); + ASSERT_EQ(n->outputs.size(), 1UL); + } else if (n->Name() == "test_out") { + ASSERT_EQ(n->inputs.size(), 1UL); + ASSERT_EQ(n->outputs.size(), 0UL); + } + } + ASSERT_EQ(nodes.size(), 5UL); + + // Check contents in sub_graph_1. + const ir::Graph *g1 = g->GetSubGraph(1); + ir::Node *control_dep1 = nullptr; + ir::Node *control_dep2 = nullptr; + for (ir::Node *n : g1->Nodes()) { + if (n->Name() == "sum") { + ASSERT_EQ(n->outputs[0]->Name(), "b"); + ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1])); + control_dep1 = n->outputs[1]; + ASSERT_EQ(n->outputs.size(), 2UL); + } + if (n->Name() == "dummy") { + ASSERT_EQ(n->inputs[0]->Name(), "c"); + ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1])); + control_dep2 = n->inputs[1]; + ASSERT_EQ(n->inputs.size(), 2UL); + } + } + ASSERT_EQ(control_dep1, control_dep2); + + // Check contents in sub_graph_2. + const ir::Graph *g2 = g->GetSubGraph(2); + control_dep1 = nullptr; + control_dep2 = nullptr; + for (ir::Node *n : g2->Nodes()) { + if (n->Name() == "sum") { + ASSERT_EQ(n->outputs[0]->Name(), "b"); + ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1])); + ASSERT_EQ(n->outputs.size(), 2UL); + control_dep1 = n->outputs[1]; + } + if (n->Name() == "dummy") { + ASSERT_EQ(n->inputs[0]->Name(), "c"); + ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1])); + control_dep2 = n->inputs[1]; + ASSERT_EQ(n->inputs.size(), 2UL); + } + } + ASSERT_NE(control_dep1, nullptr); + ASSERT_NE(control_dep2, nullptr); + ASSERT_EQ(control_dep1, control_dep2); + + // Step3: Clone graph. + std::shared_ptr clone_g = g->Clone(); + ASSERT_EQ(clone_g->IsMainGraph(), true); + ASSERT_EQ(clone_g->SubGraphsSize(), 3UL); + + // Recover FLAGS_convert_all_blocks. + FLAGS_convert_all_blocks = flag_temp; +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/pass_test.cc b/paddle/fluid/framework/ir/pass_test.cc index 65b9c427869ee..616ba7f1a9761 100644 --- a/paddle/fluid/framework/ir/pass_test.cc +++ b/paddle/fluid/framework/ir/pass_test.cc @@ -135,6 +135,93 @@ TEST(PassTest, TestPassAttrCheck) { exception.npos); } +TEST(PassTest, TestPassAttrCheckConvertAllBlocks) { + // Set FLAGS_convert_all_blocks to true to make sure this test works. + bool flag_temp = FLAGS_convert_all_blocks; + FLAGS_convert_all_blocks = true; + + ProgramDesc prog; + auto pass = PassRegistry::Instance().Get("test_pass"); + std::unique_ptr graph(new Graph(prog)); + std::string exception; + try { + graph.reset(pass->Apply(graph.release())); + } catch (paddle::platform::EnforceNotMet& e) { + exception = std::string(e.what()); + } + ASSERT_TRUE(exception.find("Required atrribute test_pass_attr for pass < " + "test_pass > is not set") != exception.npos); + + int val = 1; + graph.reset(new Graph(prog)); + pass->SetNotOwned("test_pass_attr", &val); + + for (std::string try_type : {"bool", "const int", "std::string"}) { + try { + if (try_type == "bool") { + pass->Get("test_pass_attr"); + } else if (try_type == "const int") { + pass->Get("test_pass_attr"); + } else if (try_type == "std::string") { + pass->Get("test_pass_attr"); + } + } catch (paddle::platform::EnforceNotMet& e) { + exception = std::string(e.what()); + } + std::string msg = "Invalid type for attritube test_pass_attr, expected: " + + try_type + ", actual: int"; + ASSERT_TRUE(exception.find(msg) != exception.npos); + } + + try { + graph.reset(pass->Apply(graph.release())); + } catch (paddle::platform::EnforceNotMet& e) { + exception = std::string(e.what()); + } + ASSERT_TRUE(exception.find( + "Required atrribute test_graph_attr for graph is not set") != + exception.npos); + + graph.reset(new Graph(prog)); + graph->Set("test_graph_attr", new int); + graph->Get("test_graph_attr") = 1; + graph.reset(pass->Apply(graph.release())); + ASSERT_EQ(graph->Get("copy_test_pass_attr"), 2); + ASSERT_EQ(graph->Get("copy_test_graph_attr"), 2); + + // Allow apply more than once. + graph.reset(new Graph(prog)); + graph->Set("test_graph_attr", new int); + graph.reset(pass->Apply(graph.release())); + + pass = PassRegistry::Instance().Get("test_pass"); + pass->SetNotOwned("test_pass_attr", &val); + graph.reset(new Graph(prog)); + BuildCircleGraph(graph.get()); + graph->Set("test_graph_attr", new int); + graph->Get("test_graph_attr") = 2; + try { + pass->Apply(graph.release()); + } catch (paddle::platform::EnforceNotMet& e) { + exception = std::string(e.what()); + } + ASSERT_TRUE(exception.find("shouldn't contain cycle") != exception.npos); + + pass = PassRegistry::Instance().Get("test_pass"); + pass->Set("test_pass_attr", new int); + try { + pass->Set("test_pass_attr", new int); + } catch (paddle::platform::EnforceNotMet& e) { + exception = std::string(e.what()); + } + ASSERT_TRUE( + exception.find("Attribute test_pass_attr already set in the pass") != + exception.npos); + + // Recover FLAGS_convert_all_blocks. + FLAGS_convert_all_blocks = flag_temp; +} + class TestPassWithDefault : public Pass { protected: void ApplyImpl(ir::Graph* graph) const { @@ -160,6 +247,28 @@ TEST(PassTest, TestPassDefaultAttrCheck) { ASSERT_EQ(pass->Get("default_attr"), 3); } +TEST(PassTest, TestPassDefaultAttrCheckConvertAllBlocks) { + // Set FLAGS_convert_all_blocks to true to make sure this test works. + bool flag_temp = FLAGS_convert_all_blocks; + FLAGS_convert_all_blocks = true; + + ProgramDesc prog; + // check if default value is set + auto pass = PassRegistry::Instance().Get("test_pass_default_attr"); + std::unique_ptr graph(new Graph(prog)); + ASSERT_EQ(pass->Get("default_attr"), 1); + graph.reset(pass->Apply(graph.release())); + ASSERT_EQ(graph->Get("copy_default_attr"), 2); + + // check if new value overrides default value + pass = PassRegistry::Instance().Get("test_pass_default_attr"); + pass->Set("default_attr", new int{3}); + ASSERT_EQ(pass->Get("default_attr"), 3); + + // Recover FLAGS_convert_all_blocks. + FLAGS_convert_all_blocks = flag_temp; +} + TEST(PassTest, TestPassRegistrarDeconstructor) { auto pass_registrary = new PassRegistrar( From d6e1d7ece3bdda8bf0d7728c78adf5d9620b8114 Mon Sep 17 00:00:00 2001 From: levi131 Date: Thu, 15 Jul 2021 09:27:34 +0000 Subject: [PATCH 16/17] small fix: avoid compare between int64_t and size_t --- paddle/fluid/framework/ir/graph.cc | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index ce0f3d5edce88..5bd52e42daf9a 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -44,11 +44,13 @@ Graph::Graph(const ProgramDesc &program, const int64_t start_op_index, program_.Size(), 1, platform::errors::InvalidArgument("Can't construct a graph from this " "program, it doesn't have a block")); - PADDLE_ENFORCE_GE(end_op_index, program_.Block(0).AllOps().size(), + + const int64_t block_op_size = program_.Block(0).AllOps().size(); + PADDLE_ENFORCE_GE(end_op_index, block_op_size, platform::errors::InvalidArgument( "Required end_op_index <= block_op_size, but received " "end_op_index: %d > block_op_size: %d", - end_op_index, program_.Block(0).AllOps().size())); + end_op_index, block_op_size)); if (FLAGS_convert_all_blocks) { // NOTE(levi): start_op_index and end_op_index only work on the first // sub_graph. @@ -67,10 +69,7 @@ Graph::Graph(const ProgramDesc &program, const int64_t start_op_index, } Graph::Graph(const BlockDesc &block, const Graph *main_graph) - : main_graph_(main_graph) { - auto var_nodes = InitFromBlock(block, 0, block.AllOps().size()); - ResolveHazard(var_nodes); -} + : Graph(block, main_graph, 0, block.AllOps().size()) {} Graph::Graph(const BlockDesc &block, const Graph *main_graph, const int64_t start_op_index, const int64_t end_op_index) From e59c6c0a3e897b9ece336de837378b91034a31a5 Mon Sep 17 00:00:00 2001 From: levi131 Date: Thu, 15 Jul 2021 09:29:48 +0000 Subject: [PATCH 17/17] small fix --- paddle/fluid/framework/ir/graph.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 5bd52e42daf9a..1f55f0aa3cbad 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -46,7 +46,7 @@ Graph::Graph(const ProgramDesc &program, const int64_t start_op_index, "program, it doesn't have a block")); const int64_t block_op_size = program_.Block(0).AllOps().size(); - PADDLE_ENFORCE_GE(end_op_index, block_op_size, + PADDLE_ENFORCE_LE(end_op_index, block_op_size, platform::errors::InvalidArgument( "Required end_op_index <= block_op_size, but received " "end_op_index: %d > block_op_size: %d",