Skip to content

Commit

Permalink
Modify according to review
Browse files Browse the repository at this point in the history
  • Loading branch information
gongshaotian committed Aug 17, 2023
1 parent 8312452 commit 82242b7
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 15 deletions.
2 changes: 0 additions & 2 deletions paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,6 @@ class Op {
public:
const std::string& name() const { return op_type_name_; }

const std::unordered_map<std::string, Attribute>& attribute() const { return attributes_; }

void operator()(const Tensor& arg, const Tensor* out) const;

Tensor& operator()() const;
Expand Down
4 changes: 2 additions & 2 deletions paddle/ir/pattern_rewrite/drr/drr_rewrite_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {
while (!drr_q.empty()) {
if (!Matched) {
break;
};
}
IR_ENFORCE(drr_q.size() == ir_q.size());
// if (drr_q.size() != ir_q.size()) {
// Matched = false;
Expand Down Expand Up @@ -181,7 +181,7 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {
for (size_t i = 0; i < drr_output_tensors.size(); ++i) {
if (!Matched) break;

if (drr_output_tensors[i]->consumers().size() == 0){
if (drr_output_tensors[i]->consumers().empty()){
continue;
}

Expand Down
6 changes: 3 additions & 3 deletions paddle/ir/pattern_rewrite/drr/ir_operation_creator.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ Operation* CreateOperation(const OpCall& op_call,
Operation* transpose_op = rewriter.Build<paddle::dialect::TransposeOp>(
ir_values[0].dyn_cast<ir::OpResult>(),
std::vector<int>{0, 2, 1, 3});
auto out = transpose_op->result(0);
res_match_ctx->BindIrValue(op_call.outputs()[0]->name(),
std::make_shared<IrValue>(out));
res_match_ctx->BindIrValue(
op_call.outputs()[0]->name(),
std::make_shared<IrValue>(transpose_op->result(0)));
return transpose_op;
}

Expand Down
9 changes: 1 addition & 8 deletions test/cpp/ir/pattern_rewrite/drr_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,21 +144,14 @@ void BuildProgram(ir::Builder &builder) { // NOLINT
builder.Build<paddle::dialect::FetchOp>(relu_op_second.out(), "out", 0);
}

//
std::unique_ptr<RemoveRedundentTransposePattern> CreateDrrPatternRewritePass(
ir::IrContext *ir_ctx) {
return std::make_unique<RemoveRedundentTransposePattern>(ir_ctx, 1);
}


class DrrPatternRewritePass : public ir::Pass {
public:
DrrPatternRewritePass() : ir::Pass("DrrPatternRewritePass", 1) {}

bool Initialize(ir::IrContext *context) override {
ir::RewritePatternSet ps(context);
ps.Add(std::make_unique<RemoveRedundentReshapePattern>(context));
ps.Add(std::make_unique<RemoveRedundentReshapePattern>(context));
ps.Add(std::make_unique<RemoveRedundentTransposePattern>(context));

patterns_ = ir::FrozenRewritePatternSet(std::move(ps));
return true;
Expand Down

0 comments on commit 82242b7

Please sign in to comment.