Skip to content

Commit

Permalink
[CINN / Fusion] Support fusion of Pattern with multi downstream (#66034)
Browse files Browse the repository at this point in the history
* [CINN] Support horizontal fusion

* Change data type

* Support horizontal fusion

* Fix compile error

* add topo sort in backend fusion

* update

* reorder yield_store op adding

* rename fusion to group

* remove IsNotOutputNodeMatcher

* update

* update

* update

* update policy manager

* update

* add reverse topo search algorithm for op fusion

* horizontal support dynamic shape and enhance fusion ability

* fix

* xx

* move logic of reverse topo sort to pattern_graph.cc

* fix some bugs

* update

* skip multi-downstream nodes when doing trivial sink

* fix

* xxxx

* fix

* update

* LiftToAnchorPattern Implementation

* update

* update LiftToAnchorPattern

* horizontal operator fusion enhance

* Implementation of anchor pattern recomputing mechanism

* update

* update

* update

* update

* update

* fix compile err

* update

* fix horizontal fusion

* fix syntax err

* register anchor policy

* update

* update

* support LiftToAnchorPattern for reduce tree pattern

* update

* update

* fix add_store_in_group_op

* fix pir all path test

* update

* fix split recompute

* update

* update

* update

* fix compile

* update

* fix recompute matcher er

* update

* update

* reduce logs

* fix SearchAnchorTransformRecursively

* update

* recover add_store_in_fusion_op

* update

* update

* update

* refine codes and add interpreter

* update

* update

* support cluster pass and add tracker to fusionOp

* update

* update backend

* update

* update

* update

* update

* update

* update

* update

* update

* fix compile err

* update

* update

* fix compile err

* update

* update

* update

* add_test

* fix

* fix

* fix

* fix

* fix reshape tmp

* fix test_graph

* fix test_sd_resnet_block

* fix shared tracker conflict

* revert fix reshape tmp

* fix softmax

* fix

* fix test_sub_graph_23

---------

Co-authored-by: jiahongyu <jiahongyu@baidu.com>
Co-authored-by: xiongkun <xiongkun03@baidu.com>
Co-authored-by: feifei-111 <2364819892@qq.com>
Co-authored-by: zhangbaizhou <zhangbaizhou@baidu.com>
  • Loading branch information
5 people authored Aug 5, 2024
1 parent ba16a30 commit 55629f8
Show file tree
Hide file tree
Showing 68 changed files with 3,733 additions and 2,129 deletions.
21 changes: 21 additions & 0 deletions paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <vector>
#include "paddle/cinn/hlir/framework/op.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/cinn/operator_fusion/fusion_tracker/tracker.h"
#include "paddle/pir/include/core/attribute_base.h"
#include "paddle/pir/include/core/operation.h"
#include "paddle/pir/include/dialect/shape/utils/dim_expr.h"
Expand Down Expand Up @@ -100,6 +101,26 @@ struct GroupInfoAttributeStorage : public pir::AttributeStorage {
ParamKey data_;
};

struct FusionTrackerPtrAttributeStorage : public pir::AttributeStorage {
using ParamKey = cinn::fusion::FusionTrackerPtr;
explicit FusionTrackerPtrAttributeStorage(const ParamKey& key) : data_(key) {}

static FusionTrackerPtrAttributeStorage* Construct(const ParamKey& key) {
return new FusionTrackerPtrAttributeStorage(key);
}

static std::size_t HashValue(const ParamKey& key) {
return std::hash<ParamKey>()(key);
}

bool operator==(const ParamKey& key) const { return data_ == key; }

const ParamKey& GetAsKey() const { return data_; }

private:
ParamKey data_;
};

struct CINNKernelInfoAttributeStorage : public pir::AttributeStorage {
using ParamKey = cinn::hlir::framework::pir::CINNKernelInfo;
explicit CINNKernelInfoAttributeStorage(const ParamKey& key) : data_(key) {}
Expand Down
9 changes: 7 additions & 2 deletions paddle/cinn/hlir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ namespace dialect {
using DenseTensorType = paddle::dialect::DenseTensorType;

const char* GroupOp::attributes_name[GroupOp::attributes_num] = {"group_info"};
const char* FusionOp::attributes_name[GroupOp::attributes_num] = {"group_info"};
const char* FusionOp::attributes_name[FusionOp::attributes_num] = {
"group_info", "fusion_tracker"};
const char* ConcatOp::attributes_name[ConcatOp::attributes_num] = {"axis"};
const char* SplitOp::attributes_name[SplitOp::attributes_num] = {
"num_or_sections", "axis"};
Expand Down Expand Up @@ -146,13 +147,17 @@ void FusionOp::Build(pir::Builder& builder,
void FusionOp::Build(pir::Builder& builder, // NOLINT
pir::OperationArgument& argument, // NOLINT
const std::vector<pir::Type>& output_types,
const cinn::dialect::GroupInfo& group_info) {
const cinn::dialect::GroupInfo& group_info,
const cinn::fusion::FusionTrackerPtr& tracker) {
argument.AddRegion(nullptr);
argument.output_types = output_types;

argument.AddAttribute("group_info",
cinn::dialect::GroupInfoAttribute::get(
pir::IrContext::Instance(), group_info));
argument.AddAttribute("fusion_tracker",
cinn::dialect::FusionTrackerPtrAttribute::get(
pir::IrContext::Instance(), tracker));
}

pir::Block* FusionOp::block() {
Expand Down
5 changes: 3 additions & 2 deletions paddle/cinn/hlir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class IR_API FusionOp
public:
using Op::Op;
static const char *name() { return "cinn_op.fusion"; }
static constexpr uint32_t attributes_num = 1;
static constexpr uint32_t attributes_num = 2;
static const char *attributes_name[attributes_num];
static void Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
Expand All @@ -76,7 +76,8 @@ class IR_API FusionOp
static void Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
const std::vector<pir::Type> &output_types,
const cinn::dialect::GroupInfo &group_info);
const cinn::dialect::GroupInfo &group_info,
const cinn::fusion::FusionTrackerPtr &tracker);

pir::Block *block();
pir::Block *block() const;
Expand Down
6 changes: 6 additions & 0 deletions paddle/cinn/hlir/dialect/operator/ir/op_attribute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@

namespace cinn {
namespace dialect {

const cinn::fusion::FusionTrackerPtr &FusionTrackerPtrAttribute::data() const {
return storage()->GetAsKey();
}

const GroupInfo &GroupInfoAttribute::data() const {
return storage()->GetAsKey();
}
Expand All @@ -29,3 +34,4 @@ CINNKernelInfoAttribute::data() const {

IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::GroupInfoAttribute)
IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::CINNKernelInfoAttribute)
IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::FusionTrackerPtrAttribute)
17 changes: 17 additions & 0 deletions paddle/cinn/hlir/dialect/operator/ir/op_attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,26 @@

#pragma once
#include "paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h"
#include "paddle/cinn/operator_fusion/fusion_tracker/tracker.h"
#include "paddle/pir/include/core/attribute_base.h"

namespace cinn {
namespace dialect {
class FusionTrackerPtrAttribute : public pir::Attribute {
public:
using Attribute::Attribute;

DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(FusionTrackerPtrAttribute,
FusionTrackerPtrAttributeStorage);

bool operator<(const FusionTrackerPtrAttribute& right) const {
return storage() < right.storage();
}

static std::string name() { return "fusion_tracker"; }

const cinn::fusion::FusionTrackerPtr& data() const;
};

class GroupInfoAttribute : public pir::Attribute {
public:
Expand Down Expand Up @@ -56,3 +72,4 @@ class CINNKernelInfoAttribute : public pir::Attribute {

IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::GroupInfoAttribute)
IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::CINNKernelInfoAttribute)
IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::FusionTrackerPtrAttribute)
5 changes: 5 additions & 0 deletions paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ void OperatorDialect::initialize() {
RegisterOp<GenerateXShapeOp>();
RegisterAttribute<GroupInfoAttribute>();
RegisterAttribute<CINNKernelInfoAttribute>();
RegisterAttribute<FusionTrackerPtrAttribute>();
}

void OperatorDialect::PrintType(pir::Type type, std::ostream &os) const {}
Expand All @@ -81,6 +82,10 @@ void OperatorDialect::PrintAttribute(pir::Attribute attr,

os << "(" << cinn_kernel_info.data().fn_ptr;
os << ')';
} else if (attr.isa<FusionTrackerPtrAttribute>()) {
auto tracker = attr.dyn_cast<FusionTrackerPtrAttribute>();
os << "(" << tracker;
os << ')';
} else {
PADDLE_THROW(::common::errors::Unimplemented(
"cinn dialect only support GroupInfo and CINNKernelInfo"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/accuracy_check_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/add_store_in_group_op_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/conv2d_transpose_filter_pass.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ class AddYieldStoreInFusionOpPattern
bool MatchAndRewrite(::pir::YieldOp op,
pir::PatternRewriter& rewriter) const override {
for (auto i = 0; i < op->num_operands(); ++i) {
if (op->operand_source(i).use_count() == 1) {
continue;
}

rewriter.SetInsertionPointAfter(op->operand_source(i).defining_op());
auto store_op = rewriter.Build<cinn::dialect::YieldStoreOp>(
op->operand_source(i), op->operand_source(i).type());
Expand Down Expand Up @@ -69,13 +65,6 @@ class AddStoreInFusionOpPass : public pir::Pass {
for (auto& block : op->region(i)) {
for (auto& op : block) {
if (op.isa<cinn::dialect::FusionOp>()) {
auto fusion_op = op.dyn_cast<cinn::dialect::FusionOp>();
if (fusion_op.GetOperators().size() == 2 &&
fusion_op.GetOperators()
.front()
->isa<cinn::dialect::ReshapeOp>()) {
continue;
}
auto [_, num_rewrites] =
pir::ApplyPatternsGreedily(&op, patterns_, cfg);
AddStatistics(num_rewrites);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ namespace cinn::dialect {

class GroupInfoAttribute;
class CINNKernelInfoAttribute;
class FusionTrackerPtrAttribute;

} // namespace cinn::dialect

Expand Down Expand Up @@ -86,7 +87,8 @@ class UnclassifiedAttribute {};
__macro(paddle::dialect::PlaceAttribute) \
__macro(paddle::dialect::DataLayoutAttribute) \
__macro(cinn::dialect::GroupInfoAttribute) \
__macro(cinn::dialect::CINNKernelInfoAttribute)
__macro(cinn::dialect::CINNKernelInfoAttribute) \
__macro(cinn::dialect::FusionTrackerPtrAttribute)
// clang-format on

using AttrAdtTypeIdBase = ::common::AdtBaseTypeId<
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/cinn/operator_fusion/group_cluster.h"
#include "paddle/cinn/operator_fusion/cluster_interface.h"
#include "paddle/cinn/operator_fusion/fusion_tracker/tracker.h"
#include "paddle/common/ddim.h"
#include "paddle/common/flags.h"
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
Expand Down Expand Up @@ -109,6 +110,8 @@ struct GroupClusterNode {
std::unordered_set<::pir::Value> GetOutsideInput() const {
return GetListOutsideInput(ops);
}

cinn::fusion::FusionTrackerPtr tracker;
};

std::vector<::pir::Value> GenerateOutputValue(
Expand Down Expand Up @@ -199,7 +202,7 @@ ::pir::GroupOpsVec CloneOps(
return vec_new_op_list;
}

::pir::Operation* ReplaceWithGroupOp(
::pir::Operation* ReplaceWithFusionOp(
pir::PatternRewriter* rewriter,
const ::pir::GroupOpsVec& group_ops,
const GroupClusterNode& node,
Expand All @@ -218,8 +221,8 @@ ::pir::Operation* ReplaceWithGroupOp(
// step 2: Replace the old op with GroupOp.

auto output_types = BuildOutType(output_value);
auto new_fusion_op =
rewriter->Build<cinn::dialect::FusionOp>(output_types, group_info);
auto new_fusion_op = rewriter->Build<cinn::dialect::FusionOp>(
output_types, group_info, node.tracker);
pir::Block* fusion_block = new_fusion_op.block();

for (auto op : vec_new_op_list) {
Expand All @@ -245,32 +248,42 @@ ::pir::Operation* ReplaceWithGroupOp(
}

std::vector<GroupClusterNode> GroupSplit(cinn::dialect::GroupOp group_op) {
std::function<cinn::fusion::FrontendContent(pir::Operation*)> func =
[](pir::Operation* op) { return cinn::fusion::FrontendContent(op); };
std::function<cinn::fusion::PatternContent(pir::Operation*)> func =
[](pir::Operation* op) { return cinn::fusion::PatternContent(op); };
const auto& contents = cinn::fusion::MapVector(group_op.GetOperators(), func);
auto cluster_result = cinn::fusion::ClusterOps(contents, {});
std::vector<std::vector<pir::Operation*>> result;
std::transform(
cluster_result.begin(),
cluster_result.end(),
std::back_inserter(result),
[](const cinn::fusion::PatternNodePtr<cinn::fusion::FrontendStage> node) {
return cinn::fusion::GetOpsInPattern(node->stmt_pattern());
});
std::vector<std::vector<pir::Operation*>> op_sets;
std::vector<cinn::fusion::FusionTrackerPtr> trackers;
std::transform(cluster_result.begin(),
cluster_result.end(),
std::back_inserter(op_sets),
[](const cinn::fusion::PatternNodePtr node) {
return cinn::fusion::GetOpsInPattern(node->stmt_pattern());
});
std::transform(cluster_result.begin(),
cluster_result.end(),
std::back_inserter(trackers),
[](const cinn::fusion::PatternNodePtr node)
-> cinn::fusion::FusionTrackerPtr {
return cinn::fusion::GetFusionTracker(node->stmt_pattern());
});

// Each stmts corresponds to each fusion op(cluster node).
// Concat all the ops of patterns in the stmts, and make them the op list of
// cluster node.
VLOG(4) << "Start Creating Cluster Nodes!";
std::vector<GroupClusterNode> output_cluster_nodes;
for (const auto& op_set : result) {
for (int i = 0; i < op_sets.size(); i++) {
auto op_set = op_sets[i];
GroupClusterNode cluster_node;
for (const auto* op : op_set) {
cluster_node.ops.push_back(const_cast<pir::Operation*>(op));
auto op_kind = cinn::hlir::framework::pir::CompatibleInfo::OpKind(*op);
cluster_node.group_kind =
cluster_node.group_kind > op_kind ? cluster_node.group_kind : op_kind;
}
// Deep copy trackers to avoid shared tracker conflict in different node
cluster_node.tracker = trackers[i]->Clone();
output_cluster_nodes.push_back(cluster_node);
}
VLOG(4) << "Finished Creating Cluster Nodes!";
Expand Down Expand Up @@ -320,6 +333,22 @@ std::unordered_map<::pir::Value, size_t> BuildValueOrderByYieldOp(
return all_output_values;
}

void UpdateTracker(std::vector<pir::Operation*> uniq_ops,
fusion::FusionTrackerPtr tracker) {
std::map<pir::Operation*, int> op2idx;
for (int i = 0; i < uniq_ops.size(); ++i) {
op2idx[uniq_ops[i]] = i;
}
for (const auto& t : tracker->instructions_) {
if (t->type() == fusion::T_InitPattern) {
auto init_instr =
cinn::fusion::dynamic_cast_instr_with_err<fusion::InitPatternInstr>(
t);
init_instr->set_idx(op2idx[init_instr->op_]);
}
}
}

} // namespace

class CinnGroupClusterPattern
Expand Down Expand Up @@ -349,7 +378,9 @@ class CinnGroupClusterPattern
VLOG(4) << "cluster node output size: " << output_values.size();
auto uniq_ops = SortByOriginalOrderAndUniq(group_op, node.ops);

auto new_group_op = ReplaceWithGroupOp(
UpdateTracker(uniq_ops, node.tracker);

auto new_group_op = ReplaceWithFusionOp(
&rewriter, uniq_ops, node, output_values, &ir_mapping);

// TODO(Hongqing-work): delete this after fix bug of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ void SetBroadcastLeafGroup(
"ops:%d.",
origin_group->ops().size(),
new_group->ops().size()));

UpdateGroupShapeExprs(
new_group, origin_group, value_dim_exprs_list, value_to_dim_expr_idx);

Expand All @@ -85,6 +86,7 @@ void SetBroadcastLeafGroup(
new_group->GetShapeOrDataExprs(v));
}
}

group_list->emplace_back(new_group);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/utils.h"

#include "paddle/cinn/adt/generate_map_expr.h"
#include "paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h"
#include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h"
#include "paddle/cinn/hlir/dialect/operator/ir/op_attribute.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/broadcast_with_cf.h"
Expand Down Expand Up @@ -107,6 +108,7 @@ OpLoweringGroupPtr BuildOpLoweringGroup(pir::Operation* fusion_op_ptr) {
: group_op_kind;
}
}

PADDLE_ENFORCE_GT(fusion_op.attributes().count("group_info"),
0UL,
::common::errors::InvalidArgument(
Expand All @@ -117,7 +119,12 @@ OpLoweringGroupPtr BuildOpLoweringGroup(pir::Operation* fusion_op_ptr) {
.data();

const auto& fn_name = attr.fn_name;
auto group = std::make_shared<OpLoweringGroup>(ops, fn_name);
auto group = std::make_shared<OpLoweringGroup>(
ops,
fn_name,
fusion_op_ptr->attribute("fusion_tracker")
.dyn_cast<cinn::dialect::FusionTrackerPtrAttribute>()
.data());

group_op_kind =
static_cast<int>(attr.op_pattern_kind) > static_cast<int>(group_op_kind)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1085,6 +1085,12 @@ struct PirToPyCodeConverterHelper {
ss << "self." << name << "()";
return ss.str();
}
std::string operator()(TypeId<cinn::dialect::FusionTrackerPtrAttribute>) {
const auto& name = cinn::dialect::FusionTrackerPtrAttribute::name();
std::stringstream ss;
ss << "self." << name << "()";
return ss.str();
}
std::string operator()(TypeId<UnclassifiedAttribute>) {
return "self.UnclassifiedAttribute()";
}
Expand Down
3 changes: 2 additions & 1 deletion paddle/cinn/hlir/framework/pir/op_lowering_group.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ std::shared_ptr<OpLoweringGroup> OpLoweringGroup::Clone(
const std::string name_suffix) const {
const auto new_fn_name = this->fn_name_ + "_cloned_" + name_suffix;
// Construct Base information for new Group
auto new_group = std::make_shared<OpLoweringGroup>(this->ops_, new_fn_name);
auto new_group = std::make_shared<OpLoweringGroup>(
this->ops_, new_fn_name, this->fusion_tracker_ptr);

new_group->output_ops_ = this->output_ops_;
new_group->output_values_ = this->output_values_;
Expand Down
Loading

0 comments on commit 55629f8

Please sign in to comment.