Skip to content

Commit

Permalink
Merge upstream develop to fix conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
veyron95 committed Sep 17, 2021
2 parents 4f6dc7f + 0eaab80 commit a99dcd8
Show file tree
Hide file tree
Showing 106 changed files with 4,571 additions and 817 deletions.
2 changes: 0 additions & 2 deletions paddle/fluid/extension/src/ext_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ void DeviceCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc,
PADDLE_THROW(platform::errors::Unavailable(
"Only GPU related Copy can reach this func."));
}
cudaStreamSynchronize(dev_ctx->stream());
#elif defined(PADDLE_WITH_HIP)
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
int device_num = paddle::platform::GetCurrentDeviceId();
Expand All @@ -110,7 +109,6 @@ void DeviceCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc,
PADDLE_THROW(platform::errors::Unavailable(
"Only GPU related Copy can reach this func."));
}
hipStreamSynchronize(dev_ctx->stream());
#else
PADDLE_THROW(platform::errors::Unavailable(
"This function can only be used if compiled with"
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ add_subdirectory(io)
add_subdirectory(new_executor)
#ddim lib
proto_library(framework_proto SRCS framework.proto)
proto_library(pass_desc_proto SRCS pass_desc.proto DEPS framework_proto)

proto_library(op_def_proto SRCS op_def.proto DEPS framework_proto)
cc_library(op_def_api SRCS op_def_api.cc DEPS op_def_proto boost)
Expand Down
14 changes: 14 additions & 0 deletions paddle/fluid/framework/hogwild_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ void HogwildWorker::TrainFiles() {
// how to accumulate fetched values here
device_reader_->Start();
int cur_batch;
int batch_cnt = 0;
while ((cur_batch = device_reader_->Next()) > 0) {
for (auto &op : ops_) {
bool need_skip = false;
Expand All @@ -230,13 +231,26 @@ void HogwildWorker::TrainFiles() {
}
}

if (need_dump_field_) {
DumpField(*thread_scope_, dump_mode_, dump_interval_);
}
if (need_dump_param_ && thread_id_ == 0) {
DumpParam(*thread_scope_, batch_cnt);
}

total_ins_num += cur_batch;
++batch_cnt;
PrintFetchVars();
thread_scope_->DropKids();
}
timeline.Pause();
VLOG(3) << "worker " << thread_id_ << " train cost " << timeline.ElapsedSec()
<< " seconds, ins_num: " << total_ins_num;

if (need_dump_field_ || need_dump_param_) {
writer_.Flush();
}

#if defined PADDLE_WITH_PSCORE
if (thread_barrier_) {
paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ pass_library(multihead_matmul_fuse_pass inference)
pass_library(adaptive_pool2d_convert_global_pass inference)
pass_library(unsqueeze2_eltwise_fuse_pass inference)
pass_library(layer_norm_fuse_pass inference)
pass_library(generate_pass DEPS pass_desc_proto)
target_link_libraries(generate_pass pass_desc_proto)
if(WITH_GPU OR WITH_ROCM)
pass_library(cudnn_placement_pass base DEPS placement_pass_base)
pass_library(embedding_eltwise_layernorm_fuse_pass inference)
Expand Down Expand Up @@ -156,6 +158,7 @@ cc_test(test_conv_bn_fuse_pass_cc SRCS conv_bn_fuse_pass_tester.cc DEPS conv_bn_
cc_test(test_adaptive_pool2d_convert_global_pass SRCS adaptive_pool2d_convert_global_pass_tester.cc DEPS adaptive_pool2d_convert_global_pass)
cc_test(test_unsqueeze2_eltwise_fuse_pass SRCS unsqueeze2_eltwise_fuse_pass_tester.cc DEPS unsqueeze2_eltwise_fuse_pass)
cc_test(test_layer_norm_fuse_pass_cc SRCS layer_norm_fuse_pass_tester.cc DEPS layer_norm_fuse_pass pass_test_util naive_executor)
cc_test(test_generate_pass_cc SRCS generate_pass_tester.cc DEPS generate_pass pass_desc_proto)
if(WITH_GPU OR WITH_ROCM)
cc_test(test_embedding_eltwise_layernorm_fuse_pass SRCS embedding_eltwise_layernorm_fuse_pass_tester.cc DEPS embedding_eltwise_layernorm_fuse_pass)
cc_test(test_cudnn_placement_pass SRCS cudnn_placement_pass_tester.cc DEPS cudnn_placement_pass)
Expand Down
229 changes: 229 additions & 0 deletions paddle/fluid/framework/ir/generate_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/generate_pass.h"

namespace paddle {
namespace framework {
namespace ir {

void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) {
const proto::BlockDesc& block = pass_desc.pattern().blocks(0);
// Traverse all operators to create subgraph.
for (int index = 0; index < block.ops_size(); ++index) {
const proto::OpDesc& op = block.ops(index);
// Create a PDNode for current operator. Use the index as name to avoid
// multiple operators with same type. Get a PDNode from pattern subgraph
// through index in rewrite phase.
PDNode* op_pdnode =
pattern->NewNode(std::to_string(index))->assert_is_op(op.type());
// Create PDNodes for inputs of current operator.
for (const proto::OpDesc::Var& var : op.inputs()) {
for (const std::string& argument : var.arguments()) {
// The input may be the output of other operator.
PDNode* var_pdnode = pattern->RetrieveNode(argument);
if (nullptr == var_pdnode) {
var_pdnode = pattern->NewNode(argument)->AsInput();
} else if (var_pdnode->IsOutput()) {
var_pdnode->AsIntermediate();
}
var_pdnode->assert_is_op_input(op.type());
pattern->AddEdge(var_pdnode, op_pdnode);
}
}
// Create PDNodes for outputs of current operator.
for (const proto::OpDesc::Var& var : op.outputs()) {
for (const std::string& argument : var.arguments()) {
// The output may be the input of other operator.
PDNode* var_pdnode = pattern->RetrieveNode(argument);
if (nullptr == var_pdnode) {
var_pdnode = pattern->NewNode(argument)->AsOutput();
} else if (var_pdnode->IsInput()) {
var_pdnode->AsIntermediate();
}
var_pdnode->assert_is_op_output(op.type());
pattern->AddEdge(op_pdnode, var_pdnode);
}
}
// Set attribute condition for current operator.
for (const proto::OpDesc::Attr& attr : op.attrs()) {
op_pdnode->assert_more([&](Node* x) {
if (x && x->IsOp()) {
OpDesc* op_desc = x->Op();
if (op_desc->HasAttr(attr.name())) {
return GetAttrValue(attr) == op_desc->GetAttr(attr.name());
}
return false;
}
return false;
});
}
}
}

GraphPatternDetector::handle_t GetGenerateRewrite(
const PDPattern& pattern, const proto::PassDesc& pass_desc) {
GraphPatternDetector::handle_t handler = [&](
const GraphPatternDetector::subgraph_t subgraph, Graph* graph) {
// There are some duplicate patterns.
for (auto iter : subgraph) {
if (nullptr == graph->RetrieveNode(iter.second->id())) {
VLOG(3) << "Node [" << iter.second->Name()
<< "] of subgraph has been removed. So skip this optimize.";
return;
}
}
const proto::BlockDesc& block = pass_desc.replace().blocks(0);
// `var_node_maps` record the mapping of variable to the pattern subgraph.
std::map<std::string, Node*> var_node_maps;
for (const proto::PassDesc::VarMap& var_map : pass_desc.var_maps()) {
Node* node = subgraph.at(pattern.RetrieveNode(var_map.pattern_var()));
var_node_maps.insert({var_map.replace_var(), node});
}
// Traverse all operators to create subgraph.
for (const proto::OpDesc& op : block.ops()) {
OpDesc op_desc;
std::vector<Node *> in_nodes, out_nodes;
op_desc.SetType(op.type());
// Create Nodes for inputs of current operator.
for (const proto::OpDesc::Var& var : op.inputs()) {
std::vector<std::string> arguments;
for (const std::string& argument : var.arguments()) {
// The input may be mapped on the operator of pattern subgraph.
Node* node = nullptr;
auto iter = var_node_maps.find(argument);
if (var_node_maps.end() == iter) {
VarDesc var_desc(patterns::UniqueKey(argument));
node = graph->CreateVarNode(&var_desc);
var_node_maps.insert({argument, node});
} else {
node = iter->second;
}
in_nodes.push_back(node);
arguments.push_back(node->Name());
}
op_desc.SetInput(var.parameter(), arguments);
}
// Create Nodes for outputs of current operator.
for (const proto::OpDesc::Var& var : op.outputs()) {
std::vector<std::string> arguments;
for (const std::string& argument : var.arguments()) {
// The output may be mapped on the operator of pattern subgraph.
Node* node = nullptr;
auto iter = var_node_maps.find(argument);
if (var_node_maps.end() == iter) {
VarDesc var_desc(patterns::UniqueKey(argument));
node = graph->CreateVarNode(&var_desc);
var_node_maps.insert({argument, node});
} else {
node = iter->second;
}
out_nodes.push_back(node);
arguments.push_back(node->Name());
}
op_desc.SetOutput(var.parameter(), arguments);
}
// Set attribute for current operator.
for (const proto::OpDesc::Attr& attr : op.attrs()) {
op_desc.SetAttr(attr.name(), GetAttrValue(attr));
}
// Create a Node for current operator.
Node* op_node = graph->CreateOpNode(&op_desc);
for (Node* node : in_nodes) {
IR_NODE_LINK_TO(node, op_node);
}
for (Node* node : out_nodes) {
IR_NODE_LINK_TO(op_node, node);
}
}
// Remove nodes that are intermediate.
std::unordered_set<const Node*> remove_nodes;
for (const std::unique_ptr<PDNode>& pdnode : pattern.nodes()) {
remove_nodes.emplace(subgraph.at(pdnode.get()));
}
for (auto iter : var_node_maps) {
remove_nodes.erase(iter.second);
}
GraphSafeRemoveNodes(graph, remove_nodes);
};
return handler;
}

GeneratePass::GeneratePass(const std::string& binary_str) {
multi_pass_desc_.ParseFromString(binary_str);
VerifyDesc();
}

GeneratePass::GeneratePass(const proto::MultiPassDesc& multi_pass_desc)
: multi_pass_desc_(multi_pass_desc) {
VerifyDesc();
}

void GeneratePass::ApplyImpl(Graph* graph) const {
for (const proto::PassDesc& pass_desc : multi_pass_desc_.pass_descs()) {
GraphPatternDetector detector;
InitGeneratePattern(pass_desc, detector.mutable_pattern());
detector(graph, GetGenerateRewrite(detector.pattern(), pass_desc));
// The rewrited graph needs to be verified. Current Pass should be skipped
// if validation failed. Rewrite based on the original graph cannot
// implement rollback operation.
VerifyGraph(*graph);
}
}

void GeneratePass::VerifyDesc() const {
PADDLE_ENFORCE_NE(multi_pass_desc_.pass_descs_size(), 0,
platform::errors::InvalidArgument(
"Size of PassDesc should not be empty."));
for (const proto::PassDesc& pass_desc : multi_pass_desc_.pass_descs()) {
// Check inputs/outputs of subgraph should in `var_maps`.
std::set<std::string> pattern_var_sets, replace_var_sets;
for (const proto::PassDesc::VarMap& var_map : pass_desc.var_maps()) {
pattern_var_sets.emplace(var_map.pattern_var());
replace_var_sets.emplace(var_map.replace_var());
}
auto check_vars = [=](std::set<std::string>* var_sets,
const proto::BlockDesc& block) {
for (const proto::OpDesc& op : block.ops()) {
for (const proto::OpDesc::Var& var : op.outputs()) {
for (const std::string& argument : var.arguments()) {
var_sets->emplace(argument);
}
}
}
for (const proto::OpDesc& op : block.ops()) {
for (const proto::OpDesc::Var& var : op.inputs()) {
for (const std::string& argument : var.arguments()) {
PADDLE_ENFORCE_NE(
var_sets->find(argument), var_sets->end(),
platform::errors::InvalidArgument(
"Subgraph of PassDesc has argument [%s] not in `var_maps`.",
argument));
}
}
}
};
check_vars(&pattern_var_sets, pass_desc.pattern().blocks(0));
check_vars(&replace_var_sets, pass_desc.replace().blocks(0));
}
}

bool GeneratePass::VerifyGraph(const Graph& graph) {
// Return true temporarily.
return true;
}

} // namespace ir
} // namespace framework
} // namespace paddle
48 changes: 48 additions & 0 deletions paddle/fluid/framework/ir/generate_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/pass_desc.pb.h"

namespace paddle {
namespace framework {
namespace ir {

// Generate a substitute pass from protobuf.
class GeneratePass : public Pass {
public:
// from binary_str
explicit GeneratePass(const std::string& binary_str);
// from PassDesc/MultiPassDesc
explicit GeneratePass(const proto::MultiPassDesc& multi_pass_desc);

protected:
void ApplyImpl(Graph* graph) const override;

private:
GeneratePass() = delete;
DISABLE_COPY_AND_ASSIGN(GeneratePass);
// Verify desc
void VerifyDesc() const;
// Verify graph
static bool VerifyGraph(const Graph& graph);

proto::MultiPassDesc multi_pass_desc_;
};

} // namespace ir
} // namespace framework
} // namespace paddle
Loading

0 comments on commit a99dcd8

Please sign in to comment.