Skip to content

Commit

Permalink
refine code
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanlehome committed Aug 10, 2023
1 parent 68f2863 commit 7abe8db
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 25 deletions.
8 changes: 4 additions & 4 deletions paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class DrrPatternContext {
std::vector<std::shared_ptr<const drr::Op>> owned_ops_;
};

class Op : public std::enable_shared_from_this<Op> {
class Op {
public:
const std::string& name() const { return op_type_name_; }

Expand Down Expand Up @@ -143,7 +143,7 @@ class Op : public std::enable_shared_from_this<Op> {
PatternGraph* pattern_graph_;
};

class Tensor : public std::enable_shared_from_this<Tensor> {
class Tensor {
public:
const std::string& DebugName() const;

Expand All @@ -155,7 +155,7 @@ class Tensor : public std::enable_shared_from_this<Tensor> {

const std::string& name() const { return name_; }

void SetName(const std::string& name) { name_ = name; }
void set_name(const std::string& name) { name_ = name; }

OpCall* producer() const { return producer_; }

Expand All @@ -182,7 +182,7 @@ class Tensor : public std::enable_shared_from_this<Tensor> {
PatternGraph* pattern_graph_;
};

class OpCall : public std::enable_shared_from_this<OpCall> {
class OpCall {
public:
OpCall(const Op* op,
const std::vector<const Tensor*>& inputs,
Expand Down
2 changes: 1 addition & 1 deletion paddle/ir/pattern_rewrite/drr/api/tensor_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

#include "paddle/ir/pattern_rewrite/drr/api/tensor_interface.h"
#include "paddle/ir/pattern_rewrite/drr/ir_tensor.h"
#include "paddle/ir/pattern_rewrite/drr/ir_value.h"

namespace ir {
namespace drr {
Expand Down
32 changes: 17 additions & 15 deletions paddle/ir/pattern_rewrite/drr/drr_rewrite_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {
// Match
auto* anchor = source_pattern_graph_->AnchorNode();
IR_ENFORCE(anchor);
std::unordered_set<OpCall*> drr_visited;
std::unordered_set<Operation*> ir_visited;
std::queue<OpCall*> drr_q;
std::queue<ir::Operation*> ir_q;
drr_q.push(anchor.get());
std::unordered_set<const OpCall*> drr_visited;
std::unordered_set<const Operation*> ir_visited;
std::queue<const OpCall*> drr_q;
std::queue<const ir::Operation*> ir_q;
drr_q.push(anchor);
ir_q.push(op);
drr_visited.insert(anchor.get());
drr_visited.insert(anchor);
ir_visited.insert(op);
match_context_impl_->BindIrOperation(op->name(),
std::make_shared<IrOperation>(op));
Expand All @@ -87,17 +87,18 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {
const auto& drr_input_tensors = drr_node->inputs();
auto ir_input_value_size = ir_node->num_operands();
// check input's size
if (drr_input_tensors.sizes() != ir_input_value_size) {
if (drr_input_tensors.size() != ir_input_value_size) {
Matched = false;
break;
}
for (int i = 0; i < drr_input_tensors.size(); ++i) {
for (size_t i = 0; i < drr_input_tensors.size(); ++i) {
if (!Matched) break;
// check brother ops
auto drr_brother_ops = drr_input_tensors[i]->consumers();
auto ir_input_value = ir_node->operand(i);
match_context_impl_->BindIrValue(drr_input_tensors[i]->name(),
ir_input_value);
auto ir_input_value = ir_node->operand(i).source();
match_context_impl_->BindIrValue(
drr_input_tensors[i]->name(),
std::make_shared<IrValue>(ir_input_value));
if (drr_brother_ops.size() != ir_input_value.use_count()) {
Matched = false;
break;
Expand Down Expand Up @@ -154,17 +155,18 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {
const auto& drr_output_tensors = drr_node->outputs();
auto ir_output_value_size = ir_node->num_results();
// check output's size
if (drr_output_tensors.sizes() != ir_output_value_size) {
if (drr_output_tensors.size() != ir_output_value_size) {
Matched = false;
break;
}
for (int i = 0; i < drr_output_tensors.size(); ++i) {
for (size_t i = 0; i < drr_output_tensors.size(); ++i) {
if (!Matched) break;
// check child ops
auto drr_child_ops = drr_output_tensors[i]->consumers();
auto ir_output_value = ir_node->result(i);
match_context_impl_->BindIrValue(drr_output_tensors[i]->name(),
ir_output_value);
match_context_impl_->BindIrValue(
drr_output_tensors[i]->name(),
std::make_shared<IrValue>(ir_output_value));
if (drr_child_ops.size() != ir_output_value.use_count()) {
Matched = false;
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ class IrDtype {

class IrValue : public TensorInterface {
public:
explicit IrValue(const ir::Value* value)
explicit IrValue(const ir::Value& value)
: shape_(
&value->type().dyn_cast<paddle::dialect::DenseTensorType>().dims()),
dtype_(&value->type()
&value.type().dyn_cast<paddle::dialect::DenseTensorType>().dims()),
dtype_(&value.type()
.dyn_cast<paddle::dialect::DenseTensorType>()
.dtype()) {}

Expand Down
2 changes: 1 addition & 1 deletion paddle/ir/pattern_rewrite/drr/match_context_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/pattern_rewrite/drr/api/tensor_interface.h"
#include "paddle/ir/pattern_rewrite/drr/ir_operation.h"
#include "paddle/ir/pattern_rewrite/drr/ir_tensor.h"
#include "paddle/ir/pattern_rewrite/drr/ir_value.h"

namespace ir {
namespace drr {
Expand Down
2 changes: 1 addition & 1 deletion paddle/ir/pattern_rewrite/drr/pattern_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ void PatternGraph::UpdateTmpTensor(const id_type &tmp_tensor_id,

auto tmp_tensor = id2owned_tensor_[tmp_tensor_id];
id2owned_tensor_.erase(tmp_tensor_id);
tmp_tensor->SetName(new_tensor_id);
tmp_tensor->set_name(new_tensor_id);
id2owned_tensor_[new_tensor_id] = tmp_tensor;
}

Expand Down

0 comments on commit 7abe8db

Please sign in to comment.