Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] lowering group op through op fusion && fusion merge pass and kernel jit pass #58193

Merged
merged 33 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
5de7cd4
update
phlrain Oct 11, 2023
d48461a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Oct 11, 2023
229adb9
update
phlrain Oct 12, 2023
7136bc9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Oct 12, 2023
abb1f18
revert some code
phlrain Oct 12, 2023
5ce86e5
update
phlrain Oct 13, 2023
ed6b73f
poish print
phlrain Oct 13, 2023
91ae2dd
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Oct 13, 2023
589315d
fix compile bug
phlrain Oct 13, 2023
cdb89d0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Oct 13, 2023
2d7d3fb
update
phlrain Oct 16, 2023
8d42efc
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Oct 16, 2023
f73c664
update
phlrain Oct 16, 2023
3d220bc
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Oct 16, 2023
8af3afe
polish code
phlrain Oct 16, 2023
9d26bab
update
phlrain Oct 17, 2023
1c1cf65
Merge commit 'refs/pull/58043/head' of https://github.com/PaddlePaddl…
phlrain Oct 17, 2023
de815f6
update
phlrain Oct 17, 2023
7a9e4ce
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Oct 17, 2023
479d6a1
update
phlrain Oct 18, 2023
49eb6e4
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Oct 18, 2023
c77e1c5
update
phlrain Oct 18, 2023
24a04ae
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Oct 18, 2023
cab0618
update
phlrain Oct 18, 2023
6acae1b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Oct 18, 2023
755e6ad
remove useless code
phlrain Oct 19, 2023
b4ab665
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Oct 19, 2023
c7d70ec
remove useless code
phlrain Oct 19, 2023
a97889b
remove useless code
phlrain Oct 19, 2023
418e97c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Oct 19, 2023
f0fd7ec
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Oct 23, 2023
cb1e8ad
fix compile bug
phlrain Oct 23, 2023
e8ae3fe
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Oct 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ if(NOT CINN_ONLY)
SRCS
group_with_group_merge_pass.cc
op_with_group_merge_pass.cc
cinn_group_lowering_pass.cc
tensor_node.cc
DEPS
pd_op_dialect)
pd_op_dialect
pir_compiler
cinn_runtime_dialect)
endif()
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/cinn/hlir/dialect/operator/transforms/cinn_group_lowering_pass.h"

#include <unordered_map>

#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/cinn/hlir/dialect/operator/ir/op_attribute.h"
#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.h"
#include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h"
#include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h"
#include "paddle/cinn/hlir/framework/pir_compiler.h"
#include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h"
#include "paddle/pir/dialect/control_flow/ir/cf_ops.h"

namespace cinn {
namespace dialect {
namespace ir {

std::vector<pir::Value> GetBlockOutsideInput(
const std::vector<pir::Operation*> op_list) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const std::vector<pir::Operation*> op_list) {
const std::vector<pir::Operation*>& op_list) {

std::vector<pir::Value> vec_res;
std::unordered_set<::pir::Value> block_inner_output;
for (size_t k = 0; k < op_list.size(); ++k) {
for (size_t i = 0; i < op_list[k]->num_results(); ++i) {
block_inner_output.insert(op_list[k]->result(i));
}
}

for (size_t k = 0; k < op_list.size(); ++k) {
for (size_t i = 0; i < op_list[k]->num_operands(); ++i) {
if (!block_inner_output.count(op_list[k]->operand_source(i))) {
vec_res.push_back(op_list[k]->operand_source(i));
}
}
}

return vec_res;
}

std::vector<pir::Value> GetBlockOutsideOutput(
const std::vector<pir::Operation*> op_list) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const std::vector<pir::Operation*> op_list) {
const std::vector<pir::Operation*>& op_list) {

std::vector<pir::Value> vec_res;
std::unordered_set<::pir::Value> block_inner_output;
for (size_t k = 0; k < op_list.size(); ++k) {
for (size_t i = 0; i < op_list[k]->num_operands(); ++i) {
block_inner_output.insert(op_list[k]->operand_source(i));
}
}

for (size_t k = 0; k < op_list.size(); ++k) {
for (size_t i = 0; i < op_list[k]->num_results(); ++i) {
if (!block_inner_output.count(op_list[k]->result(i))) {
vec_res.push_back(op_list[k]->result(i));
}
}
}

return vec_res;
}

std::vector<pir::Operation*> GetOpListNotIncludeYield(
const std::vector<pir::Operation*>& op_list) {
std::vector<pir::Operation*> vec_res;
for (size_t i = 0; i < op_list.size(); ++i) {
if (!op_list[i]->isa<pir::YieldOp>()) {
vec_res.push_back(op_list[i]);
}
}

return vec_res;
}

std::unique_ptr<pir::Program> CINNGroupLoweringPass(::pir::Program* program) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::unique_ptr<pir::Program> CINNGroupLoweringPass(::pir::Program* program) {
std::unique_ptr<pir::Program> CINNGroupLoweringPass(const ::pir::Program* program) {

这里可以const一下?

::pir::IrContext* ctx = ::pir::IrContext::Instance();
ctx->GetOrRegisterDialect<cinn::dialect::RuntimeDialect>();
ctx->GetOrRegisterDialect<cinn::dialect::OperatorDialect>();
ctx->GetOrRegisterDialect<paddle::dialect::KernelDialect>();

std::string jit_op_name = cinn::dialect::JitKernelOp::name();
::pir::OpInfo op_info = ctx->GetRegisteredOpInfo(jit_op_name);

auto ir_program = std::make_unique<::pir::Program>(ctx);
std::unordered_map<pir::Value, pir::Value> value_map;
std::vector<cinn::hlir::framework::PIRCompiler*> compiler_list;

auto target = cinn::common::DefaultNVGPUTarget();
auto scope = cinn::hlir::framework::BuildScope(target, *program);

for (auto it = program->block()->begin(); it != program->block()->end();
++it) {
if ((*it)->isa<cinn::dialect::GroupOp>()) {
// GetOpList and Call cinn CodeGen
auto group_op = (*it)->dyn_cast<cinn::dialect::GroupOp>();

// op fusion
auto op_fusion = cinn::dialect::ir::OpFusionPassInternal(
GetOpListNotIncludeYield(group_op.ops()));

// fusion merge
auto group_list =
cinn::dialect::ir::GeneralFusionMergePassInternal(op_fusion);

PADDLE_ENFORCE_EQ(group_list.size(),
1u,
phi::errors::Unimplemented(
"Only support one group after group fusion"));
for (auto group : group_list) {
auto ir_compiler =
new cinn::hlir::framework::PIRCompiler(*program, target, scope);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里会有潜在的mem leak风险,如之前讨论,我们还没有确定最终如何统一管理PIRCompiler 对象,但可以先简单地用单例持有下,这样能保证退出后可以delete掉这个new的对象

auto group1 =
std::make_shared<cinn::hlir::framework::pir::Group>(group->nodes);
auto fn_ptr_res = ir_compiler->BuildCUDAJITInfo({group1});
compiler_list.push_back(ir_compiler);
std::unordered_map<std::string, ::pir::Attribute> op_attrs{
{cinn::dialect::JitKernelOp::kAttrName,
cinn::dialect::CUDAJITInfoAttribute::get(ctx, fn_ptr_res[0])},
};

// Generate jit kernel op input and output
auto vec_ins = GetBlockOutsideInput(group->nodes);

std::vector<pir::Value> vec_new_ins;
for (size_t i = 0; i < vec_ins.size(); ++i) {
vec_new_ins.push_back(value_map.at(vec_ins[i]));
}

auto vec_outs = GetBlockOutsideOutput(group->nodes);

std::vector<pir::Type> vec_types;
for (auto& out : vec_outs) {
vec_types.push_back(out.type());
}

::pir::Operation* cinn_op =
::pir::Operation::Create(vec_new_ins, op_attrs, vec_types, op_info);

// for (size_t i = 0; i < vec_outs.size(); ++i) {
// value_map[vec_outs[i]] = cinn_op->result(i);
// }

// auto yield_op = group_op.ops().back()->dyn_cast<pir::YieldOp>();
for (size_t i = 0; i < group_op.num_results(); ++i) {
value_map[group_op.result(i)] = cinn_op->result(i);
}

ir_program->block()->push_back(cinn_op);
}

} else {
std::vector<pir::Value> vec_ins;

for (size_t i = 0; i < (*it)->num_operands(); ++i) {
vec_ins.push_back(value_map.at((*it)->operand_source(i)));
}

std::vector<pir::Type> vec_types;
for (size_t i = 0; i < (*it)->num_results(); ++i) {
vec_types.push_back((*it)->result(i).type());
}

::pir::OpInfo info1 = ctx->GetRegisteredOpInfo((*it)->name());
::pir::Operation* op = ::pir::Operation::Create(
vec_ins, (*it)->attributes(), vec_types, info1);

ir_program->block()->push_back(op);

value_map[(*it)->result(0)] = op->result(0);
}
}
return ir_program;
}

} // namespace ir
} // namespace dialect
} // namespace cinn
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/pir/core/program.h"

namespace cinn {
namespace dialect {
namespace ir {

std::unique_ptr<pir::Program> CINNGroupLoweringPass(::pir::Program* program);

} // namespace ir
} // namespace dialect
} // namespace cinn
Original file line number Diff line number Diff line change
Expand Up @@ -1023,9 +1023,7 @@ class FusionPassRegistrar final : public Registrar {
// code generation.
class GeneralFusionMergePassHelper {
public:
explicit GeneralFusionMergePassHelper(const ::pir::Program* graph,
const GroupList& group_list)
: graph_(graph) {
explicit GeneralFusionMergePassHelper(const GroupList& group_list) {
fusion_groups_ = group_list;
// init input to consumers.
InitInputToConsumers();
Expand Down Expand Up @@ -2099,7 +2097,6 @@ class GeneralFusionMergePassHelper {
}
}

const ::pir::Program* graph_;
GroupList fusion_groups_;
std::unordered_map<GroupPtr, int> fusion_groups_index_;
std::unordered_set<const ::pir::Operation*> output_nodes_set_;
Expand All @@ -2108,14 +2105,13 @@ class GeneralFusionMergePassHelper {
input_to_consumers_;
};

GroupList GeneralFusionMergePassInternal(const ::pir::Program* graph,
const GroupList& group_list) {
GroupList GeneralFusionMergePassInternal(const GroupList& group_list) {
if (group_list.size() <= 1) {
VLOG(3) << "Don't do Fusoin Merge Pass...!";
return group_list;
}

GeneralFusionMergePassHelper fusion_merge_pass_helper(graph, group_list);
GeneralFusionMergePassHelper fusion_merge_pass_helper(group_list);
auto res = fusion_merge_pass_helper();

return res;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.h"

#include <limits.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.h"

#include "paddle/phi/core/enforce.h"
#include "paddle/pir/core/builtin_attribute.h"
#include "paddle/pir/core/operation.h"
Expand All @@ -40,6 +40,8 @@ std::unordered_map<std::string, OpPatternKind> OpKindMap = {
{"pd_op.full", OpPatternKind::kElementWise},
{"pd_op.relu", OpPatternKind::kElementWise},
{"pd_op.exp", OpPatternKind::kElementWise},
{"pd_op.sin", OpPatternKind::kElementWise},
{"pd_op.cos", OpPatternKind::kElementWise},
{"pd_op.sum", OpPatternKind::kReduction},
{"cinn_op.reduce_sum", OpPatternKind::kReduction},
{"cinn_op.reduce_max", OpPatternKind::kReduction},
Expand Down Expand Up @@ -143,19 +145,18 @@ using ConditionFunction =
// code generation.
class OpFusionPassHelper {
public:
explicit OpFusionPassHelper(const ::pir::Program& graph) {
explicit OpFusionPassHelper(const std::vector<pir::Operation*>& op_list) {
// init fusion relation
InitFusionRelation();
// filter node data, create group for each node
// auto nodes_inorder = std::get<0>(graph->topological_order());

for (auto it = graph.block()->begin(); it != graph.block()->end(); ++it) {
auto node = *it;
local_ops_.insert(node);
for (auto it = op_list.begin(); it != op_list.end(); ++it) {
local_ops_.insert(*it);
}

int index = 0;
for (auto it = graph.block()->begin(); it != graph.block()->end(); ++it) {
for (auto it = op_list.begin(); it != op_list.end(); ++it) {
auto node = *it;
if (node) {
nodes_.push_back(node);
Expand Down Expand Up @@ -491,9 +492,9 @@ class OpFusionPassHelper {
std::unordered_map<OpPatternKind, FusionRelation> fusion_relation_map_;
};

GroupList OpFusionPassInternal(const ::pir::Program& program) {
GroupList OpFusionPassInternal(const std::vector<pir::Operation*>& op_list) {
VLOG(3) << "OpFusionPass...!";
auto op_fusion_helper = OpFusionPassHelper(program);
auto op_fusion_helper = OpFusionPassHelper(op_list);
auto res = op_fusion_helper();

for (size_t i = 0; i < res.size(); ++i) {
Expand All @@ -502,27 +503,11 @@ GroupList OpFusionPassInternal(const ::pir::Program& program) {
for (size_t j = 0; j < group->nodes.size(); ++j) {
}
}

// for (auto& group : graph->fusion_groups) {
// VLOG(3) << "Group Id : " << group->group_id;
// for (const auto& producer : group->producer_groups()) {
// VLOG(3) << " producer group -> " << producer->group_id;
// }
// for (const auto& consumer : group->consumer_groups()) {
// VLOG(3) << " consumer group -> " << consumer->group_id;
// }
// }
VLOG(3) << "OpFusionPass Finish...!";

return res;
}

// void BuildNonFusedGroupsPassInternal(framework::Graph* graph) {
// auto op_fusion_helper = OpFusionPassHelper(graph);
// VLOG(3) << "Apply OpFusionPass to generate initial non-fusion groups";
// graph->fusion_groups = op_fusion_helper(false);
// }

} // namespace ir
} // namespace dialect
} // namespace cinn
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@ namespace ir {
using GroupPtr = std::shared_ptr<Group>;
using GroupList = std::vector<GroupPtr>;

GroupList OpFusionPassInternal(const ::pir::Program& program);
GroupList OpFusionPassInternal(const std::vector<pir::Operation*>& op_list);

GroupList GeneralFusionMergePassInternal(const ::pir::Program* graph,
const GroupList& group_list);
GroupList GeneralFusionMergePassInternal(const GroupList& group_list);

} // namespace ir
} // namespace dialect
Expand Down
2 changes: 2 additions & 0 deletions paddle/cinn/hlir/framework/pir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once
#include <string>
#include <unordered_map>
#include "paddle/cinn/backends/compiler.h"
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/common/type.h"
#include "paddle/cinn/utils/type_defs.h"
Expand All @@ -29,6 +30,7 @@ struct CUDAJITInfo {
void* fn_ptr;
std::vector<int> block_dims;
std::vector<int> grid_dims;
backends::Compiler* compiler;
};

struct CompatibleInfo {
Expand Down
2 changes: 2 additions & 0 deletions paddle/cinn/hlir/framework/pir_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@ std::vector<pir::CUDAJITInfo> PIRCompiler::BuildCUDAJITInfo(

auto fn_ptrs = compiler_->GetFnPtr();

auto* compilter_ptr = compiler_.release();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里release后,jit.info是不是仅持有了指向资源的指针,但其实没有负责析构(即delete)它?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,这个有一个新的pr来处理这块

for (int idx = 0; idx < groups.size(); ++idx) {
pir::CUDAJITInfo jit_info;
jit_info.fn_ptr = fn_ptrs[idx];
jit_info.compiler = compilter_ptr;

lowered_funcs[idx][0]->cuda_axis_info.CopyBlockDimsTo(
&(jit_info.block_dims));
Expand Down
Loading