Skip to content

Commit

Permalink
Merge pull request #8 from zyfncg/drr_ir
Browse files Browse the repository at this point in the history
[DRR] Add logic of Rewrite
  • Loading branch information
yuanlehome authored Aug 10, 2023
2 parents 7abe8db + 4d0d26a commit 4099c7a
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 26 deletions.
84 changes: 66 additions & 18 deletions paddle/ir/pattern_rewrite/drr/drr_rewrite_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.h"
#include "paddle/ir/pattern_rewrite/drr/api/match_context.h"
#include "paddle/ir/pattern_rewrite/drr/ir_operation.h"
#include "paddle/ir/pattern_rewrite/drr/ir_operation_creator.h"
#include "paddle/ir/pattern_rewrite/drr/match_context_impl.h"
#include "paddle/ir/pattern_rewrite/drr/pattern_graph.h"
#include "paddle/ir/pattern_rewrite/pattern_match.h"
Expand All @@ -47,7 +48,7 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {
source_pattern_graph_->Print();
result_pattern_graph_->Print();

match_context_impl_ = std::make_unique<MatchContextImpl>();
source_pattern_match_ctx_ = std::make_unique<MatchContextImpl>();
}

bool Match(SourceOp op) const override {
Expand All @@ -62,8 +63,8 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {
ir_q.push(op);
drr_visited.insert(anchor);
ir_visited.insert(op);
match_context_impl_->BindIrOperation(op->name(),
std::make_shared<IrOperation>(op));
source_pattern_match_ctx_->BindIrOperation(
anchor, std::make_shared<IrOperation>(op));
bool Matched = true;
size_t step = 0;
while (!drr_q.empty()) {
Expand Down Expand Up @@ -96,7 +97,7 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {
// check brother ops
auto drr_brother_ops = drr_input_tensors[i]->consumers();
auto ir_input_value = ir_node->operand(i).source();
match_context_impl_->BindIrValue(
source_pattern_match_ctx_->BindIrValue(
drr_input_tensors[i]->name(),
std::make_shared<IrValue>(ir_input_value));
if (drr_brother_ops.size() != ir_input_value.use_count()) {
Expand All @@ -114,7 +115,7 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {
}
// todo()
if (drr_brother_op->name() == ir_op->name()) {
found = {true, ir_op};
found = std::make_pair(true, ir_op);
break;
}
}
Expand All @@ -123,9 +124,8 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {
ir_q.push(found.second);
drr_visited.insert(drr_brother_op);
ir_visited.insert(found.second);
match_context_impl_->BindIrOperation(
found.second->name(),
std::make_shared<IrOperation>(found.second));
source_pattern_match_ctx_->BindIrOperation(
drr_brother_op, std::make_shared<IrOperation>(found.second));
} else {
Matched = false;
break;
Expand All @@ -144,9 +144,8 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {
ir_q.push(ir_ancestor_op);
drr_visited.insert(drr_ancestor_op);
ir_visited.insert(ir_ancestor_op);
match_context_impl_->BindIrOperation(
ir_ancestor_op->name(),
std::make_shared<IrOperation>(ir_ancestor_op));
source_pattern_match_ctx_->BindIrOperation(
drr_ancestor_op, std::make_shared<IrOperation>(ir_ancestor_op));
}
}

Expand All @@ -164,7 +163,7 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {
// check child ops
auto drr_child_ops = drr_output_tensors[i]->consumers();
auto ir_output_value = ir_node->result(i);
match_context_impl_->BindIrValue(
source_pattern_match_ctx_->BindIrValue(
drr_output_tensors[i]->name(),
std::make_shared<IrValue>(ir_output_value));
if (drr_child_ops.size() != ir_output_value.use_count()) {
Expand All @@ -191,9 +190,8 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {
ir_q.push(found.second);
drr_visited.insert(drr_child_op);
ir_visited.insert(found.second);
match_context_impl_->BindIrOperation(
found.second->name(),
std::make_shared<IrOperation>(found.second));
source_pattern_match_ctx_->BindIrOperation(
drr_child_op, std::make_shared<IrOperation>(found.second));
} else {
Matched = false;
break;
Expand All @@ -211,7 +209,7 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {
// Matched = Matched && step == source_pattern_graph_->CountOfOpCalls();

// Constraints
MatchContext match_context{match_context_impl_};
MatchContext match_context{source_pattern_match_ctx_};
for (const auto& constraint : constraints_) {
Matched = constraint(match_context);
if (!Matched) break;
Expand All @@ -222,14 +220,64 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {

void Rewrite(SourceOp op,
ir::PatternRewriter& rewriter) const override { // NOLINT
// Rewrite
// 1. Create Operations in result_pattern_graph
MatchContextImpl res_match_ctx = CreateOperations(
*result_pattern_graph_, *source_pattern_match_ctx_, rewriter);

// 2. Replace Output Values in source_pattern_graph by Output Values in
// result_pattern_graph
ReplaceOutputTensor(*source_pattern_match_ctx_, res_match_ctx, rewriter);

// 3. Delete Operations in source_pattern_graph
DeleteSourcePatternOp(*source_pattern_match_ctx_, rewriter);
}

MatchContextImpl CreateOperations(
const ResultPatternGraph& result_pattern_graph,
const MatchContextImpl& src_match_ctx,
ir::PatternRewriter& rewriter) const { // NOLINT
MatchContextImpl res_match_ctx;
// add input tensors info for res_match_ctx;
const auto& input_tensors = result_pattern_graph.input_tensors();
for (const auto& in_tensor : input_tensors) {
res_match_ctx.BindIrValue(
in_tensor,
std::make_shared<IrValue>(src_match_ctx.GetIrValue(in_tensor)));
}

// topo order visit result_pattern_graph
GraphTopo graph_topo_visit(&result_pattern_graph);
graph_topo_visit.WalkGraphNodesTopoOrder(
[&rewriter, &res_match_ctx](const OpCall& op_call) {
CreateOperation(op_call, rewriter, &res_match_ctx);
});

return res_match_ctx;
}

void ReplaceOutputTensor(const MatchContextImpl& src_match_ctx,
const MatchContextImpl& res_match_ctx,
ir::PatternRewriter& rewriter) const { // NOLINT
for (const auto& output_name : source_pattern_graph_->output_tensors()) {
const auto& src_ir_tensor = src_match_ctx.GetIrValue(output_name);
const auto& res_ir_tensor = res_match_ctx.GetIrValue(output_name);
rewriter.ReplaceAllUsesWith(src_ir_tensor.get(), res_ir_tensor.get());
}
}

void DeleteSourcePatternOp(const MatchContextImpl& src_match_ctx,
ir::PatternRewriter& rewriter) const { // NOLINT
for (const auto& kv : src_match_ctx.operation_map()) {
rewriter.EraseOp(kv.second->get());
}
}

private:
std::shared_ptr<MatchContextImpl> match_context_impl_;
std::shared_ptr<SourcePatternGraph> source_pattern_graph_;
std::vector<Constraint> constraints_;
std::shared_ptr<ResultPatternGraph> result_pattern_graph_;

std::shared_ptr<MatchContextImpl> source_pattern_match_ctx_;
};

} // namespace drr
Expand Down
64 changes: 64 additions & 0 deletions paddle/ir/pattern_rewrite/drr/ir_operation_creator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <vector>

#include "paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.h"
#include "paddle/ir/pattern_rewrite/drr/match_context_impl.h"
#include "paddle/ir/pattern_rewrite/pattern_match.h"

#include "paddle/fluid/ir/dialect/pd_op.h"

namespace ir {
namespace drr {

Value GetIrValueByDrrTensor(const Tensor& tensor,
const MatchContextImpl& res_match_ctx) {
return res_match_ctx.GetIrValue(tensor.name()).get();
}

std::vector<Value> GetIrValuesByDrrTensors(
const std::vector<const Tensor*>& tensors,
const MatchContextImpl& res_match_ctx) {
std::vector<Value> ir_values;
ir_values.reserve(tensors.size());
for (const auto* tensor : tensors) {
ir_values.push_back(GetIrValueByDrrTensor(*tensor, res_match_ctx));
}
return ir_values;
}

Operation* CreateOperation(const OpCall& op_call,
ir::PatternRewriter& rewriter, // NOLINT
MatchContextImpl* res_match_ctx) {
if (op_call.name() == "pd.reshape") {
const auto& inputs = op_call.inputs();
std::vector<Value> ir_values =
GetIrValuesByDrrTensors(inputs, *res_match_ctx);
// TODO(zyfncg): support attr in build op.
Operation* reshape_op = rewriter.Build<paddle::dialect::ReshapeOp>(
ir::OpResult(ir_values[0].impl()), std::vector<int64_t>{16, 3, 4, 16});
auto out = reshape_op->result(0);
res_match_ctx->BindIrValue(op_call.outputs()[0]->name(),
std::make_shared<IrValue>(out));
return reshape_op;
}
LOG(ERROR) << "Unknown op " << op_call.name();
return nullptr;
}

} // namespace drr
} // namespace ir
6 changes: 5 additions & 1 deletion paddle/ir/pattern_rewrite/drr/ir_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ class IrDtype {
class IrValue : public TensorInterface {
public:
explicit IrValue(const ir::Value& value)
: shape_(
: value_(value),
shape_(
&value.type().dyn_cast<paddle::dialect::DenseTensorType>().dims()),
dtype_(&value.type()
.dyn_cast<paddle::dialect::DenseTensorType>()
Expand All @@ -57,7 +58,10 @@ class IrValue : public TensorInterface {
ShapeInterface Shape() const override { return ShapeInterface(&shape_); }
DtypeInterface Dtype() const override { return DtypeInterface(&dtype_); }

const Value& get() const { return value_; }

private:
const Value value_;
const IrShape shape_;
const IrDtype dtype_;
};
Expand Down
21 changes: 16 additions & 5 deletions paddle/ir/pattern_rewrite/drr/match_context_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
namespace ir {
namespace drr {

class OpCall;
template <class T>
struct CppTypeToIrAttribute;

Expand All @@ -48,8 +49,8 @@ class MatchContextImpl final {
return *tensor_map_.at(tensor_name);
}

const IrOperation& Operation(const std::string& op_name) const {
return *operation_map_.at(op_name);
const IrOperation& Operation(const OpCall* op_call) const {
return *operation_map_.at(op_call);
}

template <typename T>
Expand All @@ -59,14 +60,23 @@ class MatchContextImpl final {
.data();
}

const IrValue& GetIrValue(const std::string& tensor_name) const {
return *tensor_map_.at(tensor_name);
}

const std::unordered_map<const OpCall*, std::shared_ptr<IrOperation>>&
operation_map() const {
return operation_map_;
}

void BindIrValue(const std::string& value_name,
const std::shared_ptr<IrValue>& value) {
tensor_map_.emplace(value_name, value);
}

void BindIrOperation(const std::string& op_name,
void BindIrOperation(const OpCall* op_call,
const std::shared_ptr<IrOperation>& op) {
operation_map_.emplace(op_name, op);
operation_map_.emplace(op_call, op);
}

void BindIrAttr(const std::string& attr_name, ir::Attribute attr) {
Expand All @@ -75,7 +85,8 @@ class MatchContextImpl final {

private:
std::unordered_map<std::string, std::shared_ptr<IrValue>> tensor_map_;
std::unordered_map<std::string, std::shared_ptr<IrOperation>> operation_map_;
std::unordered_map<const OpCall*, std::shared_ptr<IrOperation>>
operation_map_;
std::unordered_map<std::string, ir::Attribute> attr_map_;
};

Expand Down
20 changes: 20 additions & 0 deletions paddle/ir/pattern_rewrite/drr/pattern_graph.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once

#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
Expand Down Expand Up @@ -41,6 +42,14 @@ class PatternGraph {
void UpdateTmpTensor(const id_type& tmp_tensor_id,
const id_type& new_tensor_id);

const std::unordered_set<id_type>& input_tensors() const {
return input_tensors_;
}

const std::unordered_set<id_type>& output_tensors() const {
return output_tensors_;
}

size_t CountOfOpCalls() const;

void Print() const;
Expand All @@ -52,6 +61,17 @@ class PatternGraph {
std::unordered_set<id_type> output_tensors_;
};

class GraphTopo {
public:
explicit GraphTopo(const PatternGraph* graph) : graph_(graph) {}

void WalkGraphNodesTopoOrder(
const std::function<void(const OpCall&)>& VisitNode) const {}

private:
const PatternGraph* graph_;
};

class SourcePatternGraph : public PatternGraph {
public:
const OpCall* AnchorNode() const;
Expand Down
2 changes: 0 additions & 2 deletions test/cpp/ir/pattern_rewrite/drr_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
#include "paddle/ir/pattern_rewrite/drr/drr_rewrite_pattern.h"
#include "paddle/ir/pattern_rewrite/pattern_rewrite_driver.h"

#include "paddle/fluid/ir/dialect/pd_op.h"

struct RemoveRedundentReshapeFunctor {
void operator()(ir::drr::DrrPatternContext *ctx) {
// Source patterns:待匹配的子图
Expand Down

0 comments on commit 4099c7a

Please sign in to comment.