From 7abe8dbc1a289edd763386828a5e207647ff59a5 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Thu, 10 Aug 2023 06:37:38 +0000 Subject: [PATCH] refine code --- .../drr/api/drr_pattern_context.h | 8 ++--- .../drr/api/tensor_interface.cc | 2 +- .../pattern_rewrite/drr/drr_rewrite_pattern.h | 32 ++++++++++--------- .../drr/{ir_tensor.h => ir_value.h} | 6 ++-- .../pattern_rewrite/drr/match_context_impl.h | 2 +- .../ir/pattern_rewrite/drr/pattern_graph.cc | 2 +- 6 files changed, 27 insertions(+), 25 deletions(-) rename paddle/ir/pattern_rewrite/drr/{ir_tensor.h => ir_value.h} (91%) diff --git a/paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.h b/paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.h index 8cf9c9794bc0b..efe7f1866483e 100755 --- a/paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.h +++ b/paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.h @@ -108,7 +108,7 @@ class DrrPatternContext { std::vector> owned_ops_; }; -class Op : public std::enable_shared_from_this { +class Op { public: const std::string& name() const { return op_type_name_; } @@ -143,7 +143,7 @@ class Op : public std::enable_shared_from_this { PatternGraph* pattern_graph_; }; -class Tensor : public std::enable_shared_from_this { +class Tensor { public: const std::string& DebugName() const; @@ -155,7 +155,7 @@ class Tensor : public std::enable_shared_from_this { const std::string& name() const { return name_; } - void SetName(const std::string& name) { name_ = name; } + void set_name(const std::string& name) { name_ = name; } OpCall* producer() const { return producer_; } @@ -182,7 +182,7 @@ class Tensor : public std::enable_shared_from_this { PatternGraph* pattern_graph_; }; -class OpCall : public std::enable_shared_from_this { +class OpCall { public: OpCall(const Op* op, const std::vector& inputs, diff --git a/paddle/ir/pattern_rewrite/drr/api/tensor_interface.cc b/paddle/ir/pattern_rewrite/drr/api/tensor_interface.cc index d541322e45169..4d2b7a4a73574 100644 --- a/paddle/ir/pattern_rewrite/drr/api/tensor_interface.cc +++ b/paddle/ir/pattern_rewrite/drr/api/tensor_interface.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/ir/pattern_rewrite/drr/api/tensor_interface.h" -#include "paddle/ir/pattern_rewrite/drr/ir_tensor.h" +#include "paddle/ir/pattern_rewrite/drr/ir_value.h" namespace ir { namespace drr { diff --git a/paddle/ir/pattern_rewrite/drr/drr_rewrite_pattern.h b/paddle/ir/pattern_rewrite/drr/drr_rewrite_pattern.h index e775cafe12ec1..8b1a82ef413ed 100644 --- a/paddle/ir/pattern_rewrite/drr/drr_rewrite_pattern.h +++ b/paddle/ir/pattern_rewrite/drr/drr_rewrite_pattern.h @@ -54,13 +54,13 @@ class DrrRewritePattern : public ir::OpRewritePattern { // Match auto* anchor = source_pattern_graph_->AnchorNode(); IR_ENFORCE(anchor); - std::unordered_set drr_visited; - std::unordered_set ir_visited; - std::queue drr_q; - std::queue ir_q; - drr_q.push(anchor.get()); + std::unordered_set drr_visited; + std::unordered_set ir_visited; + std::queue drr_q; + std::queue ir_q; + drr_q.push(anchor); ir_q.push(op); - drr_visited.insert(anchor.get()); + drr_visited.insert(anchor); ir_visited.insert(op); match_context_impl_->BindIrOperation(op->name(), std::make_shared(op)); @@ -87,17 +87,18 @@ class DrrRewritePattern : public ir::OpRewritePattern { const auto& drr_input_tensors = drr_node->inputs(); auto ir_input_value_size = ir_node->num_operands(); // check input's size - if (drr_input_tensors.sizes() != ir_input_value_size) { + if (drr_input_tensors.size() != ir_input_value_size) { Matched = false; break; } - for (int i = 0; i < drr_input_tensors.size(); ++i) { + for (size_t i = 0; i < drr_input_tensors.size(); ++i) { if (!Matched) break; // check brother ops auto drr_brother_ops = drr_input_tensors[i]->consumers(); - auto ir_input_value = ir_node->operand(i); - match_context_impl_->BindIrValue(drr_input_tensors[i]->name(), - ir_input_value); + auto ir_input_value = ir_node->operand(i).source(); + match_context_impl_->BindIrValue( + drr_input_tensors[i]->name(), + std::make_shared(ir_input_value)); if (drr_brother_ops.size() != ir_input_value.use_count()) { Matched = false; break; @@ -154,17 +155,18 @@ class DrrRewritePattern : public ir::OpRewritePattern { const auto& drr_output_tensors = drr_node->outputs(); auto ir_output_value_size = ir_node->num_results(); // check output's size - if (drr_output_tensors.sizes() != ir_output_value_size) { + if (drr_output_tensors.size() != ir_output_value_size) { Matched = false; break; } - for (int i = 0; i < drr_output_tensors.size(); ++i) { + for (size_t i = 0; i < drr_output_tensors.size(); ++i) { if (!Matched) break; // check child ops auto drr_child_ops = drr_output_tensors[i]->consumers(); auto ir_output_value = ir_node->result(i); - match_context_impl_->BindIrValue(drr_output_tensors[i]->name(), - ir_output_value); + match_context_impl_->BindIrValue( + drr_output_tensors[i]->name(), + std::make_shared(ir_output_value)); if (drr_child_ops.size() != ir_output_value.use_count()) { Matched = false; break; diff --git a/paddle/ir/pattern_rewrite/drr/ir_tensor.h b/paddle/ir/pattern_rewrite/drr/ir_value.h similarity index 91% rename from paddle/ir/pattern_rewrite/drr/ir_tensor.h rename to paddle/ir/pattern_rewrite/drr/ir_value.h index 1f78c43f139be..7a9250a268aab 100644 --- a/paddle/ir/pattern_rewrite/drr/ir_tensor.h +++ b/paddle/ir/pattern_rewrite/drr/ir_value.h @@ -47,10 +47,10 @@ class IrDtype { class IrValue : public TensorInterface { public: - explicit IrValue(const ir::Value* value) + explicit IrValue(const ir::Value& value) : shape_( - &value->type().dyn_cast().dims()), - dtype_(&value->type() + &value.type().dyn_cast().dims()), + dtype_(&value.type() .dyn_cast() .dtype()) {} diff --git a/paddle/ir/pattern_rewrite/drr/match_context_impl.h b/paddle/ir/pattern_rewrite/drr/match_context_impl.h index 962849ef23448..5ce94e89b37aa 100644 --- a/paddle/ir/pattern_rewrite/drr/match_context_impl.h +++ b/paddle/ir/pattern_rewrite/drr/match_context_impl.h @@ -20,7 +20,7 @@ #include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/pattern_rewrite/drr/api/tensor_interface.h" #include "paddle/ir/pattern_rewrite/drr/ir_operation.h" -#include "paddle/ir/pattern_rewrite/drr/ir_tensor.h" +#include "paddle/ir/pattern_rewrite/drr/ir_value.h" namespace ir { namespace drr { diff --git a/paddle/ir/pattern_rewrite/drr/pattern_graph.cc b/paddle/ir/pattern_rewrite/drr/pattern_graph.cc index 22eef7254ee44..8d8ce90dbb142 100755 --- a/paddle/ir/pattern_rewrite/drr/pattern_graph.cc +++ b/paddle/ir/pattern_rewrite/drr/pattern_graph.cc @@ -77,7 +77,7 @@ void PatternGraph::UpdateTmpTensor(const id_type &tmp_tensor_id, auto tmp_tensor = id2owned_tensor_[tmp_tensor_id]; id2owned_tensor_.erase(tmp_tensor_id); - tmp_tensor->SetName(new_tensor_id); + tmp_tensor->set_name(new_tensor_id); id2owned_tensor_[new_tensor_id] = tmp_tensor; }