Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement OpMergeWithOp #51

Merged
merged 1 commit into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cmake/cinn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,6 @@ function(gen_cinncore LINKTYPE)
${LINKTYPE}
SRCS
${core_src}
${group_pattern_util}
DEPS
glog
${llvm_libs}
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ if(NOT CINN_ONLY)
cinn_runtime_dialect
pir_compiler)

cinn_cc_library(cinn_transforms SRCS ${cinn_transforms_srcs} DEPS
${cinn_transforms_deps})
cinn_cc_library(cinn_transforms SRCS ${cinn_transforms_srcs}
${group_pattern_util} DEPS ${cinn_transforms_deps})

cc_library(
add_cinn_pass
Expand Down
132 changes: 113 additions & 19 deletions paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@

#include "paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.h"

#include "paddle/cinn/frontend/group_pattern.h"
#include "paddle/cinn/frontend/group_pattern_util.h"
#include "paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h"
#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/common/ddim.h"
#include "paddle/common/flags.h"
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
Expand All @@ -48,6 +50,8 @@
#include "paddle/pir/include/pattern_rewrite/pattern_match.h"
#include "paddle/pir/include/pattern_rewrite/pattern_rewrite_driver.h"

PD_DECLARE_bool(cinn_new_cluster_op_method);

namespace cinn {
namespace dialect {
namespace ir {
Expand Down Expand Up @@ -781,24 +785,112 @@ std::vector<GroupClusterNode> NodeMergeWithNode(
return second_stage_output;
}

// std::vector<GroupClusterNode> OpMergeWithOp(cinn::dialect::GroupOp group_op)
// {
// const auto& ops = [&]{
// std::vector<const pir::Operation*> ops;
// for (const auto& op : *group_op.block()) {
// ops.push_back(&op);
// }
// return ops;
// }();
// auto cluster_policy = [&]{
// auto* program = group_op.GetParentProgram();
// const auto* shape_analysis =
// &pir::ShapeAnalysisManager::Instance().Get(program); return
// frontend::MakeLoopAlignableClusteringPolicy(shape_analysis);
// }();
// const auto cluster_result = frontend::ClusterOps(ops,
// std::move(cluster_policy));
// }
// This structure is the visitor function of fetching pattern's operator list.
// For IS or PS patterns, directly use their operator list;
// For Reduce patterns, the operator list is the concatenation of reduce op and
// its inputs.
struct GetPatternOpList {
std::vector<const pir::Operation*> operator()(
const api::InjectiveSourcePattern<frontend::FrontendPattern>& pattern) {
return pattern.ops;
}

std::vector<const pir::Operation*> operator()(
const api::PartialShardablePattern<frontend::FrontendPattern>& pattern) {
return pattern.ops;
}

std::vector<const pir::Operation*> operator()(
const api::ReductionPattern<frontend::FrontendPattern>& pattern) {
struct InputOpsVisitor {
std::vector<const pir::Operation*> operator()(
const api::InjectiveSourcePattern<frontend::FrontendPattern>& input) {
return input.ops;
}

std::vector<const pir::Operation*> operator()(
const api::PartialShardablePattern<frontend::FrontendPattern>&
input) {
return input.ops;
}

std::vector<const pir::Operation*> operator()(
const std::monostate& input) {
return {};
}
};

std::vector<const pir::Operation*> ops_list = {
pattern.reduce_op_pattern.reduce_op};
std::vector<const pir::Operation*> input_ops =
std::visit(InputOpsVisitor(), pattern.input);
ops_list.insert(ops_list.end(), input_ops.begin(), input_ops.end());

return ops_list;
}
};

std::vector<GroupClusterNode> NewOpMergeWithOp(
cinn::dialect::GroupOp group_op) {
const auto& ops = [&] {
std::vector<const pir::Operation*> ops;
for (const auto& op : *group_op.block()) {
ops.push_back(&op);
}
return ops;
}();

auto shardable_axes_provider = [&] {
auto* program = group_op->GetParentProgram();
const auto* shape_analysis =
&pir::ShapeAnalysisManager::Instance().Get(program);
return frontend::MakeDefaultShardableAxesProvider(shape_analysis);
}();

auto cluster_policy = [&] {
auto* program = group_op->GetParentProgram();
const auto* shape_analysis =
&pir::ShapeAnalysisManager::Instance().Get(program);
return frontend::MakeLoopAlignableClusteringPolicy(shape_analysis);
}();

VLOG(4) << "Start Clustering Ops!";
const auto cluster_result = frontend::ClusterOps(
ops, std::move(shardable_axes_provider), std::move(cluster_policy));
VLOG(4) << "Finished Clustering Ops!";

// Each stmts corresponds to each fusion op(cluster node).
// Concat all the ops of patterns in the stmts, and make them the op list of
// cluster node.
VLOG(4) << "Start Creating Cluster Nodes!";
std::vector<GroupClusterNode> output_cluster_nodes;
for (const auto& stmts_pattern : cluster_result.loop_alignable_list) {
GroupClusterNode cluster_node;
std::set<const pir::Operation*>
node_ops_set; // The set of all ops in the cluster node, for deleting
// repeated elements.
bool is_reduce_node =
false; // A flag indicating whether current node is a reduce node.
for (const auto& pattern : stmts_pattern.stmts) {
std::vector<const pir::Operation*> pattern_ops =
std::visit(GetPatternOpList(), pattern);
node_ops_set.insert(pattern_ops.begin(), pattern_ops.end());
if (std::holds_alternative<
api::ReductionPattern<frontend::FrontendPattern>>(pattern)) {
is_reduce_node = true;
}
}
for (const auto& op : node_ops_set) {
cluster_node.ops.push_back(const_cast<pir::Operation*>(op));
}
cluster_node.group_kind = is_reduce_node
? cinn::hlir::framework::kReduction
: cinn::hlir::framework::kInjective;
output_cluster_nodes.push_back(cluster_node);
}
VLOG(4) << "Finished Creating Cluster Nodes!";
return output_cluster_nodes;
}

std::vector<GroupClusterNode> OpMergeWithOp(cinn::dialect::GroupOp group_op) {
// op merge with op
Expand Down Expand Up @@ -866,7 +958,9 @@ std::vector<GroupClusterNode> OpMergeWithOp(cinn::dialect::GroupOp group_op) {

std::vector<GroupClusterNode> GroupSplit(cinn::dialect::GroupOp group_op) {
// stage 1
auto first_stage_output = OpMergeWithOp(group_op);
auto first_stage_output = FLAGS_cinn_new_cluster_op_method
? NewOpMergeWithOp(group_op)
: OpMergeWithOp(group_op);

if (first_stage_output.size() <= 1) {
return first_stage_output;
Expand Down
5 changes: 5 additions & 0 deletions paddle/cinn/runtime/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ PD_DEFINE_bool(group_schedule_tiling_first,
BoolFromEnv("FLAGS_group_schedule_tiling_first", false),
"Whether to enable new group scheduler tiling first strategy.");

PD_DEFINE_bool(cinn_new_cluster_op_method,
BoolFromEnv("FLAGS_cinn_new_cluster_op_method", false),
"Whether to enable newly developed clustering method of group "
"op for cinn.");

PD_DEFINE_bool(support_reduce_stride_read,
BoolFromEnv("FLAGS_support_reduce_stride_read", false),
"Whether to enable new group scheduler tiling first strategy.");
Expand Down