Skip to content

Commit

Permalink
fix test_sub_graph_23
Browse files Browse the repository at this point in the history
  • Loading branch information
huangjiyi committed Aug 4, 2024
1 parent e6cd91e commit 930f061
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
17 changes: 17 additions & 0 deletions paddle/cinn/operator_fusion/graph_transformer/matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,23 @@ struct HorizontalFusionMatcher {
}
};

struct LEOneElementWiseDownstreamMatcher {
bool operator()(const PatternGraph& graph, const PatternNodePtr& node) {
size_t count = 0;
for (const auto& downsteram : node->downstream()) {
if (StmtPatternGraphMatcher<TrivialPattern>()(graph, downsteram)) {
auto ops = std::get<TrivialPattern>(downsteram->stmt_pattern()).ops();
bool is_elementwise =
std::all_of(ops.begin(), ops.end(), [](pir::Operation* op) {
return GetOpPatternKind(op) == hlir::framework::kElementWise;
});
count += is_elementwise;
}
}
return (count <= 1);
}
};

struct NonSinkNodeMatcher {
bool operator()(const PatternGraph& graph, const PatternNodePtr& node) {
return !node->downstream().empty();
Expand Down
6 changes: 6 additions & 0 deletions paddle/cinn/operator_fusion/pattern_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ void PatternGraph::SinkTrivialPattern() {
DownstreamSmallerThan<2>,
NonSinkNodeMatcher>,
MergeTrivialPatternOperation>(this);

GraphTransformer<NodePattern,
And<StmtPatternGraphMatcher<TrivialPattern>,
NonSinkNodeMatcher,
LEOneElementWiseDownstreamMatcher>,
MergeTrivialPatternOperation>(this);
}

void PatternGraph::ReduceLiftReduceTree() {
Expand Down

0 comments on commit 930f061

Please sign in to comment.