Skip to content

Commit

Permalink
split pass file (PaddlePaddle#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
gglin001 authored Jul 23, 2021
1 parent 3965526 commit 53c3b20
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 40 deletions.
18 changes: 2 additions & 16 deletions paddle/fluid/framework/ir/ipu/forward_graph_extract_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/ipu/forward_graph_extract_pass.h"

#include <glog/logging.h>

#include <algorithm>
Expand All @@ -24,28 +26,12 @@
#include <unordered_map>
#include <unordered_set>

#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/program_desc.h"

// debug
#include "paddle/fluid/framework/ir/pass_tester_helper.h"

namespace paddle {
namespace framework {
namespace ir {

class Graph;

class ForwardGraphExtractPass : public IPUPassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;

private:
};

void ForwardGraphExtractPass::ApplyImpl(ir::Graph* graph) const {
VLOG(10) << "enter ForwardGraphExtractPass::ApplyImpl";
// find forward ops
Expand Down
29 changes: 29 additions & 0 deletions paddle/fluid/framework/ir/ipu/forward_graph_extract_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"

namespace paddle {
namespace framework {
namespace ir {

class ForwardGraphExtractPass : public IPUPassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};

} // namespace ir
} // namespace framework
} // namespace paddle
41 changes: 17 additions & 24 deletions paddle/fluid/framework/ir/ipu/ipu_runtime_replacer_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/ipu/ipu_runtime_replacer_pass.h"

#include <algorithm>
#include <array>
#include <fstream>
Expand All @@ -22,37 +24,26 @@
#include <unordered_map>
#include <unordered_set>

#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"

// debug
#include "paddle/fluid/framework/ir/pass_tester_helper.h"

namespace paddle {
namespace framework {
namespace ir {

class Graph;

class IpuRuntimeReplacerPass : public IPUPassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};

void IpuRuntimeReplacerPass::ApplyImpl(ir::Graph* graph) const {
VLOG(10) << "enter IpuRuntimeReplacerPass::ApplyImpl";

VLOG(10) << "Raw Graph: ";
VLOG(10) << DebugString(graph);

// graph_viz_pass
auto graph_viz_pass = PassRegistry::Instance().Get("graph_viz_pass");
graph_viz_pass->Set("graph_viz_path",
new std::string("/home/Paddle/demos/before_ipu_runtime_replacer_pass.dot"));
graph_viz_pass->Apply(graph);
// // graph_viz_pass
// auto graph_viz_pass = PassRegistry::Instance().Get("graph_viz_pass");
// graph_viz_pass->Set(
// "graph_viz_path",
// new std::string(
// "/home/Paddle/demos/before_ipu_runtime_replacer_pass.dot"));
// graph_viz_pass->Apply(graph);

std::vector<std::string> feed_list;
feed_list = Get<std::vector<std::string>>("feed_list");
Expand All @@ -67,7 +58,7 @@ void IpuRuntimeReplacerPass::ApplyImpl(ir::Graph* graph) const {
ipu_rt_op_desc.Flush();

// Create a new node for the ipu_runtime_op.
auto *ipu_rt_node = graph->CreateOpNode(&ipu_rt_op_desc);
auto* ipu_rt_node = graph->CreateOpNode(&ipu_rt_op_desc);

for (auto* node : graph->Nodes()) {
if (node->IsVar()) {
Expand Down Expand Up @@ -100,11 +91,13 @@ void IpuRuntimeReplacerPass::ApplyImpl(ir::Graph* graph) const {
VLOG(10) << "Post Graph: ";
VLOG(10) << DebugString(graph);

// graph_viz_pass
graph_viz_pass->Erase("graph_viz_path");
graph_viz_pass->Set("graph_viz_path",
new std::string("/home/Paddle/demos/after_ipu_runtime_replacer_pass.dot"));
graph_viz_pass->Apply(graph);
// // graph_viz_pass
// graph_viz_pass->Erase("graph_viz_path");
// graph_viz_pass->Set(
// "graph_viz_path",
// new std::string(
// "/home/Paddle/demos/after_ipu_runtime_replacer_pass.dot"));
// graph_viz_pass->Apply(graph);

VLOG(10) << "leave IpuRuntimeReplacerPass::ApplyImpl";
}
Expand Down
29 changes: 29 additions & 0 deletions paddle/fluid/framework/ir/ipu/ipu_runtime_replacer_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"

namespace paddle {
namespace framework {
namespace ir {

class IpuRuntimeReplacerPass : public IPUPassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};

} // namespace ir
} // namespace framework
} // namespace paddle

0 comments on commit 53c3b20

Please sign in to comment.