Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

graph_to_program topology sort #33949

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
770923f
try whether topology sort can solve speed decease problem
thisjiang Jul 4, 2021
1450bda
add multi block topo function
thisjiang Jul 6, 2021
8f9e886
add convert_all_blocks flags
thisjiang Jul 6, 2021
3e8f358
solve synax bug
thisjiang Jul 6, 2021
1f3101c
optimize some description
thisjiang Jul 6, 2021
e9e98fe
optimize sort function
thisjiang Jul 7, 2021
89b98e8
optimize function name and solve some bug
thisjiang Jul 7, 2021
af57fdd
add single test script
thisjiang Jul 8, 2021
3de8418
remove IsParameter and StopGradient test for PR33771 not merged
thisjiang Jul 9, 2021
a109c74
using self-define string instead of protobuf's ShortDebugString
thisjiang Jul 12, 2021
0484189
solve conflict
thisjiang Jul 16, 2021
84c8ed2
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thisjiang Jul 16, 2021
eb5b8cd
Fix Bug: node.ToString consider OpHandle, an Op but hasn't OpDesc.
thisjiang Jul 16, 2021
d154917
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
thisjiang Jul 16, 2021
2f4f1b7
add flag control TopologySort function
thisjiang Jul 16, 2021
9e691a5
replace scale_loss_grad op to fill_constant in graph2program
thisjiang Jul 23, 2021
819c1e2
add single test script for CI converage
thisjiang Jul 23, 2021
ec23b38
fix node test bug
thisjiang Jul 23, 2021
b3ebb83
complete node test script for node.ToString test
thisjiang Jul 24, 2021
c89fa02
avoid redundant copy of scalelossgrad op replace
thisjiang Jul 26, 2021
4fa0bdf
remove GetBlockId
thisjiang Jul 26, 2021
45ed5ba
merge branch zhhsplendid:ci_test_convert_all_block https://github.com…
thisjiang Jul 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions paddle/fluid/framework/ir/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ Graph::Graph(const ProgramDesc &program, const int64_t start_op_index,
// sub_graph.
std::unique_ptr<Graph> first_sub_graph = std::make_unique<Graph>(
program_.Block(0), this, start_op_index, end_op_index);
first_sub_graph->block_id_ = 0;
sub_graphs_.push_back(std::move(first_sub_graph));
for (size_t idx = 1; idx < program_.Size(); ++idx) {
std::unique_ptr<Graph> sub_graph =
std::make_unique<Graph>(program_.Block(idx), this);
sub_graph->block_id_ = idx;
sub_graphs_.push_back(std::move(sub_graph));
}
} else {
Expand Down Expand Up @@ -90,14 +92,32 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
const BlockDesc &block, const int64_t start_op_index,
const int64_t end_op_index) {
std::unordered_map<std::string, VarDesc *> all_vars;
std::unordered_map<std::string, std::pair<VarDesc *, int>>
name_to_desc_block_id;

const BlockDesc *block_var_visible = &block;
while (block_var_visible != nullptr) {
for (auto *var : block_var_visible->AllVars()) {
name_to_desc_block_id.emplace(
var->Name(), std::make_pair(var, block_var_visible->ID()));
}
const BlockDesc *forward_block = block_var_visible->ForwardBlock();
if (forward_block != nullptr) {
for (auto *var : forward_block->AllVars()) {
name_to_desc_block_id.emplace(var->Name(),
std::make_pair(var, forward_block->ID()));
}
}
block_var_visible = block_var_visible->ParentBlock();
}
// var nodes for each var name, will have multiple versions in SSA
std::map<std::string, std::vector<ir::Node *>> var_nodes;
std::unordered_map<std::string, VarDesc *> not_visited_vars;
for (auto *var : block.AllVars()) {
all_vars.emplace(var->Name(), var);
not_visited_vars.emplace(var->Name(), var);
}

auto not_visited_vars = all_vars;
int desc_order = 0;
auto all_ops = block.AllOps();
PADDLE_ENFORCE_LE(
end_op_index, all_ops.size(),
Expand All @@ -109,15 +129,18 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
auto *op = all_ops[i];
VLOG(3) << "create OpNode by " << op->Type();
ir::Node *node = CreateOpNode(op);
node->SetDescOrder(desc_order);
++desc_order;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

45行可以和44行合并。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我觉得分开写更清晰~多一行少一行无所谓

// For input args, reuse the same var name if it was created before.
// Otherwise, create a new one.
for (auto &each_var_name : op->InputArgumentNames()) {
not_visited_vars.erase(each_var_name);
ir::Node *var = nullptr;
if (var_nodes.find(each_var_name) != var_nodes.end()) {
var = var_nodes.at(each_var_name).back();
} else if (all_vars.count(each_var_name) != 0) {
var = CreateVarNode(all_vars.at(each_var_name));
} else if (name_to_desc_block_id.count(each_var_name) != 0) {
auto desc_and_block_id = name_to_desc_block_id.at(each_var_name);
var = CreateVarNode(desc_and_block_id.first, desc_and_block_id.second);
var_nodes[each_var_name].push_back(var);
} else {
// Operation input var can be optional (dispensable). Which means
Expand All @@ -143,8 +166,9 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
}

ir::Node *var = nullptr;
if (all_vars.count(each_var_name) != 0) {
var = CreateVarNode(all_vars.at(each_var_name));
if (name_to_desc_block_id.count(each_var_name) != 0) {
auto desc_and_block_id = name_to_desc_block_id.at(each_var_name);
var = CreateVarNode(desc_and_block_id.first, desc_and_block_id.second);
} else {
// Operation output vars can be @EMPTY@. For example, while_grad
// can have multi @EMPTY@ outputs with no VarDesc.
Expand Down Expand Up @@ -270,6 +294,7 @@ std::shared_ptr<Graph> Graph::Clone() {
auto cloned_graph = std::make_shared<Graph>(this->program_);
cloned_graph->ReleaseNodes();
cloned_graph->num_node_created_ = 0;
cloned_graph->block_id_ = this->block_id_;
std::unordered_map<ir::Node *, ir::Node *> origin_to_cloned;
for (auto *n : this->node_set_) {
PADDLE_ENFORCE_NOT_NULL(n, platform::errors::InvalidArgument(
Expand Down Expand Up @@ -313,6 +338,7 @@ std::unique_ptr<Graph> Graph::CloneSubGraph(const size_t idx) {
std::make_unique<Graph>(this->program_.Block(idx), this);
cloned_sub_graph->ReleaseNodes();
cloned_sub_graph->num_node_created_ = 0;
cloned_sub_graph->block_id_ = idx;
std::unordered_map<ir::Node *, ir::Node *> origin_to_cloned;
for (auto *n : this->sub_graphs_.at(idx)->Nodes()) {
PADDLE_ENFORCE_NOT_NULL(n, platform::errors::InvalidArgument(
Expand Down
32 changes: 27 additions & 5 deletions paddle/fluid/framework/ir/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,14 @@ class Graph {
attr_dels_.clear();
}

bool IsConstructedByPartialProgram() const { return is_partial_; }
bool IsConstructedByPartialProgram() const {
if (FLAGS_convert_all_blocks) {
if (IsMainGraph()) {
return GetSubGraph(0)->IsConstructedByPartialProgram();
}
}
return is_partial_;
}

bool Has(const std::string &attr_name) const {
if (FLAGS_convert_all_blocks) {
Expand Down Expand Up @@ -210,7 +217,7 @@ class Graph {
}

// Create a normal variable with non-null VarDesc.
ir::Node *CreateVarNode(VarDesc *var_desc) {
ir::Node *CreateVarNode(VarDesc *var_desc, int block_id = -1) {
if (FLAGS_convert_all_blocks) {
if (IsMainGraph()) {
return GetSubGraph(0)->CreateVarNode(var_desc);
Expand All @@ -219,7 +226,8 @@ class Graph {
PADDLE_ENFORCE_NOT_NULL(
var_desc, platform::errors::InvalidArgument(
"The VarDesc used to create variable node is null."));
auto *x = AddNode(new ir::Node(var_desc));
auto *x =
AddNode(new ir::Node(var_desc, block_id == -1 ? block_id_ : block_id));
x->SetId(num_node_created_++);
return x;
}
Expand Down Expand Up @@ -252,7 +260,7 @@ class Graph {
const std::string name = string::Sprintf(
"%s@%llu", static_cast<const char *>(ir::Node::kControlDepVarName),
num_node_created_);
auto *x = AddNode(new ir::Node(name, ir::Node::Type::kVariable));
auto *x = AddNode(new ir::Node(name, ir::Node::Type::kVariable, block_id_));
x->SetId(num_node_created_++);
return x;
}
Expand All @@ -265,7 +273,7 @@ class Graph {
return GetSubGraph(0)->CreateEmptyNode(name, type);
}
}
auto *x = AddNode(new ir::Node(name, type));
auto *x = AddNode(new ir::Node(name, type, block_id_));
x->SetId(num_node_created_++);
return x;
}
Expand Down Expand Up @@ -365,6 +373,15 @@ class Graph {
return sub_graphs_.at(idx).get();
}

int GetBlockId() const {
if (FLAGS_convert_all_blocks) {
if (IsMainGraph()) {
return GetSubGraph(0)->block_id_;
}
}
return block_id_;
}

size_t SubGraphsSize() const {
PADDLE_ENFORCE_EQ(
this->IsMainGraph(), true,
Expand Down Expand Up @@ -394,6 +411,9 @@ class Graph {
PADDLE_ENFORCE_EQ(
this->IsMainGraph(), true,
platform::errors::InvalidArgument("This graph is not main_graph"));
PADDLE_ENFORCE_EQ(sub_graphs_.size(), sub_graph->block_id_,
platform::errors::InvalidArgument(
"sub_graph idx is not equal to block_id_"));
sub_graphs_.push_back(std::move(sub_graph));
}

Expand All @@ -416,6 +436,8 @@ class Graph {
// parts: forward graph and backward graph, which can be executed
// independently.
bool is_partial_{false};
// The block this SubGraph belongs to.
int block_id_{0};
};

bool IsControlDepVar(const ir::Node &var);
Expand Down
80 changes: 80 additions & 0 deletions paddle/fluid/framework/ir/graph_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/ir/graph_helper.h"
#include <queue>
#include <stack>

DEFINE_string(print_sub_graph_dir, "",
Expand Down Expand Up @@ -395,6 +396,85 @@ std::vector<Node *> TopologyVarientSort(const Graph &graph,
}
}

class DescOrderComparator {
public:
bool operator()(const Node *n1, const Node *n2) {
return (n1->DescOrder() > n2->DescOrder()) ||
((n1->DescOrder() == n2->DescOrder()) &&
(n1->ToString() > n2->ToString()));
}
};

std::vector<ir::Node *> TopologySortGraphByDescOrder(const Graph &graph) {
std::vector<ir::Node *> sorted_ops;
std::priority_queue<Node *, std::vector<Node *>, DescOrderComparator> q;
std::unordered_map<Node *, std::unordered_set<Node *>> in_ops;
std::unordered_map<Node *, std::unordered_set<Node *>> out_ops;

// ensure all op node in 'in_ops' and 'out_ops'
for (const auto &n : graph.Nodes()) {
if (!n->IsOp()) continue;

in_ops.emplace(n, std::unordered_set<Node *>());
out_ops.emplace(n, std::unordered_set<Node *>());
}

// record all op's input op and output op
for (const auto &n : graph.Nodes()) {
if (!n->IsOp()) continue;

// traverse all input op
for (const auto &var : n->inputs) {
for (const auto &in : var->inputs) {
// use at instead of [] to prevent no unrecorded op node
in_ops.at(n).insert(in);
out_ops.at(in).insert(n);
}
}
}

// find topology entrance
for (const auto &n : graph.Nodes()) {
if (!n->IsOp()) continue;

if (in_ops.at(n).empty()) {
q.push(n);
}
}

// topological sorting
while (!q.empty()) {
// Do not get by reference!!! The element will pop later.
const auto cur_op = q.top();
q.pop();

sorted_ops.push_back(cur_op);
for (const auto &out : out_ops.at(cur_op)) {
PADDLE_ENFORCE_GT(in_ops.at(out).count(cur_op), 0,
platform::errors::InvalidArgument(
"We find %s in %s's output list, "
"but cannot find %s in %s's input list. "
"Please ensure graph completely.",
out->Name().c_str(), cur_op->Name().c_str(),
cur_op->Name().c_str(), out->Name().c_str()));
in_ops.at(out).erase(cur_op);

// push if in-degree is 0
if (in_ops.at(out).empty()) {
q.push(out);
}
}
}

PADDLE_ENFORCE_EQ(
sorted_ops.size(), in_ops.size(),
platform::errors::InvalidArgument("Topological sorting incompletely, "
"only sorted %zd op but total %zd.",
sorted_ops.size(), in_ops.size()));

return sorted_ops;
}

} // namespace ir
} // namespace framework
} // namespace paddle
2 changes: 2 additions & 0 deletions paddle/fluid/framework/ir/graph_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ std::vector<T *> FilterByNodeWrapper(const Graph &graph) {
return ret;
}

std::vector<ir::Node *> TopologySortGraphByDescOrder(const Graph &graph);

} // namespace ir
} // namespace framework
} // namespace paddle
Loading