Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN / Fusion] Support fusion of Pattern with multi downstream #66034

Merged
merged 130 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
130 commits
Select commit Hold shift + click to select a range
0922d8d
[CINN] Support horizontal fusion
jiahy0825 Apr 11, 2024
e219ce3
Change data type
jiahy0825 Apr 12, 2024
f6bae91
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jiahy0825 Apr 22, 2024
a3bb64b
Support horizontal fusion
jiahy0825 Apr 22, 2024
2c0d754
Fix compile error
jiahy0825 Apr 22, 2024
e586c0e
Merge commit 'refs/pull/63415/head' of https://github.com/PaddlePaddl…
2742195759 Apr 22, 2024
278effd
add topo sort in backend fusion
2742195759 Apr 23, 2024
06fb22c
update
feifei-111 Apr 24, 2024
ee3e694
reorder yield_store op adding
Apr 24, 2024
0752aa1
rename fusion to group
Apr 24, 2024
770f496
remove IsNotOutputNodeMatcher
Apr 24, 2024
fec3505
Merge pull request #10 from Fridge003/multi-down
feifei-111 Apr 24, 2024
f9f7f94
update
feifei-111 Apr 24, 2024
37c55e2
Merge branch 'multi_downstream' of https://github.com/feifei-111/Padd…
feifei-111 Apr 24, 2024
b65bdd3
merge upstream
feifei-111 Apr 25, 2024
0d0097f
update
feifei-111 Apr 25, 2024
d634fe5
update
feifei-111 Apr 25, 2024
1ccbfb5
update policy manager
feifei-111 Apr 25, 2024
e39186c
update
feifei-111 Apr 26, 2024
7d2e686
add reverse topo search algorithm for op fusion
Apr 26, 2024
18c5f51
Merge pull request #11 from Fridge003/multi-down
feifei-111 Apr 26, 2024
2a98eea
horizontal support dynamic shape and enhance fusion ability
2742195759 Apr 26, 2024
7ffbdf8
fix
2742195759 Apr 26, 2024
c30917a
fix
2742195759 Apr 27, 2024
3a5ff20
xx
2742195759 Apr 28, 2024
9745be1
move logic of reverse topo sort to pattern_graph.cc
Apr 28, 2024
4eda8bc
Merge pull request #12 from Fridge003/multi-down
feifei-111 Apr 28, 2024
303a766
fix some bugs
2742195759 Apr 28, 2024
0c56af5
Merge
2742195759 Apr 28, 2024
3b74908
update
feifei-111 Apr 28, 2024
03d4fa2
Merge branch 'multi_downstream' of https://github.com/feifei-111/Padd…
feifei-111 Apr 28, 2024
18a3eb8
skip multi-downstream nodes when doing trivial sink
Apr 28, 2024
2a85bbf
Merge pull request #13 from Fridge003/multi-down
feifei-111 Apr 28, 2024
cfa49e2
fix
2742195759 Apr 28, 2024
64590a6
xxxx
2742195759 Apr 28, 2024
4c4eeb7
fix
2742195759 Apr 28, 2024
c17322a
update
feifei-111 Apr 29, 2024
d173c22
Merge branch 'multi_downstream' of https://github.com/feifei-111/Padd…
feifei-111 Apr 29, 2024
ada835e
LiftToAnchorPattern Implementation
Apr 29, 2024
b36e48d
Merge pull request #14 from Fridge003/multi-down
feifei-111 Apr 29, 2024
a7a0b0c
update
feifei-111 Apr 29, 2024
b0172b4
update
feifei-111 Apr 29, 2024
8e7d785
update LiftToAnchorPattern
Apr 29, 2024
eb6ef40
Merge pull request #15 from Fridge003/multi-down
feifei-111 Apr 29, 2024
f6f58ee
horizontal operator fusion enhance
2742195759 Apr 29, 2024
5b48ee7
merge
2742195759 Apr 30, 2024
affebad
Implementation of anchor pattern recomputing mechanism
Apr 30, 2024
af8078c
Merge pull request #16 from Fridge003/multi-down
feifei-111 Apr 30, 2024
c70b7b1
update
feifei-111 May 5, 2024
c3be008
merge upstream
feifei-111 May 5, 2024
ab9c830
merge xk pr
feifei-111 May 5, 2024
7d4c58f
update
feifei-111 May 6, 2024
23048b5
update
feifei-111 May 6, 2024
d04e6f2
update
feifei-111 May 7, 2024
7f25695
update
feifei-111 May 7, 2024
5e4e24a
fix compile err
feifei-111 May 7, 2024
0b62207
update
feifei-111 May 7, 2024
5495a08
fix horizontal fusion
feifei-111 May 7, 2024
5289176
fix syntax err
feifei-111 May 7, 2024
dcb9500
register anchor policy
feifei-111 May 8, 2024
aa9b9ff
update
feifei-111 May 8, 2024
e10f83a
update
feifei-111 May 8, 2024
0f8b3f5
support LiftToAnchorPattern for reduce tree pattern
May 8, 2024
f84c7f4
Merge pull request #17 from Fridge003/multi-down
feifei-111 May 8, 2024
2fcb779
update
feifei-111 May 8, 2024
168ae82
Merge branch 'multi_downstream' of https://github.com/feifei-111/Padd…
feifei-111 May 8, 2024
3c35a1f
update
feifei-111 May 8, 2024
80f31ef
fix add_store_in_group_op
feifei-111 May 8, 2024
8d12de7
fix pir all path test
feifei-111 May 9, 2024
bdcc958
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
feifei-111 May 9, 2024
2b61681
update
feifei-111 May 9, 2024
04a9223
fix split recompute
feifei-111 May 9, 2024
78a514c
update
feifei-111 May 9, 2024
3723824
update
feifei-111 May 9, 2024
74ed2e8
update
feifei-111 May 10, 2024
0937f0d
fix compile
feifei-111 May 10, 2024
2102d72
update
feifei-111 May 10, 2024
49963ef
fix recompute matcher er
feifei-111 May 10, 2024
79eec1c
update
feifei-111 May 10, 2024
108fad7
update
feifei-111 May 10, 2024
afdb4e0
reduce logs
feifei-111 May 10, 2024
5cead78
fix SearchAnchorTransformRecursively
feifei-111 May 10, 2024
86e2920
update
feifei-111 May 11, 2024
46778a0
recover add_store_in_fusion_op
feifei-111 May 11, 2024
cd05e81
update
feifei-111 May 11, 2024
1dc23e6
update
feifei-111 May 11, 2024
54cd100
update
feifei-111 May 11, 2024
c328973
fix conf
feifei-111 May 11, 2024
4a11b1d
fix conf
feifei-111 May 24, 2024
b36ec30
refine codes and add interpreter
feifei-111 May 27, 2024
61fb855
update
feifei-111 May 29, 2024
6f264e9
update
feifei-111 May 29, 2024
c6d1edc
support cluster pass and add tracker to fusionOp
feifei-111 May 30, 2024
78c7abb
update
feifei-111 Jun 4, 2024
84d9b06
update backend
feifei-111 Jun 5, 2024
3e6d6ea
update
feifei-111 Jun 5, 2024
e381852
update
feifei-111 Jun 5, 2024
ab12d11
update
feifei-111 Jun 6, 2024
9ae7679
update
feifei-111 Jun 11, 2024
ff8384b
update
feifei-111 Jun 12, 2024
e74f49e
update
feifei-111 Jun 12, 2024
e668e4c
merge dev
feifei-111 Jun 12, 2024
f0dbf24
update
feifei-111 Jun 12, 2024
d16b61a
update
feifei-111 Jun 13, 2024
c318112
fix compile err
feifei-111 Jun 18, 2024
e8c30ae
update
feifei-111 Jun 18, 2024
5d0b5b6
update
feifei-111 Jun 18, 2024
5f10212
fix compile err
feifei-111 Jun 19, 2024
b712f70
update
feifei-111 Jun 19, 2024
9ee5952
update
feifei-111 Jun 20, 2024
3f9f7d8
update
feifei-111 Jun 20, 2024
1731238
add_test
feifei-111 Jun 21, 2024
c03571b
fix
2742195759 Jun 25, 2024
ec2d6f4
Merge
2742195759 Jul 2, 2024
9e44e06
merge
2742195759 Jul 2, 2024
9ca3465
fix
2742195759 Jul 3, 2024
0d8c62e
fix
huangjiyi Jul 15, 2024
93c4e80
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
huangjiyi Jul 15, 2024
6164f3a
fix
huangjiyi Jul 15, 2024
486d511
fix reshape tmp
huangjiyi Jul 16, 2024
b781b3e
fix test_graph
huangjiyi Jul 22, 2024
f0d5004
fix test_sd_resnet_block
huangjiyi Jul 23, 2024
640fca7
fix shared tracker conflict
huangjiyi Jul 24, 2024
27fbd9f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
huangjiyi Jul 30, 2024
63946aa
revert fix reshape tmp
huangjiyi Jul 30, 2024
cf4f9c4
fix softmax
huangjiyi Jul 31, 2024
0c7c524
fix
huangjiyi Jul 31, 2024
e6cd91e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
huangjiyi Aug 1, 2024
930f061
fix test_sub_graph_23
huangjiyi Aug 4, 2024
49bc25a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
huangjiyi Aug 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <vector>
#include "paddle/cinn/hlir/framework/op.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/cinn/operator_fusion/fusion_tracker/tracker.h"
#include "paddle/pir/include/core/attribute_base.h"
#include "paddle/pir/include/core/operation.h"
#include "paddle/pir/include/dialect/shape/utils/dim_expr.h"
Expand Down Expand Up @@ -100,6 +101,26 @@ struct GroupInfoAttributeStorage : public pir::AttributeStorage {
ParamKey data_;
};

struct FusionTrackerPtrAttributeStorage : public pir::AttributeStorage {
using ParamKey = cinn::fusion::FusionTrackerPtr;
explicit FusionTrackerPtrAttributeStorage(const ParamKey& key) : data_(key) {}

static FusionTrackerPtrAttributeStorage* Construct(const ParamKey& key) {
return new FusionTrackerPtrAttributeStorage(key);
}

static std::size_t HashValue(const ParamKey& key) {
return std::hash<ParamKey>()(key);
}

bool operator==(const ParamKey& key) const { return data_ == key; }

const ParamKey& GetAsKey() const { return data_; }

private:
ParamKey data_;
};

struct CINNKernelInfoAttributeStorage : public pir::AttributeStorage {
using ParamKey = cinn::hlir::framework::pir::CINNKernelInfo;
explicit CINNKernelInfoAttributeStorage(const ParamKey& key) : data_(key) {}
Expand Down
9 changes: 7 additions & 2 deletions paddle/cinn/hlir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ namespace dialect {
using DenseTensorType = paddle::dialect::DenseTensorType;

const char* GroupOp::attributes_name[GroupOp::attributes_num] = {"group_info"};
const char* FusionOp::attributes_name[GroupOp::attributes_num] = {"group_info"};
const char* FusionOp::attributes_name[FusionOp::attributes_num] = {
"group_info", "fusion_tracker"};
const char* ConcatOp::attributes_name[ConcatOp::attributes_num] = {"axis"};
const char* SplitOp::attributes_name[SplitOp::attributes_num] = {
"num_or_sections", "axis"};
Expand Down Expand Up @@ -146,13 +147,17 @@ void FusionOp::Build(pir::Builder& builder,
void FusionOp::Build(pir::Builder& builder, // NOLINT
pir::OperationArgument& argument, // NOLINT
const std::vector<pir::Type>& output_types,
const cinn::dialect::GroupInfo& group_info) {
const cinn::dialect::GroupInfo& group_info,
const cinn::fusion::FusionTrackerPtr& tracker) {
argument.AddRegion(nullptr);
argument.output_types = output_types;

argument.AddAttribute("group_info",
cinn::dialect::GroupInfoAttribute::get(
pir::IrContext::Instance(), group_info));
argument.AddAttribute("fusion_tracker",
cinn::dialect::FusionTrackerPtrAttribute::get(
pir::IrContext::Instance(), tracker));
}

pir::Block* FusionOp::block() {
Expand Down
5 changes: 3 additions & 2 deletions paddle/cinn/hlir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class IR_API FusionOp
public:
using Op::Op;
static const char *name() { return "cinn_op.fusion"; }
static constexpr uint32_t attributes_num = 1;
static constexpr uint32_t attributes_num = 2;
static const char *attributes_name[attributes_num];
static void Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
Expand All @@ -76,7 +76,8 @@ class IR_API FusionOp
static void Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
const std::vector<pir::Type> &output_types,
const cinn::dialect::GroupInfo &group_info);
const cinn::dialect::GroupInfo &group_info,
const cinn::fusion::FusionTrackerPtr &tracker);

pir::Block *block();
pir::Block *block() const;
Expand Down
6 changes: 6 additions & 0 deletions paddle/cinn/hlir/dialect/operator/ir/op_attribute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@

namespace cinn {
namespace dialect {

const cinn::fusion::FusionTrackerPtr &FusionTrackerPtrAttribute::data() const {
return storage()->GetAsKey();
}

const GroupInfo &GroupInfoAttribute::data() const {
return storage()->GetAsKey();
}
Expand All @@ -29,3 +34,4 @@ CINNKernelInfoAttribute::data() const {

IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::GroupInfoAttribute)
IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::CINNKernelInfoAttribute)
IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::FusionTrackerPtrAttribute)
17 changes: 17 additions & 0 deletions paddle/cinn/hlir/dialect/operator/ir/op_attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,26 @@

#pragma once
#include "paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h"
#include "paddle/cinn/operator_fusion/fusion_tracker/tracker.h"
#include "paddle/pir/include/core/attribute_base.h"

namespace cinn {
namespace dialect {
class FusionTrackerPtrAttribute : public pir::Attribute {
public:
using Attribute::Attribute;

DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(FusionTrackerPtrAttribute,
FusionTrackerPtrAttributeStorage);

bool operator<(const FusionTrackerPtrAttribute& right) const {
return storage() < right.storage();
}

static std::string name() { return "fusion_tracker"; }

const cinn::fusion::FusionTrackerPtr& data() const;
};

class GroupInfoAttribute : public pir::Attribute {
public:
Expand Down Expand Up @@ -56,3 +72,4 @@ class CINNKernelInfoAttribute : public pir::Attribute {

IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::GroupInfoAttribute)
IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::CINNKernelInfoAttribute)
IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::FusionTrackerPtrAttribute)
5 changes: 5 additions & 0 deletions paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ void OperatorDialect::initialize() {
RegisterOp<GenerateXShapeOp>();
RegisterAttribute<GroupInfoAttribute>();
RegisterAttribute<CINNKernelInfoAttribute>();
RegisterAttribute<FusionTrackerPtrAttribute>();
}

void OperatorDialect::PrintType(pir::Type type, std::ostream &os) const {}
Expand All @@ -81,6 +82,10 @@ void OperatorDialect::PrintAttribute(pir::Attribute attr,

os << "(" << cinn_kernel_info.data().fn_ptr;
os << ')';
} else if (attr.isa<FusionTrackerPtrAttribute>()) {
auto tracker = attr.dyn_cast<FusionTrackerPtrAttribute>();
os << "(" << tracker;
os << ')';
} else {
PADDLE_THROW(::common::errors::Unimplemented(
"cinn dialect only support GroupInfo and CINNKernelInfo"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/accuracy_check_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/add_store_in_group_op_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/conv2d_transpose_filter_pass.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ class AddYieldStoreInFusionOpPattern
bool MatchAndRewrite(::pir::YieldOp op,
pir::PatternRewriter& rewriter) const override {
for (auto i = 0; i < op->num_operands(); ++i) {
if (op->operand_source(i).use_count() == 1) {
continue;
}

rewriter.SetInsertionPointAfter(op->operand_source(i).defining_op());
auto store_op = rewriter.Build<cinn::dialect::YieldStoreOp>(
op->operand_source(i), op->operand_source(i).type());
Expand Down Expand Up @@ -69,13 +65,6 @@ class AddStoreInFusionOpPass : public pir::Pass {
for (auto& block : op->region(i)) {
for (auto& op : block) {
if (op.isa<cinn::dialect::FusionOp>()) {
auto fusion_op = op.dyn_cast<cinn::dialect::FusionOp>();
if (fusion_op.GetOperators().size() == 2 &&
fusion_op.GetOperators()
.front()
->isa<cinn::dialect::ReshapeOp>()) {
continue;
}
auto [_, num_rewrites] =
pir::ApplyPatternsGreedily(&op, patterns_, cfg);
AddStatistics(num_rewrites);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ namespace cinn::dialect {

class GroupInfoAttribute;
class CINNKernelInfoAttribute;
class FusionTrackerPtrAttribute;

} // namespace cinn::dialect

Expand Down Expand Up @@ -86,7 +87,8 @@ class UnclassifiedAttribute {};
__macro(paddle::dialect::PlaceAttribute) \
__macro(paddle::dialect::DataLayoutAttribute) \
__macro(cinn::dialect::GroupInfoAttribute) \
__macro(cinn::dialect::CINNKernelInfoAttribute)
__macro(cinn::dialect::CINNKernelInfoAttribute) \
__macro(cinn::dialect::FusionTrackerPtrAttribute)
// clang-format on

using AttrAdtTypeIdBase = ::common::AdtBaseTypeId<
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/cinn/operator_fusion/group_cluster.h"
#include "paddle/cinn/operator_fusion/cluster_interface.h"
#include "paddle/cinn/operator_fusion/fusion_tracker/tracker.h"
#include "paddle/common/ddim.h"
#include "paddle/common/flags.h"
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
Expand Down Expand Up @@ -109,6 +110,8 @@ struct GroupClusterNode {
std::unordered_set<::pir::Value> GetOutsideInput() const {
return GetListOutsideInput(ops);
}

cinn::fusion::FusionTrackerPtr tracker;
};

std::vector<::pir::Value> GenerateOutputValue(
Expand Down Expand Up @@ -199,7 +202,7 @@ ::pir::GroupOpsVec CloneOps(
return vec_new_op_list;
}

::pir::Operation* ReplaceWithGroupOp(
::pir::Operation* ReplaceWithFusionOp(
pir::PatternRewriter* rewriter,
const ::pir::GroupOpsVec& group_ops,
const GroupClusterNode& node,
Expand All @@ -218,8 +221,8 @@ ::pir::Operation* ReplaceWithGroupOp(
// step 2: Replace the old op with GroupOp.

auto output_types = BuildOutType(output_value);
auto new_fusion_op =
rewriter->Build<cinn::dialect::FusionOp>(output_types, group_info);
auto new_fusion_op = rewriter->Build<cinn::dialect::FusionOp>(
output_types, group_info, node.tracker);
pir::Block* fusion_block = new_fusion_op.block();

for (auto op : vec_new_op_list) {
Expand All @@ -245,32 +248,42 @@ ::pir::Operation* ReplaceWithGroupOp(
}

std::vector<GroupClusterNode> GroupSplit(cinn::dialect::GroupOp group_op) {
std::function<cinn::fusion::FrontendContent(pir::Operation*)> func =
[](pir::Operation* op) { return cinn::fusion::FrontendContent(op); };
std::function<cinn::fusion::PatternContent(pir::Operation*)> func =
[](pir::Operation* op) { return cinn::fusion::PatternContent(op); };
const auto& contents = cinn::fusion::MapVector(group_op.GetOperators(), func);
auto cluster_result = cinn::fusion::ClusterOps(contents, {});
std::vector<std::vector<pir::Operation*>> result;
std::transform(
cluster_result.begin(),
cluster_result.end(),
std::back_inserter(result),
[](const cinn::fusion::PatternNodePtr<cinn::fusion::FrontendStage> node) {
return cinn::fusion::GetOpsInPattern(node->stmt_pattern());
});
std::vector<std::vector<pir::Operation*>> op_sets;
std::vector<cinn::fusion::FusionTrackerPtr> trackers;
std::transform(cluster_result.begin(),
cluster_result.end(),
std::back_inserter(op_sets),
[](const cinn::fusion::PatternNodePtr node) {
return cinn::fusion::GetOpsInPattern(node->stmt_pattern());
});
std::transform(cluster_result.begin(),
cluster_result.end(),
std::back_inserter(trackers),
[](const cinn::fusion::PatternNodePtr node)
-> cinn::fusion::FusionTrackerPtr {
return cinn::fusion::GetFusionTracker(node->stmt_pattern());
});

// Each stmts corresponds to each fusion op(cluster node).
// Concat all the ops of patterns in the stmts, and make them the op list of
// cluster node.
VLOG(4) << "Start Creating Cluster Nodes!";
std::vector<GroupClusterNode> output_cluster_nodes;
for (const auto& op_set : result) {
for (int i = 0; i < op_sets.size(); i++) {
auto op_set = op_sets[i];
GroupClusterNode cluster_node;
for (const auto* op : op_set) {
cluster_node.ops.push_back(const_cast<pir::Operation*>(op));
auto op_kind = cinn::hlir::framework::pir::CompatibleInfo::OpKind(*op);
cluster_node.group_kind =
cluster_node.group_kind > op_kind ? cluster_node.group_kind : op_kind;
}
// Deep copy trackers to avoid shared tracker conflict in different node
cluster_node.tracker = trackers[i]->Clone();
output_cluster_nodes.push_back(cluster_node);
}
VLOG(4) << "Finished Creating Cluster Nodes!";
Expand Down Expand Up @@ -320,6 +333,22 @@ std::unordered_map<::pir::Value, size_t> BuildValueOrderByYieldOp(
return all_output_values;
}

void UpdateTracker(std::vector<pir::Operation*> uniq_ops,
fusion::FusionTrackerPtr tracker) {
std::map<pir::Operation*, int> op2idx;
for (int i = 0; i < uniq_ops.size(); ++i) {
op2idx[uniq_ops[i]] = i;
}
for (const auto& t : tracker->instructions_) {
if (t->type() == fusion::T_InitPattern) {
auto init_instr =
cinn::fusion::dynamic_cast_instr_with_err<fusion::InitPatternInstr>(
t);
init_instr->set_idx(op2idx[init_instr->op_]);
}
}
}

} // namespace

class CinnGroupClusterPattern
Expand Down Expand Up @@ -349,7 +378,9 @@ class CinnGroupClusterPattern
VLOG(4) << "cluster node output size: " << output_values.size();
auto uniq_ops = SortByOriginalOrderAndUniq(group_op, node.ops);

auto new_group_op = ReplaceWithGroupOp(
UpdateTracker(uniq_ops, node.tracker);

auto new_group_op = ReplaceWithFusionOp(
&rewriter, uniq_ops, node, output_values, &ir_mapping);

// TODO(Hongqing-work): delete this after fix bug of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ void SetBroadcastLeafGroup(
"ops:%d.",
origin_group->ops().size(),
new_group->ops().size()));

UpdateGroupShapeExprs(
new_group, origin_group, value_dim_exprs_list, value_to_dim_expr_idx);

Expand All @@ -85,6 +86,7 @@ void SetBroadcastLeafGroup(
new_group->GetShapeOrDataExprs(v));
}
}

group_list->emplace_back(new_group);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/utils.h"

#include "paddle/cinn/adt/generate_map_expr.h"
#include "paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h"
#include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h"
#include "paddle/cinn/hlir/dialect/operator/ir/op_attribute.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/broadcast_with_cf.h"
Expand Down Expand Up @@ -107,6 +108,7 @@ OpLoweringGroupPtr BuildOpLoweringGroup(pir::Operation* fusion_op_ptr) {
: group_op_kind;
}
}

PADDLE_ENFORCE_GT(fusion_op.attributes().count("group_info"),
0UL,
::common::errors::InvalidArgument(
Expand All @@ -117,7 +119,12 @@ OpLoweringGroupPtr BuildOpLoweringGroup(pir::Operation* fusion_op_ptr) {
.data();

const auto& fn_name = attr.fn_name;
auto group = std::make_shared<OpLoweringGroup>(ops, fn_name);
auto group = std::make_shared<OpLoweringGroup>(
ops,
fn_name,
fusion_op_ptr->attribute("fusion_tracker")
.dyn_cast<cinn::dialect::FusionTrackerPtrAttribute>()
.data());

group_op_kind =
static_cast<int>(attr.op_pattern_kind) > static_cast<int>(group_op_kind)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1085,6 +1085,12 @@ struct PirToPyCodeConverterHelper {
ss << "self." << name << "()";
return ss.str();
}
std::string operator()(TypeId<cinn::dialect::FusionTrackerPtrAttribute>) {
const auto& name = cinn::dialect::FusionTrackerPtrAttribute::name();
std::stringstream ss;
ss << "self." << name << "()";
return ss.str();
}
std::string operator()(TypeId<UnclassifiedAttribute>) {
return "self.UnclassifiedAttribute()";
}
Expand Down
3 changes: 2 additions & 1 deletion paddle/cinn/hlir/framework/pir/op_lowering_group.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ std::shared_ptr<OpLoweringGroup> OpLoweringGroup::Clone(
const std::string name_suffix) const {
const auto new_fn_name = this->fn_name_ + "_cloned_" + name_suffix;
// Construct Base information for new Group
auto new_group = std::make_shared<OpLoweringGroup>(this->ops_, new_fn_name);
auto new_group = std::make_shared<OpLoweringGroup>(
this->ops_, new_fn_name, this->fusion_tracker_ptr);

new_group->output_ops_ = this->output_ops_;
new_group->output_values_ = this->output_values_;
Expand Down
Loading