Skip to content

Commit

Permalink
Implement OpMergeWithOp
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangbaizhou committed Mar 15, 2024
1 parent 31a1124 commit 1d45795
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 72 deletions.
1 change: 0 additions & 1 deletion cmake/cinn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,6 @@ function(gen_cinncore LINKTYPE)
${LINKTYPE}
SRCS
${core_src}
${group_pattern_util}
DEPS
glog
${llvm_libs}
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ if(NOT CINN_ONLY)
cinn_runtime_dialect
pir_compiler)

cinn_cc_library(cinn_transforms SRCS ${cinn_transforms_srcs} DEPS
${cinn_transforms_deps})
cinn_cc_library(cinn_transforms SRCS ${cinn_transforms_srcs}
${group_pattern_util} DEPS ${cinn_transforms_deps})

cc_library(
add_cinn_pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

#include "paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.h"

#include "paddle/cinn/frontend/group_pattern.h"
#include "paddle/cinn/frontend/group_pattern_util.h"
#include "paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h"
#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
Expand Down Expand Up @@ -781,87 +782,111 @@ std::vector<GroupClusterNode> NodeMergeWithNode(
return second_stage_output;
}

// std::vector<GroupClusterNode> OpMergeWithOp(cinn::dialect::GroupOp group_op)
// {
// const auto& ops = [&]{
// std::vector<const pir::Operation*> ops;
// for (const auto& op : *group_op.block()) {
// ops.push_back(&op);
// }
// return ops;
// }();
// auto cluster_policy = [&]{
// auto* program = group_op.GetParentProgram();
// const auto* shape_analysis =
// &pir::ShapeAnalysisManager::Instance().Get(program); return
// frontend::MakeLoopAlignableClusteringPolicy(shape_analysis);
// }();
// const auto cluster_result = frontend::ClusterOps(ops,
// std::move(cluster_policy));
// }

std::vector<GroupClusterNode> OpMergeWithOp(cinn::dialect::GroupOp group_op) {
// op merge with op
auto inner_values = GetInnerGeneValue(group_op.GetOperators());

std::unordered_map<::pir::Operation*, GroupClusterNode> op_path;

auto op_list = group_op.GetOperators();

std::vector<GroupClusterNode> first_stage_output;

std::unordered_set<::pir::Operation*> yield_output_ops;
std::unordered_set<::pir::Operation*> first_output_ops;
auto yield_op = op_list.back();
for (size_t i = 0; i < yield_op->num_operands(); ++i) {
if (yield_op->operand_source(i).defining_op()->result(0).use_count() == 1) {
yield_output_ops.insert(yield_op->operand_source(i).defining_op());
}
// This structure is the visitor function of fetching pattern's operator list.
// For IS or PS patterns, directly use their operator list;
// For Reduce patterns, the operator list is the concatenation of reduce op and
// its inputs.
struct GetPatternOpList {
std::vector<const pir::Operation*> operator()(
const api::InjectiveSourcePattern<frontend::FrontendPattern>& pattern) {
return pattern.ops;
}

// first stage op fuse op
for (auto* op : op_list) {
if (op->isa<::pir::YieldOp>()) {
continue;
}

auto& cluster_node = op_path[op];
auto& op_list = cluster_node.ops;
std::vector<const pir::Operation*> operator()(
const api::PartialShardablePattern<frontend::FrontendPattern>& pattern) {
return pattern.ops;
}

// process cluster node
ScheduleInfoNode sch_node;
GetClusterNodeBasicInfo(op, &cluster_node, &sch_node);
std::vector<const pir::Operation*> operator()(
const api::ReductionPattern<frontend::FrontendPattern>& pattern) {
struct InputOpsVisitor {
std::vector<const pir::Operation*> operator()(
const api::InjectiveSourcePattern<frontend::FrontendPattern>& input) {
return input.ops;
}

// process current Node and pre Node
auto pre_ops = GetPreOps(inner_values, op);
for (auto pre_op : pre_ops) {
if (!op_path.count(pre_op)) {
continue;
std::vector<const pir::Operation*> operator()(
const api::PartialShardablePattern<frontend::FrontendPattern>&
input) {
return input.ops;
}

if (CanOpMergeNode(op_path, pre_op, op)) {
cluster_node.MergePreNode(op_path.at(pre_op), sch_node);
std::vector<const pir::Operation*> operator()(
const std::monostate& input) {
return {};
}
}
};

op_list.push_back(op);
std::vector<const pir::Operation*> ops_list = {
pattern.reduce_op_pattern.reduce_op};
std::vector<const pir::Operation*> input_ops =
std::visit(InputOpsVisitor(), pattern.input);
ops_list.insert(ops_list.end(), input_ops.begin(), input_ops.end());

if (yield_output_ops.count(op) ||
cinn::hlir::framework::pir::CompatibleInfo::OpKind(*op) ==
cinn::hlir::framework::kReduction) {
// TODO(phlrain): yield output no need to push into first stage output,
// Update here
VLOG(4) << "Split Group by yield output ops: "
<< yield_output_ops.count(op);
if (!first_output_ops.count(op)) {
first_stage_output.push_back(op_path[op]);
first_output_ops.insert(op);
return ops_list;
}
};

std::vector<GroupClusterNode> OpMergeWithOp(cinn::dialect::GroupOp group_op) {
const auto& ops = [&] {
std::vector<const pir::Operation*> ops;
for (const auto& op : *group_op.block()) {
ops.push_back(&op);
}
return ops;
}();

auto shardable_axes_provider = [&] {
auto* program = group_op->GetParentProgram();
const auto* shape_analysis =
&pir::ShapeAnalysisManager::Instance().Get(program);
return frontend::MakeDefaultShardableAxesProvider(shape_analysis);
}();

auto cluster_policy = [&] {
auto* program = group_op->GetParentProgram();
const auto* shape_analysis =
&pir::ShapeAnalysisManager::Instance().Get(program);
return frontend::MakeLoopAlignableClusteringPolicy(shape_analysis);
}();

VLOG(4) << "Start Clustering Ops!";
const auto cluster_result = frontend::ClusterOps(
ops, std::move(shardable_axes_provider), std::move(cluster_policy));
VLOG(4) << "Finished Clustering Ops!";

// Each stmts corresponds to each fusion op(cluster node).
// Concat all the ops of patterns in the stmts, and make them the op list of
// cluster node.
VLOG(4) << "Start Creating Cluster Nodes!";
std::vector<GroupClusterNode> output_cluster_nodes;
for (const auto& stmts_pattern : cluster_result.loop_alignable_list) {
GroupClusterNode cluster_node;
std::set<const pir::Operation*>
node_ops_set; // The set of all ops in the cluster node, for deleting
// repeated elements.
bool is_reduce_node =
false; // A flag indicating whether current node is a reduce node.
for (const auto& pattern : stmts_pattern.stmts) {
std::vector<const pir::Operation*> pattern_ops =
std::visit(GetPatternOpList(), pattern);
node_ops_set.insert(pattern_ops.begin(), pattern_ops.end());
if (std::holds_alternative<
api::ReductionPattern<frontend::FrontendPattern>>(pattern)) {
is_reduce_node = true;
}
}
for (const auto& op : node_ops_set) {
// TODO(TianChao): Delete const in the definition of pattern.ops
cluster_node.ops.push_back(const_cast<pir::Operation*>(op));
}
cluster_node.group_kind = is_reduce_node
? cinn::hlir::framework::kReduction
: cinn::hlir::framework::kInjective;
output_cluster_nodes.push_back(cluster_node);
}

VLOG(4) << "first stage output size " << first_stage_output.size();
return first_stage_output;
VLOG(4) << "Finished Creating Cluster Nodes!";
return output_cluster_nodes;
}

std::vector<GroupClusterNode> GroupSplit(cinn::dialect::GroupOp group_op) {
Expand Down

0 comments on commit 1d45795

Please sign in to comment.