Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN / PIR] Cinn trivalop fuse #62088

Merged
merged 308 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from 250 commits
Commits
Show all changes
308 commits
Select commit Hold shift + click to select a range
a745eb0
Merge pull request #40 from tc20042008/xk-cinn-trivalop-fuse
tc20042008 Mar 8, 2024
d4bc74a
Merge branch 'cinn-trivalop-fuse' of https://github.com/2742195759/Pa…
feifei-111 Mar 8, 2024
b1c9cb8
implement FuseFilteredStmtPatterns
jiahy0825 Mar 8, 2024
badeae6
Merge pull request #41 from tc20042008/xk-cinn-trivalop-fuse
tc20042008 Mar 8, 2024
9d56d41
Merge branch 'cinn-trivalop-fuse' of https://github.com/2742195759/Pa…
feifei-111 Mar 9, 2024
83d1e79
update
feifei-111 Mar 9, 2024
b0d6347
Merge pull request #42 from feifei-111/cinn-trivalop-fuse
feifei-111 Mar 9, 2024
8fc1551
split trivial op into a single file.
2742195759 Mar 10, 2024
f59d49c
fix compiler complaints
jiahy0825 Mar 10, 2024
97735a1
Merge pull request #43 from tc20042008/xk-cinn-trivalop-fuse
tc20042008 Mar 10, 2024
666da6d
rename StmtIter to StmtPtr
jiahy0825 Mar 10, 2024
6be51d0
Merge pull request #44 from tc20042008/xk-cinn-trivalop-fuse
tc20042008 Mar 10, 2024
cff8bb6
declare group_pattern.InferShardableAxes
jiahy0825 Mar 10, 2024
8e74d2e
refine signature of group_pattern.InferShardableAxes
jiahy0825 Mar 10, 2024
c6bcf2d
Merge pull request #45 from tc20042008/xk-cinn-trivalop-fuse
tc20042008 Mar 10, 2024
6bf5f0e
move group_pattern.InferShardableAxes to group_pattern_util.InferShar…
jiahy0825 Mar 10, 2024
c947ada
Merge pull request #46 from tc20042008/xk-cinn-trivalop-fuse
tc20042008 Mar 10, 2024
de23d96
implement group_pattern_util.InferShardableAxes
jiahy0825 Mar 10, 2024
604afab
Merge pull request #47 from tc20042008/xk-cinn-trivalop-fuse
tc20042008 Mar 10, 2024
5b7dc57
add group_pattern_util.InferShardableAxesFromSink
jiahy0825 Mar 10, 2024
53ba3ed
Merge pull request #48 from tc20042008/xk-cinn-trivalop-fuse
tc20042008 Mar 10, 2024
2417813
ReversedInferShardableAxes support sinks
jiahy0825 Mar 10, 2024
b8e7939
update op lower
feifei-111 Mar 10, 2024
e22f81d
support multiple sinks in group_pattern_util.InferShardableAxes
jiahy0825 Mar 10, 2024
da2b472
Merge branch 'cinn-trivalop-fuse' of https://github.com/2742195759/Pa…
jiahy0825 Mar 10, 2024
c84c50c
update
feifei-111 Mar 10, 2024
302ba60
fix link error
2742195759 Mar 10, 2024
2f0c384
update
feifei-111 Mar 10, 2024
0a97ad9
merge origin
jiahy0825 Mar 11, 2024
a6cfd99
Merge pull request #50 from tc20042008/xk-cinn-trivalop-fuse
tc20042008 Mar 11, 2024
0ad3c13
fix conf
feifei-111 Mar 11, 2024
c1f01d2
remove FusionOp to OpList
2742195759 Mar 11, 2024
444abed
erge branch 'cinn-trivalop-fuse' of https://github.com/2742195759/Pad…
2742195759 Mar 11, 2024
5875b9e
update
feifei-111 Mar 11, 2024
274086f
update
feifei-111 Mar 11, 2024
efdd3e8
Merge pull request #49 from feifei-111/cinn-trivalop-fuse
feifei-111 Mar 11, 2024
b08377a
update
feifei-111 Mar 11, 2024
f47ca40
update
feifei-111 Mar 11, 2024
9661fb2
Merge pull request #52 from feifei-111/cinn-trivalop-fuse
feifei-111 Mar 11, 2024
22be208
update
feifei-111 Mar 11, 2024
7713da3
Merge pull request #53 from feifei-111/cinn-trivalop-fuse
feifei-111 Mar 11, 2024
e012d74
declare group_pattern_util.h
jiahy0825 Mar 11, 2024
ed7d12c
fix compiler complains
jiahy0825 Mar 11, 2024
5ff4943
Merge pull request #54 from tc20042008/xk-cinn-trivalop-fuse
tc20042008 Mar 11, 2024
7640fe7
declare group_pattern_util.ClusteringHelper
jiahy0825 Mar 11, 2024
d3d6926
refine signature of group_pattern_util.ClusterIntoGroupPatternsFromOp…
jiahy0825 Mar 11, 2024
c0dd054
Merge pull request #55 from tc20042008/xk-cinn-trivalop-fuse
tc20042008 Mar 11, 2024
e75a6bf
update op lowr
feifei-111 Mar 12, 2024
e96c0fd
Merge pull request #56 from feifei-111/cinn-trivalop-fuse
feifei-111 Mar 12, 2024
f2c12b8
add todo
2742195759 Mar 12, 2024
dc30e81
minor refine by group_pattern_util.OpSet
jiahy0825 Mar 12, 2024
85fb051
update
feifei-111 Mar 12, 2024
720e34d
update
feifei-111 Mar 12, 2024
9d49bef
update
feifei-111 Mar 12, 2024
2a6a72a
update (#57)
feifei-111 Mar 12, 2024
8c02464
update
feifei-111 Mar 12, 2024
2455e57
update
feifei-111 Mar 12, 2024
1603102
update
feifei-111 Mar 12, 2024
efe91cc
Cinn trivalop fuse (#58)
feifei-111 Mar 12, 2024
f95a83e
Merge branch 'cinn-trivalop-fuse' of https://github.com/2742195759/Pa…
feifei-111 Mar 12, 2024
ac906c7
fix
2742195759 Mar 12, 2024
e521bae
fix
2742195759 Mar 12, 2024
1884745
Merge branch 'cinn-trivalop-fuse' of https://github.com/2742195759/Pa…
feifei-111 Mar 12, 2024
eae1c4d
refactor StmtFusionHelper by OpTopo
jiahy0825 Mar 12, 2024
8662273
Complete: CreateReduceExpr function.
2742195759 Mar 12, 2024
b1ba43b
Merge branch 'cinn-trivalop-fuse' of https://github.com/2742195759/Pa…
feifei-111 Mar 12, 2024
d87830c
update
feifei-111 Mar 12, 2024
60e2715
recursive done.
2742195759 Mar 12, 2024
99d370c
update
feifei-111 Mar 12, 2024
483edae
Cinn trivalop fuse (#59)
feifei-111 Mar 12, 2024
18453e0
Merge branch 'cinn-trivalop-fuse' of https://github.com/2742195759/Pa…
feifei-111 Mar 12, 2024
fc94466
fix
2742195759 Mar 12, 2024
d602b58
clean all the TODO.
2742195759 Mar 12, 2024
9b821ae
update
feifei-111 Mar 12, 2024
1adce1e
fix cluster
2742195759 Mar 12, 2024
7027d1b
remove unused OpTopo.downstream_disconnected_ops
jiahy0825 Mar 12, 2024
b49ac16
update
feifei-111 Mar 12, 2024
185f288
Cinn trivalop fuse (#60)
feifei-111 Mar 12, 2024
35b6933
Merge branch 'cinn-trivalop-fuse' of https://github.com/2742195759/Pa…
feifei-111 Mar 12, 2024
dbc7a90
fix compile rror
2742195759 Mar 12, 2024
62432a1
Merge branch 'cinn-trivalop-fuse' of https://github.com/2742195759/Pa…
feifei-111 Mar 12, 2024
7b07854
update
feifei-111 Mar 12, 2024
55f975c
Cinn trivalop fuse (#61)
feifei-111 Mar 12, 2024
5419f4c
add R + T skeleon
2742195759 Mar 12, 2024
af04c62
Merge branch 'cinn-trivalop-fuse' of https://github.com/2742195759/Pa…
2742195759 Mar 12, 2024
0ffee51
Merge branch 'cinn-trivalop-fuse' of https://github.com/2742195759/Pa…
feifei-111 Mar 13, 2024
d8424e8
add search utils.
2742195759 Mar 13, 2024
1ecd28d
Merge branch 'cinn-trivalop-fuse' of https://github.com/2742195759/Pa…
feifei-111 Mar 13, 2024
6681fe1
update
feifei-111 Mar 13, 2024
b47aacb
Cinn trivalop fuse (#62)
feifei-111 Mar 13, 2024
36823db
Merge branch 'cinn-trivalop-fuse' of https://github.com/2742195759/Pa…
feifei-111 Mar 13, 2024
4131c45
push
2742195759 Mar 13, 2024
76c9f3c
merge
2742195759 Mar 13, 2024
3c5a716
Merge branch 'cinn-trivalop-fuse' of https://github.com/2742195759/Pa…
feifei-111 Mar 13, 2024
70c42bb
update
feifei-111 Mar 13, 2024
e1eebd0
fix
2742195759 Mar 13, 2024
02df615
Merge branch 'cinn-trivalop-fuse' of https://github.com/2742195759/Pa…
2742195759 Mar 13, 2024
699a171
fix transformer
2742195759 Mar 13, 2024
7f49672
fix
2742195759 Mar 13, 2024
1c2d2e6
Implement iterator vars fetching in ReduceOp
Mar 13, 2024
b8c65dc
small fix
Mar 13, 2024
e5a421e
add GetOuterIterVars API
Mar 13, 2024
98e5195
Merge pull request #63 from Fridge003/cinn_lower
2742195759 Mar 13, 2024
cec5d2b
fix
2742195759 Mar 13, 2024
d832ad5
Merge branch 'cinn-trivalop-fuse' of https://github.com/2742195759/Pa…
2742195759 Mar 13, 2024
e2fb978
fix compile complain
2742195759 Mar 13, 2024
6fd91aa
modify GetOutputIters of TrivialOp
Mar 13, 2024
c6e2129
Merge pull request #64 from Fridge003/cinn_lower
2742195759 Mar 13, 2024
5c1fc98
remove dumplicate code in visit
2742195759 Mar 13, 2024
a12a1ae
merge
2742195759 Mar 13, 2024
4170593
implement ClusterIntoGroupPatternsFromOpList
jiahy0825 Mar 13, 2024
4f1cd70
Fix most error in trivial_op.cc.
2742195759 Mar 13, 2024
c453011
CreateReduceExpr is OK!
2742195759 Mar 14, 2024
8ab605d
fix
2742195759 Mar 14, 2024
2249346
add CheckIterEq
Mar 14, 2024
02505f3
Merge pull request #65 from Fridge003/cinn_lower
2742195759 Mar 14, 2024
2dc9de1
implement group_pattern_util.ClusteringEngine and groupp_pattern_util…
jiahy0825 Mar 14, 2024
4a52ccb
SinkTrivialTransform OK!
2742195759 Mar 14, 2024
7a4b02c
Merge branch 'cinn-trivalop-fuse' of https://github.com/2742195759/Pa…
2742195759 Mar 14, 2024
8bafa17
update
feifei-111 Mar 14, 2024
146ff67
update
feifei-111 Mar 14, 2024
a4ef084
fix init_tensor name problem.
2742195759 Mar 14, 2024
599941f
update
feifei-111 Mar 14, 2024
7bd854e
Merge branch 'cinn-trivalop-fuse' of https://github.com/2742195759/Pa…
feifei-111 Mar 14, 2024
8f21bec
fix compiler complains
jiahy0825 Mar 14, 2024
23e8341
merge origin repo xiongkun
jiahy0825 Mar 14, 2024
5180b55
Merge pull request #67 from tc20042008/xk-cinn-trivalop-fuse
tc20042008 Mar 14, 2024
cda4b1b
refactor ShardableAxesSignature by group_pattern.SoleOutputShardableAxes
jiahy0825 Mar 14, 2024
35506a8
Merge pull request #68 from tc20042008/xk-cinn-trivalop-fuse
tc20042008 Mar 14, 2024
b4d91ce
split trivial_op.cc
Mar 14, 2024
7d43e4d
Merge pull request #66 from Fridge003/cinn_lower
2742195759 Mar 14, 2024
a6a85e9
update
feifei-111 Mar 14, 2024
b052007
implement group_pattern_util.MakeShardableAxesSignature4ReduceOp
jiahy0825 Mar 14, 2024
c371dca
update
feifei-111 Mar 14, 2024
27a647c
implement group_pattern_util.MakeEmptyShardableAxesSignature
jiahy0825 Mar 14, 2024
0937581
Merge pull request #69 from tc20042008/xk-cinn-trivalop-fuse
tc20042008 Mar 14, 2024
f252733
add helper class group_pattern_util.ShardableAxesProvider
jiahy0825 Mar 14, 2024
00adbcd
implement group_pattern_util.MakeShardableAxesSignature4BroadcastOp
jiahy0825 Mar 14, 2024
08b45af
Merge pull request #70 from tc20042008/xk-cinn-trivalop-fuse
tc20042008 Mar 14, 2024
7600787
update
feifei-111 Mar 15, 2024
cc70542
Merge branch 'cinn-trivalop-fuse' of https://github.com/2742195759/Pa…
feifei-111 Mar 15, 2024
233c21c
update
feifei-111 Mar 15, 2024
a071108
fix softmax error.!
2742195759 Mar 15, 2024
1ed7a04
Merge branch 'cinn-trivalop-fuse' of https://github.com/2742195759/Pa…
2742195759 Mar 15, 2024
0f0bdba
fix
2742195759 Mar 15, 2024
ce46657
update
feifei-111 Mar 15, 2024
1d51303
fix
2742195759 Mar 15, 2024
31a1124
merge
2742195759 Mar 15, 2024
298c45f
fix
2742195759 Mar 18, 2024
0c5c0d0
merge
2742195759 Mar 18, 2024
ffada84
Implement new OpMergeWithOp and add a relevant flag
Mar 18, 2024
b1cd524
Merge pull request #51 from Fridge003/cinn
2742195759 Mar 18, 2024
52369fc
update
feifei-111 Mar 18, 2024
9e91367
update
feifei-111 Mar 18, 2024
ba4b084
update
feifei-111 Mar 18, 2024
2d346ea
fix reduce_load error. add splitReduceTransform
2742195759 Mar 18, 2024
8cc0e11
merge and fix conflict
2742195759 Mar 18, 2024
99e6471
fix conflict
2742195759 Mar 19, 2024
f5f959a
update
feifei-111 Mar 19, 2024
04b3b74
update
feifei-111 Mar 19, 2024
d988733
update
feifei-111 Mar 19, 2024
eead87b
disable horizontal fusion
feifei-111 Mar 19, 2024
e06fb8b
fix
2742195759 Mar 19, 2024
9c595c9
Merge branch 'cinn-trivalop-fuse' of https://github.com/2742195759/Pa…
2742195759 Mar 19, 2024
2ecdf38
Add some VLOG
2742195759 Mar 19, 2024
e9f1de7
merge
2742195759 Mar 19, 2024
064c055
Fix group cluster bug (#71)
Fridge003 Mar 19, 2024
3eda7f5
fix
2742195759 Mar 19, 2024
97c6292
Merge branch 'cinn-trivalop-fuse' of https://github.com/2742195759/Pa…
2742195759 Mar 19, 2024
a2c31c2
fix dyshape
2742195759 Mar 19, 2024
c890a73
fix
2742195759 Mar 20, 2024
47526c1
init split cluster files
feifei-111 Mar 20, 2024
6cd6f1b
update
feifei-111 Mar 20, 2024
f8a6f7c
update
feifei-111 Mar 20, 2024
890c560
update
feifei-111 Mar 20, 2024
05aeb8f
spliting
Mar 20, 2024
dfee88f
Merge pull request #1 from Fridge003/cinn_tmp
feifei-111 Mar 20, 2024
384dafc
update
feifei-111 Mar 20, 2024
6a01a19
spliting
Mar 20, 2024
d86e15e
spliting
Mar 20, 2024
bcbf191
Merge pull request #2 from Fridge003/cinn_tmp
feifei-111 Mar 20, 2024
1f343bf
pattern utils
Mar 20, 2024
3f16743
Merge pull request #3 from Fridge003/cinn_tmp
feifei-111 Mar 20, 2024
5af2ed7
update
feifei-111 Mar 20, 2024
06e9dd9
update
feifei-111 Mar 20, 2024
cc730d3
clean cmake
Mar 20, 2024
e27ea6a
Merge pull request #4 from Fridge003/cinn_tmp
feifei-111 Mar 20, 2024
8a32dcf
update
feifei-111 Mar 20, 2024
48e5924
update
feifei-111 Mar 20, 2024
a095ef1
update
feifei-111 Mar 20, 2024
ff4dbf3
fix clustering_engine
Mar 20, 2024
925cb78
Merge pull request #5 from Fridge003/cinn_tmp
feifei-111 Mar 20, 2024
1f662b8
fix fusion_helper
Mar 20, 2024
2ec08f8
Merge pull request #6 from Fridge003/cinn_tmp
feifei-111 Mar 20, 2024
9f804b7
update
feifei-111 Mar 20, 2024
4a767ec
update
feifei-111 Mar 20, 2024
a3fb6ff
fix
Mar 20, 2024
6c64295
Merge pull request #7 from Fridge003/cinn_tmp
feifei-111 Mar 20, 2024
670c2a0
update
feifei-111 Mar 20, 2024
af835ee
Merge branch 'cinn-trivalop-fuse' of https://github.com/feifei-111/Pa…
feifei-111 Mar 20, 2024
c6649c2
update
feifei-111 Mar 20, 2024
4585e6c
update
feifei-111 Mar 20, 2024
277b202
update
feifei-111 Mar 20, 2024
c02814b
fix
Mar 20, 2024
45cfe69
Merge pull request #8 from Fridge003/cinn_tmp
feifei-111 Mar 20, 2024
5d3d33e
fix some erros
2742195759 Mar 20, 2024
93ae0ef
update
feifei-111 Mar 20, 2024
2711dd4
Merge branch 'cinn-trivalop-fuse' of https://github.com/feifei-111/Pa…
feifei-111 Mar 20, 2024
583f86f
update
feifei-111 Mar 20, 2024
6af8711
Merge branch 'cinn-trivalop-fuse' of https://github.com/2742195759/Pa…
feifei-111 Mar 20, 2024
b8b27d3
fix split with num problem
2742195759 Mar 20, 2024
ac9dd9d
update
feifei-111 Mar 20, 2024
7fedcf4
fix
2742195759 Mar 20, 2024
83a7036
Merge remote-tracking branch 'origin/develop' into cinn-trivalop-fuse
2742195759 Mar 20, 2024
04ff057
fix static issues
Mar 21, 2024
3166622
fix
Mar 21, 2024
11bebda
Merge pull request #9 from Fridge003/cinn_tmp
feifei-111 Mar 21, 2024
1d008f0
init split cluster files (#72)
feifei-111 Mar 21, 2024
5947d36
update
feifei-111 Mar 21, 2024
476fc69
update
feifei-111 Mar 21, 2024
7827491
update
feifei-111 Mar 21, 2024
f194613
update
feifei-111 Mar 21, 2024
7338026
update
feifei-111 Mar 21, 2024
4cf4062
update
feifei-111 Mar 21, 2024
e7cba69
update
feifei-111 Mar 21, 2024
98b3435
update
feifei-111 Mar 21, 2024
4d5e4d8
update
feifei-111 Mar 21, 2024
76e562a
update
feifei-111 Mar 21, 2024
56a112c
split shardable axes provider (#73)
Fridge003 Mar 21, 2024
bb0d5bb
update
feifei-111 Mar 21, 2024
ff22b61
update
feifei-111 Mar 22, 2024
4b0af2c
fix conlict develop
Mar 22, 2024
66aaca7
Merge pull request #74 from Fridge003/cinn
2742195759 Mar 22, 2024
71b0104
update
feifei-111 Mar 22, 2024
f261bbd
Merge branch 'cinn-trivalop-fuse' of https://github.com/2742195759/Pa…
feifei-111 Mar 22, 2024
3ce0cec
fix broadcast (#75)
Fridge003 Mar 22, 2024
5da4797
update
feifei-111 Mar 22, 2024
92f33f9
Merge branch 'cinn-trivalop-fuse' of https://github.com/2742195759/Pa…
feifei-111 Mar 22, 2024
ea6782a
update
feifei-111 Mar 24, 2024
7af6629
Merge remote-tracking branch 'xiongkun/cinn-trivalop-fuse' into cinn-…
2742195759 Mar 25, 2024
bf1e66f
fix
2742195759 Mar 25, 2024
cb41920
fix code format
2742195759 Mar 25, 2024
a9843c0
fix code format
2742195759 Mar 25, 2024
177772a
remove unittest
2742195759 Mar 25, 2024
03d85f7
update
feifei-111 Mar 25, 2024
a0439ff
update
feifei-111 Mar 25, 2024
4fce3e6
update (#77)
feifei-111 Mar 25, 2024
fba58f5
update
feifei-111 Mar 25, 2024
cbf4df9
update
feifei-111 Mar 25, 2024
dddc198
fix
2742195759 Mar 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions paddle/cinn/api/op_topo_pattern.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <list>
#include <variant>
#include <vector>

namespace cinn::api {

template <typename T>
struct ErrorPattern {};

// ElementWise/Broadcast/Injective Ops without reduction ancestors.
template <typename T>
struct InjectiveSourcePattern {};

// Reduce op
template <typename T>
struct SingleReductionOpPattern {};

// ElementWise/Broadcast ops which have shardable dimentions and reduction
// ancestors.
template <typename T>
struct PartialShardablePattern {};

// Reduce base pattern
template <typename T>
struct ReductionPattern {
using Nothing = std::monostate;
std::variant<Nothing, InjectiveSourcePattern<T>, PartialShardablePattern<T>>
input;
SingleReductionOpPattern<T> reduce_op_pattern;

bool HasFusedInput() const {
return !std::holds_alternative<Nothing>(this->input);
}
};

// Stmt := IS | R | PS
// ops in StmtPattern will be lowered into a inlined cuda code.
template <typename T>
using StmtPattern = std::variant<InjectiveSourcePattern<T>,
ReductionPattern<T>,
PartialShardablePattern<T>>;

// Stmts := [Stmt]
template <typename T>
using StmtPatternVec = std::vector<StmtPattern<T>>;
// fuse rules:
// 1. IS * IS -> IS
// 2. PS * PS -> PS
// 3. IS * PS -> PS
// 4. IS * R -> R
// 5. PS * R -> R
// lifting rules:
// 1. R -> Stmts
// 2. PS -> Stmts
// 3. Stmts * Stmts -> Stmts
// OpTopoPattern := Error | Stmts

template <typename T>
using OpTopoPattern = std::variant<ErrorPattern<T>, StmtPatternVec<T>>;

} // namespace cinn::api
23 changes: 1 addition & 22 deletions paddle/cinn/ast_gen_ius/ast_gen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,6 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
const std::vector<ir::Var>& reduce_axis = tensor->reduce_axis;
VLOG(4) << "ast gen: tensor init_body is " << init_body;
for (int i = 0; i < shape.size(); ++i) {
bool is_keep_dim = axis[i]->is_keepdim;
if (FLAGS_group_schedule_tiling_first && is_keep_dim) {
// if tiling first, we need to replace the reduce axis with 0, but don't
// deal with the non-reduce axis
optim::ReplaceVarWithExpr(&init_body, axis[i], Expr(0));
continue;
}
if (!FLAGS_group_schedule_tiling_first &&
FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) {
optim::ReplaceVarWithExpr(&init_body, axis[i], Expr(0));
Expand Down Expand Up @@ -144,13 +137,6 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
// for same axis so we re-create objects
std::vector<Var> reduce_axis_vars = cinn::common::GenDefaultAxis(axis_len);
for (int i = 0; i < shape.size(); ++i) {
bool is_keep_dim = axis[i]->is_keepdim;
if (FLAGS_group_schedule_tiling_first && is_keep_dim) {
// if tiling first, we need to replace the reduce axis with 0, but don't
// deal with the non-reduce axis
optim::ReplaceVarWithExpr(&reduce_body, axis[i], Expr(0));
continue;
}
if (!FLAGS_group_schedule_tiling_first &&
FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) {
optim::ReplaceVarWithExpr(&reduce_body, axis[i], Expr(0));
Expand Down Expand Up @@ -185,10 +171,7 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
std::vector<ir::Var> non_reduce_axis_vars = [&]() {
std::vector<ir::Var> res;
for (int i = 0; i < shape.size(); ++i) {
bool is_keep_dim = axis[i]->is_keepdim;
if (!is_keep_dim) {
res.push_back(axis[i]);
}
res.push_back(axis[i]);
}
return res;
}();
Expand Down Expand Up @@ -240,10 +223,6 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
// Put the two parts together
ir::Expr body = ir::Block::Make({init_body, reduce_body});
for (int i = static_cast<int>(axis_len) - 1; i >= 0; --i) {
bool is_keep_dim = axis[i]->is_keepdim;
if (FLAGS_group_schedule_tiling_first && is_keep_dim) {
continue;
}
if ((!FLAGS_group_schedule_tiling_first || !FLAGS_cinn_bucket_compile) &&
shape[i] == Expr(1)) {
continue;
Expand Down
1 change: 1 addition & 0 deletions paddle/cinn/backends/codegen_cuda_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ detail::CollectBucketStrategyHostFunctionVisitor::GenDeviceKernelName(

void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc(
ir::Expr func, ir::Expr predicate) {
VLOG(4) << "Process Lowered Func" << func;
ir::_LoweredFunc_ *func_node = func.as_lowered_func();
CHECK(func_node);
if (!func_node->cuda_axis_info.valid()) {
Expand Down
1 change: 1 addition & 0 deletions paddle/cinn/frontend/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ add_subdirectory(paddle)
add_subdirectory(decomposer)
add_subdirectory(op_mappers)
add_subdirectory(pass)
add_subdirectory(cluster_ops)

cinn_cc_test(test_op_mapper_registry SRCS op_mapper_registry_test.cc DEPS
cinncore)
13 changes: 13 additions & 0 deletions paddle/cinn/frontend/cluster_ops/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
gather_srcs(
cluster_ops_src
SRCS
common_utils.cc
shardable_axes_inferer.cc
shardable_axes_provider.cc
shardable_axes_utils.cc
pattern_utils.cc
fusion_helper.cc
cluster_policy.cc
clustering_engine.cc)

cc_library(cluster_ops SRCS ${cluster_ops_src})
57 changes: 57 additions & 0 deletions paddle/cinn/frontend/cluster_ops/cluster_ops.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/cinn/frontend/cluster_ops/clustering_engine.h"

namespace cinn::frontend {

cluster_ops::ClusteringResult ClusterOps(
const cinn::dialect::GroupOp& group_op) {
const auto& ops = [&] {
std::vector<const pir::Operation*> ops;
for (const auto& op : group_op.GetOperators()) {
ops.push_back(op);
}
return ops;
}();

VLOG(4) << "Start Cluster Ops!";
VLOG(4) << "Input Group with size " << ops.size() << " :\n"
<< cluster_ops::OpsDebugStr(ops);

auto shardable_axes_provider = [&] {
auto* program = group_op->GetParentProgram();
const auto* shape_analysis =
&pir::ShapeAnalysisManager::Instance().Get(program);
return cluster_ops::MakeDefaultShardableAxesProvider(shape_analysis);
}();

auto cluster_policy = [&] {
auto* program = group_op->GetParentProgram();
const auto* shape_analysis =
&pir::ShapeAnalysisManager::Instance().Get(program);
return cluster_ops::MakeLoopAlignableClusteringPolicy(shape_analysis);
}();

cluster_ops::ShardableAxesInferer inferer(shardable_axes_provider);
cluster_ops::ClusteringEngine engine(ops, inferer, cluster_policy);

auto result = engine.ClusterOps();
VLOG(4) << result.DebugStr();
VLOG(4) << "Finished Cluster Ops!";
return result;
}
} // namespace cinn::frontend
Loading