Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert all blocks in program into SSAgraphs. #33320

Merged
merged 27 commits into from
Jul 16, 2021
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
3e2059d
draft for ssa_program to graph with sub_graphs
Jun 2, 2021
0c58b91
try add gflags ssa_program
Jun 3, 2021
69c71ed
use gflags
Jun 3, 2021
1ccb8df
rename gflag macro and some member of Graph
Jun 7, 2021
2495c06
use nomal if-else instead of #if
Jun 7, 2021
4e5ec81
Merge remote-tracking branch 'upstream/develop' into levi/ssaprogram2…
Jun 8, 2021
a9a55df
add unittest for convert_all_blocks
Jun 8, 2021
3f5d561
Merge remote-tracking branch 'upstream/develop' into levi/ssaprogram2…
Jun 9, 2021
9dffa5b
use LOG(WARNING)
Jun 9, 2021
1655150
modify format
Jun 9, 2021
0229f37
small spell modify
Jun 15, 2021
86a74dc
Merge remote-tracking branch 'upstream/develop' into levi/ssaprogram2…
Jun 16, 2021
5f35c3f
Merge remote-tracking branch 'upstream/develop' into levi/ssaprogram2…
Jun 22, 2021
98bc7ad
ensure GraphTest.TestMultiBlock works and rm WARNINGs in APIs for cla…
Jun 22, 2021
c0df219
small change to re-start CI
Jun 22, 2021
cc4c312
Merge remote-tracking branch 'upstream/develop' into levi/ssaprogram2…
Jun 22, 2021
0b62f2d
enable some TestPass cases
Jun 23, 2021
ec23091
add test for graph interface
Jun 23, 2021
5d018fb
Merge remote-tracking branch 'upstream/develop' into levi/ssaprogram2…
Jun 29, 2021
3f7aa28
Merge remote-tracking branch 'upstream/develop' into levi/ssaprogram2…
Jun 29, 2021
df2c43c
add SubGraphsSize() and use InitFromBlock in InitFromProgram
Jul 7, 2021
21a9c9f
merge upstream and resolve conflicts
Jul 15, 2021
3af89c0
format fix
Jul 15, 2021
e19a194
Merge remote-tracking branch 'upstream/develop' into levi/ssaprogram2…
Jul 15, 2021
fa14213
Merge remote-tracking branch 'upstream/develop' into levi/ssaprogram2…
Jul 15, 2021
d6e1d7e
small fix: avoid compare between int64_t and size_t
Jul 15, 2021
e59c6c0
small fix
Jul 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 166 additions & 12 deletions paddle/fluid/framework/ir/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,114 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/operator.h"

DEFINE_bool(convert_all_blocks, false,
"Convert all blocks in program into SSAgraphs");

namespace paddle {
namespace framework {
namespace ir {

Graph::Graph(const ProgramDesc &program) : program_(program) {
auto var_nodes = InitFromProgram(program_);
Graph::Graph(const ProgramDesc &program)
: program_(program), main_graph_(nullptr) {
if (FLAGS_convert_all_blocks) {
PADDLE_ENFORCE_GE(
program_.Size(), 1,
platform::errors::InvalidArgument("Can't construct a graph from this "
"program, it doesn't have a block"));
for (size_t idx = 0; idx < program_.Size(); ++idx) {
std::unique_ptr<Graph> sub_graph =
std::make_unique<Graph>(program_.Block(idx), this);
sub_graphs_.push_back(std::move(sub_graph));
}
} else {
auto var_nodes = InitFromProgram(program_);
ResolveHazard(var_nodes);
}
}

Graph::Graph(const BlockDesc &block, const Graph *main_graph)
: main_graph_(main_graph) {
auto var_nodes = InitFromBlock(block);
ResolveHazard(var_nodes);
}

std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
const BlockDesc &block) {
std::unordered_map<std::string, VarDesc *> all_vars;
// var nodes for each var name, will have multiple versions in SSA
std::map<std::string, std::vector<ir::Node *>> var_nodes;
for (auto *var : block.AllVars()) {
all_vars.emplace(var->Name(), var);
}

auto not_visited_vars = all_vars;

for (auto *op : block.AllOps()) {
ir::Node *node = CreateOpNode(op);
// For input args, reuse the same var name if it was created before.
// Otherwise, create a new one.
for (auto &each_var_name : op->InputArgumentNames()) {
not_visited_vars.erase(each_var_name);
ir::Node *var = nullptr;
if (var_nodes.find(each_var_name) != var_nodes.end()) {
var = var_nodes.at(each_var_name).back();
} else if (all_vars.count(each_var_name) != 0) {
var = CreateVarNode(all_vars.at(each_var_name));
var_nodes[each_var_name].push_back(var);
} else {
// Operation input var can be optional (dispensable). Which means
// the operation doesn't really need the var at runtime. In this
// case, the no-existed var is ready at the beginning.
var = CreateEmptyNode(each_var_name, ir::Node::Type::kVariable);
var_nodes[each_var_name].push_back(var);
}
node->inputs.push_back(var);
var->outputs.push_back(node);
}
// For output args, always create a new var.
std::unordered_set<std::string> out_arg_set;
for (auto &each_var_name : op->OutputArgumentNames()) {
not_visited_vars.erase(each_var_name);
if (each_var_name != kEmptyVarName) {
PADDLE_ENFORCE_EQ(out_arg_set.count(each_var_name), 0,
platform::errors::InvalidArgument(
"The input Program is invalid. Variable %s occurs"
" in output of %s multiple times.",
each_var_name, op->Type()));
out_arg_set.insert(each_var_name);
}

ir::Node *var = nullptr;
if (all_vars.count(each_var_name) != 0) {
var = CreateVarNode(all_vars.at(each_var_name));
} else {
// Operation output vars can be @EMPTY@. For example, while_grad
// can have multi @EMPTY@ outputs with no VarDesc.
// TODO(panyx0718): Add a test.
var = CreateEmptyNode(each_var_name, ir::Node::Type::kVariable);
}
var_nodes[each_var_name].push_back(var);
node->outputs.push_back(var);
var->inputs.push_back(node);
}
}

for (auto &pair : not_visited_vars) {
const auto &var_name = pair.first;
auto *var_desc = pair.second;
if (var_name != kEmptyVarName) {
VLOG(10) << "Create isolated var node " << var_name;
var_nodes[var_name].push_back(CreateVarNode(var_desc));
}
}

Set<const std::vector<OpDesc *>>(details::kStaleProgramOpDescs,
new std::vector<OpDesc *>(block.AllOps()));
return var_nodes;
}

// TODO(levi): delete this interface after when we can convert all
// blocks into sub_graphs.
std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
const ProgramDesc &program) {
VLOG(3) << "block in program:" << program_.Size();
levi131 marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -176,38 +275,93 @@ void Graph::ResolveHazard(
}

std::shared_ptr<Graph> Graph::Clone() {
auto cloned_graph = std::make_shared<Graph>(this->program_);
cloned_graph->ReleaseNodes();
cloned_graph->num_node_created_ = 0;
PADDLE_ENFORCE_EQ(
this->IsMainGraph(), true,
platform::errors::InvalidArgument(
"This graph is a sub_graph, and can't be cloned individually"));
if (FLAGS_convert_all_blocks) {
auto cloned_graph = std::make_shared<Graph>(this->program_);
cloned_graph->ReleaseSubGraphs();
for (size_t idx = 0; idx < this->program_.Size(); ++idx) {
cloned_graph->AddSubGraph(this->CloneSubGraph(idx));
}
return cloned_graph;
} else {
auto cloned_graph = std::make_shared<Graph>(this->program_);
cloned_graph->ReleaseNodes();
cloned_graph->num_node_created_ = 0;
std::unordered_map<ir::Node *, ir::Node *> origin_to_cloned;
for (auto *n : this->node_set_) {
PADDLE_ENFORCE_NOT_NULL(n, platform::errors::InvalidArgument(
"The node to be cloned is nullptr."));
ir::Node *cloned_node = nullptr;
if (n->IsCtrlVar()) {
cloned_node = cloned_graph->CreateControlDepVar();
} else if (!n->var_desc_ && !n->op_desc_) { // empty node
cloned_node = cloned_graph->CreateEmptyNode(n->Name(), n->NodeType());
} else if (n->IsVar()) {
cloned_node = cloned_graph->CreateVarNode(n->Var());
} else if (n->IsOp()) {
cloned_node = cloned_graph->CreateOpNode(n->Op());
}
PADDLE_ENFORCE_NOT_NULL(
cloned_node,
platform::errors::InvalidArgument(
"Failed to clone new node from original node in graph."));
origin_to_cloned[n] = cloned_node;
}
for (auto *n : this->node_set_) {
for (auto it = n->inputs.begin(); it != n->inputs.end(); it++) {
origin_to_cloned[n]->inputs.push_back(origin_to_cloned[*it]);
}
for (auto it = n->outputs.begin(); it != n->outputs.end(); it++) {
origin_to_cloned[n]->outputs.push_back(origin_to_cloned[*it]);
}
}
return cloned_graph;
}
}

std::unique_ptr<Graph> Graph::CloneSubGraph(const size_t idx) {
PADDLE_ENFORCE_EQ(
this->IsMainGraph(), true,
platform::errors::InvalidArgument("This graph is not main_graph"));
PADDLE_ENFORCE_LT(
idx, this->sub_graphs_.size(),
platform::errors::InvalidArgument("Invalid sub_graph index"));
std::unique_ptr<Graph> cloned_sub_graph =
std::make_unique<Graph>(this->program_.Block(idx), this);
cloned_sub_graph->ReleaseNodes();
cloned_sub_graph->num_node_created_ = 0;
std::unordered_map<ir::Node *, ir::Node *> origin_to_cloned;
for (auto *n : this->node_set_) {
for (auto *n : this->sub_graphs_.at(idx)->Nodes()) {
PADDLE_ENFORCE_NOT_NULL(n, platform::errors::InvalidArgument(
"The node to be cloned is nullptr."));
ir::Node *cloned_node = nullptr;
if (n->IsCtrlVar()) {
cloned_node = cloned_graph->CreateControlDepVar();
cloned_node = cloned_sub_graph->CreateControlDepVar();
} else if (!n->var_desc_ && !n->op_desc_) { // empty node
cloned_node = cloned_graph->CreateEmptyNode(n->Name(), n->NodeType());
cloned_node = cloned_sub_graph->CreateEmptyNode(n->Name(), n->NodeType());
} else if (n->IsVar()) {
cloned_node = cloned_graph->CreateVarNode(n->Var());
cloned_node = cloned_sub_graph->CreateVarNode(n->Var());
} else if (n->IsOp()) {
cloned_node = cloned_graph->CreateOpNode(n->Op());
cloned_node = cloned_sub_graph->CreateOpNode(n->Op());
}
PADDLE_ENFORCE_NOT_NULL(
cloned_node,
platform::errors::InvalidArgument(
"Failed to clone new node from original node in graph."));
origin_to_cloned[n] = cloned_node;
}
for (auto *n : this->node_set_) {
for (auto *n : this->sub_graphs_.at(idx)->Nodes()) {
for (auto it = n->inputs.begin(); it != n->inputs.end(); it++) {
origin_to_cloned[n]->inputs.push_back(origin_to_cloned[*it]);
}
for (auto it = n->outputs.begin(); it != n->outputs.end(); it++) {
origin_to_cloned[n]->outputs.push_back(origin_to_cloned[*it]);
}
}
return cloned_graph;
return cloned_sub_graph;
}

bool IsControlDepVar(const ir::Node &var) {
Expand Down
Loading