Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR+CINN]Fix cinn_op.GroupOp insert bug for WriteAfterRead #62529

Merged
merged 7 commits into from
Mar 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion paddle/cinn/hlir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,13 @@ pir::Block* GroupOp::block() {
return &region.front();
}

std::vector<pir::Operation*> GroupOp::GetOperators() {
pir::Block* GroupOp::block() const {
pir::Region& region = (*this)->region(0);
CHECK(!region.empty());
return &region.front();
}

std::vector<pir::Operation*> GroupOp::GetOperators() const {
std::vector<pir::Operation*> rt_ops;
for (auto& op : *block()) {
rt_ops.push_back(&op);
Expand Down
3 changes: 2 additions & 1 deletion paddle/cinn/hlir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ class IR_API GroupOp
const cinn::dialect::GroupInfo &group_info);

pir::Block *block();
std::vector<pir::Operation *> GetOperators();
pir::Block *block() const;
std::vector<pir::Operation *> GetOperators() const;

bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis);

Expand Down
48 changes: 48 additions & 0 deletions paddle/fluid/pir/transforms/build_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ namespace {
using GroupOpsVec = std::vector<pir::Operation*>;
using CompatibleInfo = cinn::hlir::framework::pir::CompatibleInfo;

void VerifyOperationOrder(const pir::Block& block);

class BuildCinnPass : public pir::Pass {
public:
BuildCinnPass() : pir::Pass("build_cinn_pass", /*opt_level=*/1) {}
Expand All @@ -33,6 +35,7 @@ class BuildCinnPass : public pir::Pass {
for (uint32_t i = 0; i < op->num_regions(); ++i) {
for (auto& block : op->region(i)) {
ProcessBlock(&block);
VerifyOperationOrder(block);
}
}
}
Expand All @@ -56,6 +59,51 @@ class BuildCinnPass : public pir::Pass {
}
}
};

void VerifyOperationOrder(const pir::Block& block) {
auto order_info =
[&]() -> std::unordered_map<const pir::Operation*, int64_t> {
std::unordered_map<const pir::Operation*, int64_t> map;
// initialize the position index with block size by default.
const int64_t block_size = block.size();
for (auto& op : block) map[&op] = block_size;
return map;
}();
const auto& CheckOpOrder = [&](const pir::Operation* op) -> void {
const pir::Operation* current_op = op;
for (auto& value : op->operands_source()) {
if (!value || !value.defining_op()) continue;
pir::Operation* defining_op = value.defining_op();
if (order_info.count(defining_op) == 0) continue;
if (op->GetParentOp() &&
op->GetParentOp()->isa<cinn::dialect::GroupOp>()) {
current_op = op->GetParentOp();
}
CHECK(order_info.at(defining_op) < order_info.at(current_op))
<< "The order of operations is not correct!"
<< " Received defining_op(" << defining_op->id() << " "
<< order_info.at(defining_op) << ") is behind current_op("
<< current_op->id() << " " << order_info.at(current_op) << ")";
}
};
const auto& CheckGroupOpOrder = [&](pir::Operation* op) -> void {
auto group_op = op->dyn_cast<cinn::dialect::GroupOp>();
for (auto& inner_op : *group_op.block()) {
CheckOpOrder(&inner_op);
}
};

int64_t index = 0;
for (auto& op : block) {
order_info[&op] = index++;
if (op.isa<cinn::dialect::GroupOp>()) {
CheckGroupOpOrder(&op);
} else {
CheckOpOrder(&op);
}
}
}

} // namespace

namespace pir {
Expand Down
70 changes: 70 additions & 0 deletions paddle/fluid/pir/transforms/sub_graph_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <memory>

#include <iterator>
#include <queue>
#include <regex>
#include <set>
Expand Down Expand Up @@ -513,6 +514,74 @@ pir::Operation* FindInsertPoint(const GroupOpsVec& group_ops,
}
return insert_point_op;
}

struct IncrementalOrder {
bool operator()(const pir::Operation* lhs, const pir::Operation* rhs) const {
CHECK(lhs->GetParent() == rhs->GetParent())
<< "lhs and rhs should have same parent block.";
auto lhs_iter = lhs->operator Block::ConstIterator();
auto rhs_iter = rhs->operator Block::ConstIterator();
auto end_iter = lhs->GetParent()->end();
while (lhs_iter != end_iter) {
lhs_iter++;
if (lhs_iter == rhs_iter) return true;
if (lhs_iter == end_iter) return false;
}
CHECK(false) << "rhs " << rhs->id() << " is not reachable from lhs "
<< lhs->id();
return false;
}
};

std::unordered_set<pir::Operation*> GetUpstreamOpsAfterPosition(
const pir::Operation* position_op,
const pir::Block* block,
const pir::Operation* op,
std::unordered_set<pir::Operation*>* visited_ops) {
std::unordered_set<pir::Operation*> ops;
const auto& IsInBlock = [](const pir::Operation* src_op,
const pir::Block* block) {
for (auto& op : *block) {
if (src_op == &op) return true;
}
return false;
};

for (auto value : op->operands_source()) {
if (!value || !value.defining_op()) continue;
pir::Operation* defining_op = value.defining_op();
if (visited_ops->count(defining_op)) continue;
visited_ops->insert(defining_op);
if (!IsInBlock(defining_op, block)) continue;
if (IncrementalOrder()(defining_op, position_op)) continue;

ops.insert(defining_op);
auto recursive_ops = GetUpstreamOpsAfterPosition(
position_op, block, defining_op, visited_ops);
ops.insert(recursive_ops.begin(), recursive_ops.end());
}
return ops;
}

void MoveUpstreamOpBeforeGroup(const GroupOpsVec& group_ops,
pir::Block* block,
pir::Operation* insert_point_op) {
const auto moved_ops = [&]() {
std::set<pir::Operation*, IncrementalOrder> ops_set;
std::unordered_set<pir::Operation*> visited_ops;
for (auto& op : group_ops) {
auto upstream_ops =
GetUpstreamOpsAfterPosition(insert_point_op, block, op, &visited_ops);
ops_set.insert(upstream_ops.begin(), upstream_ops.end());
}
return ops_set;
}();

for (auto& op : moved_ops) {
VLOG(5) << "Move " << op->name() << " before " << insert_point_op->name();
op->MoveTo(block, insert_point_op->operator Block::Iterator());
}
}
} // namespace

void ReplaceWithGroupOp(pir::Block* block,
Expand All @@ -527,6 +596,7 @@ void ReplaceWithGroupOp(pir::Block* block,

// step 1: Analysis and insert group op before insert_point.
auto* insert_point = FindInsertPoint(group_ops, outputs);
MoveUpstreamOpBeforeGroup(group_ops, block, insert_point);
builder.set_insertion_point(insert_point);
VLOG(6) << "Insert GroupOp after " << insert_point->name();

Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/include/core/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ class IR_API alignas(8) Operation final

void Verify();

uint64_t id() { return id_; }
uint64_t id() const { return id_; }

private:
DISABLE_COPY_AND_ASSIGN(Operation);
Expand Down