Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#14 from gongshaotian/op_extension
Browse files Browse the repository at this point in the history
 ADD duplicate TransposeOp merge testing code
  • Loading branch information
yuanlehome authored Aug 17, 2023
2 parents a3b3895 + 82242b7 commit 5282144
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 7 deletions.
11 changes: 9 additions & 2 deletions paddle/ir/pattern_rewrite/drr/drr_rewrite_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {
SourceOp op,
const std::shared_ptr<MatchContextImpl>& source_pattern_match_ctx) const {
// Match
auto* anchor = source_pattern_graph_->AnchorNode();
const auto* anchor = source_pattern_graph_->AnchorNode();
IR_ENFORCE(anchor);
std::unordered_set<const OpCall*> drr_visited;
std::unordered_set<const Operation*> ir_visited;
Expand All @@ -81,7 +81,9 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {
bool Matched = true;
size_t step = 0;
while (!drr_q.empty()) {
if (!Matched) break;
if (!Matched) {
break;
}
IR_ENFORCE(drr_q.size() == ir_q.size());
// if (drr_q.size() != ir_q.size()) {
// Matched = false;
Expand Down Expand Up @@ -178,6 +180,11 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {

for (size_t i = 0; i < drr_output_tensors.size(); ++i) {
if (!Matched) break;

if (drr_output_tensors[i]->consumers().empty()){
continue;
}

// check child ops
auto drr_child_ops = drr_output_tensors[i]->consumers();
auto ir_output_value = ir_node->result(i);
Expand Down
12 changes: 12 additions & 0 deletions paddle/ir/pattern_rewrite/drr/ir_operation_creator.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,18 @@ Operation* CreateOperation(const OpCall& op_call,
std::make_shared<IrValue>(reshape_op->result(1)));
return reshape_op;
}
else if(op_call.name() == "pd.transpose") {
const auto& inputs = op_call.inputs();
std::vector<Value> ir_values = GetIrValuesByDrrTensors(inputs, *res_match_ctx);
Operation* transpose_op = rewriter.Build<paddle::dialect::TransposeOp>(
ir_values[0].dyn_cast<ir::OpResult>(),
std::vector<int>{0, 2, 1, 3});
res_match_ctx->BindIrValue(
op_call.outputs()[0]->name(),
std::make_shared<IrValue>(transpose_op->result(0)));
return transpose_op;
}

LOG(ERROR) << "Unknown op " << op_call.name();
return nullptr;
}
Expand Down
49 changes: 44 additions & 5 deletions test/cpp/ir/pattern_rewrite/drr_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "paddle/ir/pattern_rewrite/pattern_rewrite_driver.h"
#include "paddle/ir/transforms/dead_code_elimination_pass.h"


struct RemoveRedundentReshapeFunctor {
void operator()(ir::drr::DrrPatternContext *ctx) {
// Source patterns:待匹配的子图
Expand Down Expand Up @@ -55,10 +56,10 @@ struct FoldBroadcastToConstantFunctor {
ir::drr::SourcePattern pat = ctx->SourcePattern();
// Source Pattern 中可匹配的类型包括 Op 和 Tensor
const auto &fill_constant = pat.Op(
"fill_constant",
"pd.fill_constant",
{{"value", pat.Attr("value_1")}, {"dtype", pat.Attr("dtype_1")}});
const auto &broadcast_to =
pat.Op("expand", {{"shape", pat.Attr("shape_1")}});
pat.Op("pd.expand", {{"shape", pat.Attr("shape_1")}});
// 匹配fill_constant+broadcast_to,同时对输出张量标记为ret,方便后面加约束
pat.Tensor("ret") = broadcast_to(fill_constant());
// Constrains:本Pass无额外的约束规则
Expand All @@ -68,7 +69,7 @@ struct FoldBroadcastToConstantFunctor {
// broadcast_to(fill_constant()),注意shape属性已更新 所有 ret
// 参数均在Source Pattern中使用,对 ret 的赋值等同于对 ret 的 producer
// op的删除和重连接
const auto &folded_fill_constant = res.Op("fill_constant",
const auto &folded_fill_constant = res.Op("pd.fill_constant",
{{"shape", res.Attr("shape_1")},
{"value", res.Attr("value_1")},
{"dtype", res.Attr("dtype_1")}});
Expand All @@ -85,6 +86,34 @@ class FoldBroadcastToConstantPattern
FoldBroadcastToConstantFunctor>::DrrRewritePattern;
};

struct RemoveRedundentTransposeFunctor{
void operator()(ir::drr::DrrPatternContext *ctx){
// Source pattern: 待匹配的子图
ir::drr::SourcePattern pat = ctx->SourcePattern();
const auto &transpose1 = pat.Op("pd.transpose", {{"perm", pat.Attr("perm_1")}});
const auto &transpose2 = pat.Op("pd.transpose", {{"perm", pat.Attr("perm_2")}});

pat.Tensor("ret") = transpose2(transpose1(pat.Tensor("arg_transpose")));

// Result patterns: 要替换的子图
ir::drr::ResultPattern res = pat.ResultPattern();
const auto &tranpose_continuous = res.Op("pd.transpose",
{{"perm", pat.Attr("perm_2")}}); // TODO 先简单用perm2替换

res.Tensor("ret") = tranpose_continuous(res.Tensor("arg_transpose"));
}
};

class RemoveRedundentTransposePattern
: public ir::drr::DrrRewritePattern<paddle::dialect::TransposeOp,
RemoveRedundentTransposeFunctor> {
public:
using ir::drr::DrrRewritePattern<
paddle::dialect::TransposeOp,
RemoveRedundentTransposeFunctor>::DrrRewritePattern;
};


void BuildProgram(ir::Builder &builder) { // NOLINT
paddle::dialect::FullOp full_input_op =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{4, 3, 16, 16},
Expand All @@ -103,7 +132,16 @@ void BuildProgram(ir::Builder &builder) { // NOLINT
paddle::dialect::ReluOp relu_op =
builder.Build<paddle::dialect::ReluOp>(reshape_op_second.out());

builder.Build<paddle::dialect::FetchOp>(relu_op.out(), "out", 0);
paddle::dialect::TransposeOp transpose_op_first =
builder.Build<paddle::dialect::TransposeOp>(relu_op.out(), std::vector<int>{0, 2, 1, 3});

paddle::dialect::TransposeOp transpose_op_second =
builder.Build<paddle::dialect::TransposeOp>(transpose_op_first.out(), std::vector<int>{0, 1, 2, 3});

paddle::dialect::ReluOp relu_op_second =
builder.Build<paddle::dialect::ReluOp>(transpose_op_second.out());

builder.Build<paddle::dialect::FetchOp>(relu_op_second.out(), "out", 0);
}

class DrrPatternRewritePass : public ir::Pass {
Expand All @@ -113,6 +151,7 @@ class DrrPatternRewritePass : public ir::Pass {
bool Initialize(ir::IrContext *context) override {
ir::RewritePatternSet ps(context);
ps.Add(std::make_unique<RemoveRedundentReshapePattern>(context));
ps.Add(std::make_unique<RemoveRedundentTransposePattern>(context));

patterns_ = ir::FrozenRewritePatternSet(std::move(ps));
return true;
Expand Down Expand Up @@ -140,7 +179,7 @@ TEST(DrrTest, drr) {
ir::Builder builder = ir::Builder(ctx, program.block());
BuildProgram(builder);

EXPECT_EQ(program.block()->size(), 7u);
EXPECT_EQ(program.block()->size(), 10u);

ir::PassManager pm(ctx);
pm.AddPass(std::make_unique<DrrPatternRewritePass>());
Expand Down

0 comments on commit 5282144

Please sign in to comment.