forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor Paddle inference (PaddlePaddle#129)
* refactor Paddle inference * cmake
- Loading branch information
Showing
17 changed files
with
577 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
77 changes: 77 additions & 0 deletions
77
paddle/fluid/framework/ir/ipu/inference_attr_extract_pass.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/framework/ir/ipu/inference_attr_extract_pass.h" | ||
|
||
#include "paddle/fluid/framework/ipu/ipu_backend.h" | ||
#include "paddle/fluid/framework/ipu/ipu_strategy.h" | ||
|
||
#include "paddle/fluid/framework/ir/fuse_pass_base.h" | ||
#include "paddle/fluid/framework/ir/pass_tester_helper.h" | ||
#include "paddle/fluid/platform/enforce.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
namespace ir { | ||
|
||
void InferenceAttrExtractPass::ApplyImpl(ir::Graph* graph) const { | ||
VLOG(10) << "enter InferenceAttrExtractPass::ApplyImpl"; | ||
|
||
std::shared_ptr<ipu::IpuBackend> ipu_backend = ipu::IpuBackend::GetInstance(); | ||
|
||
// Get scope | ||
if (graph->Has(kParamScopeAttr)) { | ||
auto& scope = graph->Get<Scope>(kParamScopeAttr); | ||
ipu_backend->SetScope(scope); | ||
} else { | ||
PADDLE_THROW(platform::errors::Unimplemented("Can not find the scope.")); | ||
} | ||
|
||
// TODO(yaozhixin): ipu_backend manages ipu_strategy | ||
static std::shared_ptr<ipu::IpuStrategy> ipu_strategy_instance_( | ||
new ipu::IpuStrategy()); | ||
|
||
ipu_strategy_instance_->is_training = false; | ||
auto num_ipus = graph->Get<int>("num_ipus"); | ||
ipu_strategy_instance_->num_ipus = num_ipus; | ||
if (num_ipus > 1) { | ||
ipu_strategy_instance_->popart_options_.virtualGraphMode = | ||
ipu::VirtualGraphMode::Manual; | ||
} else { | ||
ipu_strategy_instance_->popart_options_.virtualGraphMode = | ||
ipu::VirtualGraphMode::Off; | ||
} | ||
|
||
auto enable_pipelining = graph->Get<bool>("enable_pipelining"); | ||
ipu_strategy_instance_->popart_options_.enablePipelining = enable_pipelining; | ||
if (enable_pipelining) { | ||
auto batches_per_step = graph->Get<int>("batches_per_step"); | ||
PADDLE_ENFORCE_GE( | ||
batches_per_step, num_ipus, | ||
platform::errors::InvalidArgument("Batched per step should be equal or " | ||
"greater than the number of IPUs")); | ||
ipu_strategy_instance_->batches_per_step = batches_per_step; | ||
} | ||
|
||
ipu_backend->SetIpuStrategy(*(ipu_strategy_instance_.get())); | ||
|
||
VLOG(10) << "leave InferenceAttrExtractPass::ApplyImpl"; | ||
} | ||
|
||
} // namespace ir | ||
} // namespace framework | ||
} // namespace paddle | ||
|
||
REGISTER_PASS(inference_attr_extract_pass, | ||
paddle::framework::ir::InferenceAttrExtractPass); |
30 changes: 30 additions & 0 deletions
30
paddle/fluid/framework/ir/ipu/inference_attr_extract_pass.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
namespace ir { | ||
|
||
class InferenceAttrExtractPass : public IPUPassBase { | ||
protected: | ||
void ApplyImpl(ir::Graph* graph) const override; | ||
}; | ||
|
||
} // namespace ir | ||
} // namespace framework | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/framework/ir/ipu/inference_compile_pass.h" | ||
|
||
#include "paddle/fluid/framework/ir/pass_tester_helper.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
namespace ir { | ||
|
||
void InferenceCompilePass::ApplyImpl(ir::Graph* graph) const { | ||
VLOG(10) << "enter InferenceCompilePass::ApplyImpl"; | ||
VLOG(10) << "Raw Graph: "; | ||
VLOG(10) << DebugString(graph); | ||
|
||
// // graph_viz_pass | ||
// auto graph_viz_pass = PassRegistry::Instance().Get("graph_viz_pass"); | ||
// graph_viz_pass->Set("graph_viz_path", | ||
// new std::string("/before_pass.dot")); | ||
// graph_viz_pass->Apply(graph); | ||
|
||
std::vector<std::string> feed_list; | ||
feed_list = Get<std::vector<std::string>>("feed_list"); | ||
std::vector<std::string> fetch_list; | ||
fetch_list = Get<std::vector<std::string>>("fetch_list"); | ||
|
||
auto forward_graph_extract_pass = | ||
PassRegistry::Instance().Get("forward_graph_extract_pass"); | ||
forward_graph_extract_pass->Apply(graph); | ||
|
||
auto popart_canonicalization_pass = | ||
PassRegistry::Instance().Get("popart_canonicalization_pass"); | ||
popart_canonicalization_pass->Apply(graph); | ||
|
||
std::vector<std::string> compile_pass = {"ipu_inplace_pass", | ||
"ipu_graph_builder_pass", | ||
"ipu_runtime_replacer_pass"}; | ||
for (auto pass_name : compile_pass) { | ||
auto pass = PassRegistry::Instance().Get(pass_name); | ||
pass->Set("feed_list", | ||
new std::vector<std::string>(feed_list.begin(), feed_list.end())); | ||
pass->Set("fetch_list", new std::vector<std::string>(fetch_list.begin(), | ||
fetch_list.end())); | ||
pass->Apply(graph); | ||
} | ||
|
||
// // graph_viz_pass | ||
// graph_viz_pass->Erase("graph_viz_path"); | ||
// graph_viz_pass->Set("graph_viz_path", | ||
// new std::string("after_pass.dot")); | ||
// graph_viz_pass->Apply(graph); | ||
|
||
VLOG(10) << "Post Graph: "; | ||
VLOG(10) << DebugString(graph); | ||
VLOG(10) << "leave InferenceCompilePass::ApplyImpl"; | ||
} | ||
|
||
} // namespace ir | ||
} // namespace framework | ||
} // namespace paddle | ||
|
||
REGISTER_PASS(inference_compile_pass, | ||
paddle::framework::ir::InferenceCompilePass) | ||
.RequirePassAttr("feed_list") | ||
.RequirePassAttr("fetch_list"); | ||
USE_PASS(graph_viz_pass); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
134 changes: 134 additions & 0 deletions
134
paddle/fluid/framework/ir/ipu/inference_graph_extract_pass.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/framework/ir/ipu/inference_graph_extract_pass.h" | ||
|
||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" | ||
#include "paddle/fluid/framework/ir/pass_tester_helper.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
namespace ir { | ||
|
||
void InferenceGraphExtractPass::ApplyImpl(ir::Graph* graph) const { | ||
VLOG(10) << "enter InferenceGraphExtractPass::ApplyImpl"; | ||
|
||
// // graph_viz_pass | ||
// auto graph_viz_pass = PassRegistry::Instance().Get("graph_viz_pass"); | ||
// graph_viz_pass->Set("graph_viz_path", | ||
// new std::string("before_pass.dot")); | ||
// graph_viz_pass->Apply(graph); | ||
|
||
// save feed and fetch nodes | ||
std::unique_ptr<ir::Node> feed_var; | ||
std::unique_ptr<ir::Node> fetch_var; | ||
std::map<std::string, std::unique_ptr<ir::Node>> feed_ops = {}; | ||
std::map<std::string, std::unique_ptr<ir::Node>> fetch_ops = {}; | ||
|
||
// Get feed_list and fetch_list | ||
std::vector<std::string> feed_list = {}; | ||
std::vector<std::string> fetch_list = {}; | ||
std::unordered_set<ir::Node*> feed_fetch_nodes = {}; | ||
for (auto node : graph->Nodes()) { | ||
if (node->Name() == "feed") { | ||
if (node->IsOp()) { | ||
feed_list.push_back(node->outputs[0]->Name()); | ||
} | ||
feed_fetch_nodes.insert(node); | ||
} | ||
if (node->Name() == "fetch") { | ||
if (node->IsOp()) { | ||
fetch_list.push_back(node->inputs[0]->Name()); | ||
} | ||
feed_fetch_nodes.insert(node); | ||
} | ||
} | ||
|
||
for (auto node : feed_fetch_nodes) { | ||
if (node->Name() == "feed") { | ||
if (node->IsOp()) { | ||
// int64->int32 | ||
if (node->outputs[0]->Var()->GetDataType() == proto::VarType::INT64) { | ||
node->outputs[0]->Var()->SetDataType(proto::VarType::INT32); | ||
} | ||
node->outputs[0]->inputs.clear(); | ||
feed_ops.emplace(node->outputs[0]->Name(), graph->RemoveNode(node)); | ||
} else { | ||
feed_var.reset(graph->RemoveNode(node).release()); | ||
} | ||
} | ||
if (node->Name() == "fetch") { | ||
if (node->IsOp()) { | ||
node->inputs[0]->outputs.clear(); | ||
fetch_ops.emplace(node->inputs[0]->Name(), graph->RemoveNode(node)); | ||
} else { | ||
fetch_var.reset(graph->RemoveNode(node).release()); | ||
} | ||
} | ||
} | ||
|
||
// Remove useless nodes | ||
std::unordered_set<const Node*> useless_nodes; | ||
for (auto node : graph->Nodes()) { | ||
if ((!node->inputs.size() && !node->outputs.size())) { | ||
useless_nodes.insert(node); | ||
} | ||
} | ||
GraphSafeRemoveNodes(graph, useless_nodes); | ||
|
||
auto inference_compile_pass = | ||
PassRegistry::Instance().Get("inference_compile_pass"); | ||
inference_compile_pass->Set( | ||
"feed_list", | ||
new std::vector<std::string>(feed_list.begin(), feed_list.end())); | ||
inference_compile_pass->Set( | ||
"fetch_list", | ||
new std::vector<std::string>(fetch_list.begin(), fetch_list.end())); | ||
inference_compile_pass->Apply(graph); | ||
|
||
graph->AddNode(feed_var.release()); | ||
graph->AddNode(fetch_var.release()); | ||
for (auto feed_name : feed_list) { | ||
for (auto node : graph->Nodes()) { | ||
if (node->Name() == feed_name) { | ||
auto feed_op_node = graph->AddNode(feed_ops.at(feed_name).release()); | ||
node->inputs.push_back(feed_op_node); | ||
} | ||
} | ||
} | ||
for (auto fetch_name : fetch_list) { | ||
for (auto node : graph->Nodes()) { | ||
if (node->Name() == fetch_name) { | ||
auto fetch_op_node = graph->AddNode(fetch_ops.at(fetch_name).release()); | ||
node->outputs.push_back(fetch_op_node); | ||
} | ||
} | ||
} | ||
|
||
// // graph_viz_pass | ||
// graph_viz_pass->Erase("graph_viz_path"); | ||
// graph_viz_pass->Set("graph_viz_path", | ||
// new std::string("/paddle/after_pass.dot")); | ||
// graph_viz_pass->Apply(graph); | ||
|
||
VLOG(10) << "leave InferenceGraphExtractPass::ApplyImpl"; | ||
} | ||
|
||
} // namespace ir | ||
} // namespace framework | ||
} // namespace paddle | ||
|
||
REGISTER_PASS(inference_graph_extract_pass, | ||
paddle::framework::ir::InferenceGraphExtractPass); | ||
USE_PASS(graph_viz_pass); |
Oops, something went wrong.