Skip to content

Commit

Permalink
refactor Paddle inference (PaddlePaddle#129)
Browse files Browse the repository at this point in the history
* refactor Paddle inference

* cmake
  • Loading branch information
yaozhixin authored Sep 6, 2021
1 parent 42bac06 commit 059ce01
Show file tree
Hide file tree
Showing 17 changed files with 577 additions and 37 deletions.
8 changes: 0 additions & 8 deletions paddle/fluid/framework/ipu/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,6 @@ void Compiler::LowerBody(const ir::Graph* graph) {
auto* op_desc = node->Op();
auto op_type = op_desc->Type();
VLOG(10) << "node->type: " << op_type;
// Paddle inference
if (op_type == "feed" || op_type == "fetch") {
continue;
}

auto itr = name_function_.find(op_type);
if (itr != name_function_.end()) {
Expand Down Expand Up @@ -234,10 +230,6 @@ void Compiler::LowerWeights(const ir::Graph* graph, const Scope* scope_) {
if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
if (node->Var()->Persistable()) {
auto var_name = node->Var()->Name();
// Paddle inference
if (var_name == "feed" || var_name == "fetch") {
continue;
}
auto var = scope_->FindVar(var_name);
if (var) {
auto tensor = var->Get<framework::LoDTensor>();
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,12 @@ if(WITH_IPU)
pass_library(optimizer_extract_pass base DIR ipu DEPS ipu_pass_base)
pass_library(ipu_graph_builder_pass base DIR ipu DEPS ipu_pass_base)
pass_library(ipu_runtime_replacer_pass base DIR ipu DEPS ipu_pass_base)
pass_library(inference_extract_pass base DIR ipu DEPS ipu_pass_base)
pass_library(inference_graph_extract_pass base DIR ipu DEPS ipu_pass_base)
pass_library(inference_compile_pass base DIR ipu DEPS ipu_pass_base)
pass_library(inference_attr_extract_pass base DIR ipu DEPS ipu_pass_base)
pass_library(popart_canonicalization_pass base DIR ipu DEPS ipu_pass_base)
target_link_libraries(popart_canonicalization_pass -Wl,--whole-archive popart_canonicalization_utils -Wl,--no-whole-archive)
pass_library(ipu_inplace_pass base DIR ipu DEPS ipu_pass_base)
target_link_libraries(popart_canonicalization_pass -Wl,--whole-archive popart_canonicalization_utils -Wl,--no-whole-archive)
endif()

cc_library(fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector )
Expand Down
77 changes: 77 additions & 0 deletions paddle/fluid/framework/ir/ipu/inference_attr_extract_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/ipu/inference_attr_extract_pass.h"

#include "paddle/fluid/framework/ipu/ipu_backend.h"
#include "paddle/fluid/framework/ipu/ipu_strategy.h"

#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace framework {
namespace ir {

void InferenceAttrExtractPass::ApplyImpl(ir::Graph* graph) const {
VLOG(10) << "enter InferenceAttrExtractPass::ApplyImpl";

std::shared_ptr<ipu::IpuBackend> ipu_backend = ipu::IpuBackend::GetInstance();

// Get scope
if (graph->Has(kParamScopeAttr)) {
auto& scope = graph->Get<Scope>(kParamScopeAttr);
ipu_backend->SetScope(scope);
} else {
PADDLE_THROW(platform::errors::Unimplemented("Can not find the scope."));
}

// TODO(yaozhixin): ipu_backend manages ipu_strategy
static std::shared_ptr<ipu::IpuStrategy> ipu_strategy_instance_(
new ipu::IpuStrategy());

ipu_strategy_instance_->is_training = false;
auto num_ipus = graph->Get<int>("num_ipus");
ipu_strategy_instance_->num_ipus = num_ipus;
if (num_ipus > 1) {
ipu_strategy_instance_->popart_options_.virtualGraphMode =
ipu::VirtualGraphMode::Manual;
} else {
ipu_strategy_instance_->popart_options_.virtualGraphMode =
ipu::VirtualGraphMode::Off;
}

auto enable_pipelining = graph->Get<bool>("enable_pipelining");
ipu_strategy_instance_->popart_options_.enablePipelining = enable_pipelining;
if (enable_pipelining) {
auto batches_per_step = graph->Get<int>("batches_per_step");
PADDLE_ENFORCE_GE(
batches_per_step, num_ipus,
platform::errors::InvalidArgument("Batched per step should be equal or "
"greater than the number of IPUs"));
ipu_strategy_instance_->batches_per_step = batches_per_step;
}

ipu_backend->SetIpuStrategy(*(ipu_strategy_instance_.get()));

VLOG(10) << "leave InferenceAttrExtractPass::ApplyImpl";
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(inference_attr_extract_pass,
paddle::framework::ir::InferenceAttrExtractPass);
30 changes: 30 additions & 0 deletions paddle/fluid/framework/ir/ipu/inference_attr_extract_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"

namespace paddle {
namespace framework {
namespace ir {

class InferenceAttrExtractPass : public IPUPassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};

} // namespace ir
} // namespace framework
} // namespace paddle
78 changes: 78 additions & 0 deletions paddle/fluid/framework/ir/ipu/inference_compile_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/ipu/inference_compile_pass.h"

#include "paddle/fluid/framework/ir/pass_tester_helper.h"

namespace paddle {
namespace framework {
namespace ir {

void InferenceCompilePass::ApplyImpl(ir::Graph* graph) const {
VLOG(10) << "enter InferenceCompilePass::ApplyImpl";
VLOG(10) << "Raw Graph: ";
VLOG(10) << DebugString(graph);

// // graph_viz_pass
// auto graph_viz_pass = PassRegistry::Instance().Get("graph_viz_pass");
// graph_viz_pass->Set("graph_viz_path",
// new std::string("/before_pass.dot"));
// graph_viz_pass->Apply(graph);

std::vector<std::string> feed_list;
feed_list = Get<std::vector<std::string>>("feed_list");
std::vector<std::string> fetch_list;
fetch_list = Get<std::vector<std::string>>("fetch_list");

auto forward_graph_extract_pass =
PassRegistry::Instance().Get("forward_graph_extract_pass");
forward_graph_extract_pass->Apply(graph);

auto popart_canonicalization_pass =
PassRegistry::Instance().Get("popart_canonicalization_pass");
popart_canonicalization_pass->Apply(graph);

std::vector<std::string> compile_pass = {"ipu_inplace_pass",
"ipu_graph_builder_pass",
"ipu_runtime_replacer_pass"};
for (auto pass_name : compile_pass) {
auto pass = PassRegistry::Instance().Get(pass_name);
pass->Set("feed_list",
new std::vector<std::string>(feed_list.begin(), feed_list.end()));
pass->Set("fetch_list", new std::vector<std::string>(fetch_list.begin(),
fetch_list.end()));
pass->Apply(graph);
}

// // graph_viz_pass
// graph_viz_pass->Erase("graph_viz_path");
// graph_viz_pass->Set("graph_viz_path",
// new std::string("after_pass.dot"));
// graph_viz_pass->Apply(graph);

VLOG(10) << "Post Graph: ";
VLOG(10) << DebugString(graph);
VLOG(10) << "leave InferenceCompilePass::ApplyImpl";
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(inference_compile_pass,
paddle::framework::ir::InferenceCompilePass)
.RequirePassAttr("feed_list")
.RequirePassAttr("fetch_list");
USE_PASS(graph_viz_pass);
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace paddle {
namespace framework {
namespace ir {

class InferenceExtractPass : public IPUPassBase {
class InferenceCompilePass : public IPUPassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
Expand Down
134 changes: 134 additions & 0 deletions paddle/fluid/framework/ir/ipu/inference_graph_extract_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/ipu/inference_graph_extract_pass.h"

#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"

namespace paddle {
namespace framework {
namespace ir {

void InferenceGraphExtractPass::ApplyImpl(ir::Graph* graph) const {
VLOG(10) << "enter InferenceGraphExtractPass::ApplyImpl";

// // graph_viz_pass
// auto graph_viz_pass = PassRegistry::Instance().Get("graph_viz_pass");
// graph_viz_pass->Set("graph_viz_path",
// new std::string("before_pass.dot"));
// graph_viz_pass->Apply(graph);

// save feed and fetch nodes
std::unique_ptr<ir::Node> feed_var;
std::unique_ptr<ir::Node> fetch_var;
std::map<std::string, std::unique_ptr<ir::Node>> feed_ops = {};
std::map<std::string, std::unique_ptr<ir::Node>> fetch_ops = {};

// Get feed_list and fetch_list
std::vector<std::string> feed_list = {};
std::vector<std::string> fetch_list = {};
std::unordered_set<ir::Node*> feed_fetch_nodes = {};
for (auto node : graph->Nodes()) {
if (node->Name() == "feed") {
if (node->IsOp()) {
feed_list.push_back(node->outputs[0]->Name());
}
feed_fetch_nodes.insert(node);
}
if (node->Name() == "fetch") {
if (node->IsOp()) {
fetch_list.push_back(node->inputs[0]->Name());
}
feed_fetch_nodes.insert(node);
}
}

for (auto node : feed_fetch_nodes) {
if (node->Name() == "feed") {
if (node->IsOp()) {
// int64->int32
if (node->outputs[0]->Var()->GetDataType() == proto::VarType::INT64) {
node->outputs[0]->Var()->SetDataType(proto::VarType::INT32);
}
node->outputs[0]->inputs.clear();
feed_ops.emplace(node->outputs[0]->Name(), graph->RemoveNode(node));
} else {
feed_var.reset(graph->RemoveNode(node).release());
}
}
if (node->Name() == "fetch") {
if (node->IsOp()) {
node->inputs[0]->outputs.clear();
fetch_ops.emplace(node->inputs[0]->Name(), graph->RemoveNode(node));
} else {
fetch_var.reset(graph->RemoveNode(node).release());
}
}
}

// Remove useless nodes
std::unordered_set<const Node*> useless_nodes;
for (auto node : graph->Nodes()) {
if ((!node->inputs.size() && !node->outputs.size())) {
useless_nodes.insert(node);
}
}
GraphSafeRemoveNodes(graph, useless_nodes);

auto inference_compile_pass =
PassRegistry::Instance().Get("inference_compile_pass");
inference_compile_pass->Set(
"feed_list",
new std::vector<std::string>(feed_list.begin(), feed_list.end()));
inference_compile_pass->Set(
"fetch_list",
new std::vector<std::string>(fetch_list.begin(), fetch_list.end()));
inference_compile_pass->Apply(graph);

graph->AddNode(feed_var.release());
graph->AddNode(fetch_var.release());
for (auto feed_name : feed_list) {
for (auto node : graph->Nodes()) {
if (node->Name() == feed_name) {
auto feed_op_node = graph->AddNode(feed_ops.at(feed_name).release());
node->inputs.push_back(feed_op_node);
}
}
}
for (auto fetch_name : fetch_list) {
for (auto node : graph->Nodes()) {
if (node->Name() == fetch_name) {
auto fetch_op_node = graph->AddNode(fetch_ops.at(fetch_name).release());
node->outputs.push_back(fetch_op_node);
}
}
}

// // graph_viz_pass
// graph_viz_pass->Erase("graph_viz_path");
// graph_viz_pass->Set("graph_viz_path",
// new std::string("/paddle/after_pass.dot"));
// graph_viz_pass->Apply(graph);

VLOG(10) << "leave InferenceGraphExtractPass::ApplyImpl";
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(inference_graph_extract_pass,
paddle::framework::ir::InferenceGraphExtractPass);
USE_PASS(graph_viz_pass);
Loading

0 comments on commit 059ce01

Please sign in to comment.