Skip to content

Commit

Permalink
Merge pull request #51 from Fridge003/cinn
Browse files Browse the repository at this point in the history
Implement OpMergeWithOp
  • Loading branch information
2742195759 authored Mar 18, 2024
2 parents 31a1124 + ffada84 commit b1cd524
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 22 deletions.
1 change: 0 additions & 1 deletion cmake/cinn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,6 @@ function(gen_cinncore LINKTYPE)
${LINKTYPE}
SRCS
${core_src}
${group_pattern_util}
DEPS
glog
${llvm_libs}
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ if(NOT CINN_ONLY)
cinn_runtime_dialect
pir_compiler)

cinn_cc_library(cinn_transforms SRCS ${cinn_transforms_srcs} DEPS
${cinn_transforms_deps})
cinn_cc_library(cinn_transforms SRCS ${cinn_transforms_srcs}
${group_pattern_util} DEPS ${cinn_transforms_deps})

cc_library(
add_cinn_pass
Expand Down
132 changes: 113 additions & 19 deletions paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@

#include "paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.h"

#include "paddle/cinn/frontend/group_pattern.h"
#include "paddle/cinn/frontend/group_pattern_util.h"
#include "paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h"
#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/common/ddim.h"
#include "paddle/common/flags.h"
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
Expand All @@ -48,6 +50,8 @@
#include "paddle/pir/include/pattern_rewrite/pattern_match.h"
#include "paddle/pir/include/pattern_rewrite/pattern_rewrite_driver.h"

PD_DECLARE_bool(cinn_new_cluster_op_method);

namespace cinn {
namespace dialect {
namespace ir {
Expand Down Expand Up @@ -781,24 +785,112 @@ std::vector<GroupClusterNode> NodeMergeWithNode(
return second_stage_output;
}

// std::vector<GroupClusterNode> OpMergeWithOp(cinn::dialect::GroupOp group_op)
// {
// const auto& ops = [&]{
// std::vector<const pir::Operation*> ops;
// for (const auto& op : *group_op.block()) {
// ops.push_back(&op);
// }
// return ops;
// }();
// auto cluster_policy = [&]{
// auto* program = group_op.GetParentProgram();
// const auto* shape_analysis =
// &pir::ShapeAnalysisManager::Instance().Get(program); return
// frontend::MakeLoopAlignableClusteringPolicy(shape_analysis);
// }();
// const auto cluster_result = frontend::ClusterOps(ops,
// std::move(cluster_policy));
// }
// This structure is the visitor function of fetching pattern's operator list.
// For IS or PS patterns, directly use their operator list;
// For Reduce patterns, the operator list is the concatenation of reduce op and
// its inputs.
struct GetPatternOpList {
std::vector<const pir::Operation*> operator()(
const api::InjectiveSourcePattern<frontend::FrontendPattern>& pattern) {
return pattern.ops;
}

std::vector<const pir::Operation*> operator()(
const api::PartialShardablePattern<frontend::FrontendPattern>& pattern) {
return pattern.ops;
}

std::vector<const pir::Operation*> operator()(
const api::ReductionPattern<frontend::FrontendPattern>& pattern) {
struct InputOpsVisitor {
std::vector<const pir::Operation*> operator()(
const api::InjectiveSourcePattern<frontend::FrontendPattern>& input) {
return input.ops;
}

std::vector<const pir::Operation*> operator()(
const api::PartialShardablePattern<frontend::FrontendPattern>&
input) {
return input.ops;
}

std::vector<const pir::Operation*> operator()(
const std::monostate& input) {
return {};
}
};

std::vector<const pir::Operation*> ops_list = {
pattern.reduce_op_pattern.reduce_op};
std::vector<const pir::Operation*> input_ops =
std::visit(InputOpsVisitor(), pattern.input);
ops_list.insert(ops_list.end(), input_ops.begin(), input_ops.end());

return ops_list;
}
};

std::vector<GroupClusterNode> NewOpMergeWithOp(
cinn::dialect::GroupOp group_op) {
const auto& ops = [&] {
std::vector<const pir::Operation*> ops;
for (const auto& op : *group_op.block()) {
ops.push_back(&op);
}
return ops;
}();

auto shardable_axes_provider = [&] {
auto* program = group_op->GetParentProgram();
const auto* shape_analysis =
&pir::ShapeAnalysisManager::Instance().Get(program);
return frontend::MakeDefaultShardableAxesProvider(shape_analysis);
}();

auto cluster_policy = [&] {
auto* program = group_op->GetParentProgram();
const auto* shape_analysis =
&pir::ShapeAnalysisManager::Instance().Get(program);
return frontend::MakeLoopAlignableClusteringPolicy(shape_analysis);
}();

VLOG(4) << "Start Clustering Ops!";
const auto cluster_result = frontend::ClusterOps(
ops, std::move(shardable_axes_provider), std::move(cluster_policy));
VLOG(4) << "Finished Clustering Ops!";

// Each stmts corresponds to each fusion op(cluster node).
// Concat all the ops of patterns in the stmts, and make them the op list of
// cluster node.
VLOG(4) << "Start Creating Cluster Nodes!";
std::vector<GroupClusterNode> output_cluster_nodes;
for (const auto& stmts_pattern : cluster_result.loop_alignable_list) {
GroupClusterNode cluster_node;
std::set<const pir::Operation*>
node_ops_set; // The set of all ops in the cluster node, for deleting
// repeated elements.
bool is_reduce_node =
false; // A flag indicating whether current node is a reduce node.
for (const auto& pattern : stmts_pattern.stmts) {
std::vector<const pir::Operation*> pattern_ops =
std::visit(GetPatternOpList(), pattern);
node_ops_set.insert(pattern_ops.begin(), pattern_ops.end());
if (std::holds_alternative<
api::ReductionPattern<frontend::FrontendPattern>>(pattern)) {
is_reduce_node = true;
}
}
for (const auto& op : node_ops_set) {
cluster_node.ops.push_back(const_cast<pir::Operation*>(op));
}
cluster_node.group_kind = is_reduce_node
? cinn::hlir::framework::kReduction
: cinn::hlir::framework::kInjective;
output_cluster_nodes.push_back(cluster_node);
}
VLOG(4) << "Finished Creating Cluster Nodes!";
return output_cluster_nodes;
}

std::vector<GroupClusterNode> OpMergeWithOp(cinn::dialect::GroupOp group_op) {
// op merge with op
Expand Down Expand Up @@ -866,7 +958,9 @@ std::vector<GroupClusterNode> OpMergeWithOp(cinn::dialect::GroupOp group_op) {

std::vector<GroupClusterNode> GroupSplit(cinn::dialect::GroupOp group_op) {
// stage 1
auto first_stage_output = OpMergeWithOp(group_op);
auto first_stage_output = FLAGS_cinn_new_cluster_op_method
? NewOpMergeWithOp(group_op)
: OpMergeWithOp(group_op);

if (first_stage_output.size() <= 1) {
return first_stage_output;
Expand Down
5 changes: 5 additions & 0 deletions paddle/cinn/runtime/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ PD_DEFINE_bool(group_schedule_tiling_first,
BoolFromEnv("FLAGS_group_schedule_tiling_first", false),
"Whether to enable new group scheduler tiling first strategy.");

PD_DEFINE_bool(cinn_new_cluster_op_method,
BoolFromEnv("FLAGS_cinn_new_cluster_op_method", false),
"Whether to enable newly developed clustering method of group "
"op for cinn.");

PD_DEFINE_bool(support_reduce_stride_read,
BoolFromEnv("FLAGS_support_reduce_stride_read", false),
"Whether to enable new group scheduler tiling first strategy.");
Expand Down

0 comments on commit b1cd524

Please sign in to comment.