Skip to content

Commit

Permalink
[PIR] op's next op should not be a while op in constant_folding_pass (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanlehome authored Jan 30, 2024
1 parent d6cf7e3 commit ea2a78a
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion paddle/fluid/pir/transforms/constant_folding_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "paddle/fluid/framework/new_executor/interpretercore.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h"
#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
Expand Down Expand Up @@ -133,9 +134,16 @@ class ConstantFoldingPattern : public pir::RewritePattern {
if (!op->result(i).type().isa<paddle::dialect::DenseTensorType>()) {
return false;
}
// 6. next op should not be a while op
for (auto it = op->result(i).use_begin(); it != op->result(i).use_end();
++it) {
if (it.owner()->isa<paddle::dialect::WhileOp>()) {
return false;
}
}
}

// 6. maybe affect performence
// 7. maybe affect performence
if (op->isa<paddle::dialect::FullOp>()) {
auto next_ops = pir::GetUseOpsForOutput(op, 0);
for (auto [next_op, _] : next_ops) {
Expand Down

0 comments on commit ea2a78a

Please sign in to comment.