Skip to content

Commit

Permalink
fix reshape tmp
Browse files Browse the repository at this point in the history
  • Loading branch information
huangjiyi committed Jul 16, 2024
1 parent 6164f3a commit 486d511
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions paddle/cinn/operator_fusion/graph_transformer/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,16 @@ struct LiftToAnchorPatternOperation {
PatternNodePtr operator()(PatternGraph* graph, PatternNodePtr node) {
auto origin_name = node->id();
std::vector<pir::Operation*> ops = GetOpsInPattern(node->stmt_pattern());
// TODO(@wuzhanfei) move sink_op into pattern (currently, part of pattern
// type has sink and the others not) then, update logic here
PADDLE_ENFORCE_EQ(
node->sink_op()->num_results(),
1,
phi::errors::PreconditionNotMet(
"Op with multi output value can not lift to AnchorPattern"));
// TODO(huangjiyi): rm if condition after xshape removed in pd_op.reshape
if (node->sink_op()->name() != "pd_op.reshape") {
// TODO(@wuzhanfei) move sink_op into pattern (currently, part of pattern
// type has sink and the others not) then, update logic here
PADDLE_ENFORCE_EQ(
node->sink_op()->num_results(),
1,
phi::errors::PreconditionNotMet(
"Op with multi output value can not lift to AnchorPattern"));
}
pir::Value anchor = node->sink_op()->result(0);
node->set_stmt_pattern(AnchorPattern(
ops,
Expand Down

0 comments on commit 486d511

Please sign in to comment.