Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanlehome committed Aug 9, 2023
1 parent 43e1f02 commit 68f2863
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 19 deletions.
5 changes: 0 additions & 5 deletions paddle/ir/pattern_rewrite/drr/api/match_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,6 @@ const TensorInterface& MatchContext::Tensor(
return impl_->Tensor(tensor_name);
}

const IrOperation& MatchContext::Operation(
const std::string& op_name) const {
return impl_->Operation(op_name);
}

template <typename T>
T MatchContext::Attr(const std::string& attr_name) const {
return impl_->Attr<T>(attr_name);
Expand Down
2 changes: 0 additions & 2 deletions paddle/ir/pattern_rewrite/drr/api/match_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ class MatchContext final {

const TensorInterface& Tensor(const std::string& tensor_name) const;

const IrOperation& Operation(const std::string& op_name) const;

template <typename T>
T Attr(const std::string& attr_name) const;

Expand Down
24 changes: 12 additions & 12 deletions paddle/ir/pattern_rewrite/drr/pattern_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,20 @@

#include "paddle/ir/pattern_rewrite/drr/pattern_graph.h"

#include <glog/logging.h>
#include <iostream>

#include "paddle/ir/core/enforce.h"
#include "paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.h"

namespace ir {
namespace drr {

const drr::OpCall &
PatternGraph::AddOpCall(const std::shared_ptr<drr::OpCall> &op_call) {
const drr::OpCall &PatternGraph::AddOpCall(
const std::shared_ptr<drr::OpCall> &op_call) {
owned_op_call_.push_back(op_call);
for (const auto &input : op_call->inputs()) {
const auto &tensor_id = input->name();
CHECK(id2owned_tensor_.count(tensor_id));
IR_ENFORCE(id2owned_tensor_.count(tensor_id));
id2owned_tensor_.at(tensor_id)->AddConsumer(op_call.get());

if (input->producer() == nullptr) {
Expand All @@ -39,24 +39,24 @@ PatternGraph::AddOpCall(const std::shared_ptr<drr::OpCall> &op_call) {
}
for (auto &output : op_call->outputs()) {
const auto &out_tensor_id = output->name();
CHECK(id2owned_tensor_.count(out_tensor_id));
IR_ENFORCE(id2owned_tensor_.count(out_tensor_id));
id2owned_tensor_[output->name()]->set_producer(op_call.get());
}
return *owned_op_call_.back();
}

const drr::Tensor &
PatternGraph::AddTensor(const std::shared_ptr<drr::Tensor> &tensor) {
const drr::Tensor &PatternGraph::AddTensor(
const std::shared_ptr<drr::Tensor> &tensor) {
if (id2owned_tensor_.find(tensor->name()) == id2owned_tensor_.end()) {
id2owned_tensor_[tensor->name()] = tensor;
output_tensors_.insert(tensor->name());
}
return *id2owned_tensor_[tensor->name()];
}

drr::Tensor &
PatternGraph::AddTmpTensor(const std::shared_ptr<drr::Tensor> &tensor) {
CHECK(id2owned_tensor_.find(tensor->name()) == id2owned_tensor_.end());
drr::Tensor &PatternGraph::AddTmpTensor(
const std::shared_ptr<drr::Tensor> &tensor) {
IR_ENFORCE(id2owned_tensor_.count(tensor->name()) == 0);
id2owned_tensor_[tensor->name()] = tensor;
output_tensors_.insert(tensor->name());
return *id2owned_tensor_[tensor->name()];
Expand Down Expand Up @@ -124,5 +124,5 @@ const OpCall *SourcePatternGraph::AnchorNode() const {
return id2owned_tensor_.at(*output_tensors_.begin())->producer();
}

} // namespace drr
} // namespace ir
} // namespace drr
} // namespace ir

0 comments on commit 68f2863

Please sign in to comment.