-
Notifications
You must be signed in to change notification settings - Fork 333
[Enhancement] Refactored buffer detection logic in warp_specialized_rewriter.cc #685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -23,69 +23,97 @@ using arith::IRVisitorWithAnalyzer; | |||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| enum class Role { kConsumer, kProducer, kBoth }; | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| class TMAFinder : public StmtExprVisitor { | ||||||||||||||||||||||||||||
| class ProducerBufferDetector : public StmtExprVisitor { | ||||||||||||||||||||||||||||
| public: | ||||||||||||||||||||||||||||
| void clear() { has_tma_load_ = false; } | ||||||||||||||||||||||||||||
| ProducerBufferDetector( | ||||||||||||||||||||||||||||
| std::unordered_set<const BufferNode *> cur_producer_buffers) | ||||||||||||||||||||||||||||
| : cur_producer_buffers_(cur_producer_buffers) {} | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| void clear() { has_producer_buffer_ = false; } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| void VisitExpr_(const CallNode *call) final { | ||||||||||||||||||||||||||||
| if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { | ||||||||||||||||||||||||||||
| has_tma_load_ = true; | ||||||||||||||||||||||||||||
| has_producer_buffer_ = true; | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| StmtExprVisitor::VisitExpr_(call); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| bool has_tma_load_ = false; | ||||||||||||||||||||||||||||
| void VisitExpr_(const BufferLoadNode *op) final { | ||||||||||||||||||||||||||||
| if (cur_producer_buffers_.count(op->buffer.get())) { | ||||||||||||||||||||||||||||
| has_producer_buffer_ = true; | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| StmtExprVisitor::VisitExpr_(op); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| bool has_producer_buffer_ = false; | ||||||||||||||||||||||||||||
| std::unordered_set<const BufferNode *> cur_producer_buffers_; | ||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To correspond with the change in the constructor to pass by reference, this member variable should also be a
Suggested change
|
||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| class ProducerUsedBufferFinder : public StmtExprVisitor { | ||||||||||||||||||||||||||||
| public: | ||||||||||||||||||||||||||||
| auto FindProducerusedBuffer(Stmt stmt) { | ||||||||||||||||||||||||||||
| VisitStmt(stmt); | ||||||||||||||||||||||||||||
| return used_in_producer_cond_; | ||||||||||||||||||||||||||||
| producer_buffers_.clear(); | ||||||||||||||||||||||||||||
| std::unordered_set<const BufferNode *> last_producer_buffers_; | ||||||||||||||||||||||||||||
| for (;;) { | ||||||||||||||||||||||||||||
| VisitStmt(stmt); | ||||||||||||||||||||||||||||
| if (producer_buffers_ == last_producer_buffers_) { | ||||||||||||||||||||||||||||
| break; | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| last_producer_buffers_ = producer_buffers_; | ||||||||||||||||||||||||||||
|
Comment on lines
+57
to
+62
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current fixed-point loop copies the
Suggested change
|
||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| return producer_buffers_; | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| void InsertBuffer(const PrimExpr &expr) { | ||||||||||||||||||||||||||||
| // Find the buffer that is used in the condition | ||||||||||||||||||||||||||||
| VarUseDefAnalyzer usage(Array<Var>{}); | ||||||||||||||||||||||||||||
| usage(expr); | ||||||||||||||||||||||||||||
| for (const auto &buffer : usage.buffer_use_count_) { | ||||||||||||||||||||||||||||
| used_in_producer_cond_.insert(buffer.first); | ||||||||||||||||||||||||||||
| producer_buffers_.insert(buffer.first); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| void VisitStmt_(const IfThenElseNode *op) final { | ||||||||||||||||||||||||||||
| TMAFinder tma_finder; | ||||||||||||||||||||||||||||
| tma_finder(op->then_case); | ||||||||||||||||||||||||||||
| ProducerBufferDetector producer_buffer_detector(producer_buffers_); | ||||||||||||||||||||||||||||
| producer_buffer_detector(op->then_case); | ||||||||||||||||||||||||||||
| if (op->else_case.defined()) { | ||||||||||||||||||||||||||||
| tma_finder(op->else_case.value()); | ||||||||||||||||||||||||||||
| producer_buffer_detector(op->else_case.value()); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| if (tma_finder.has_tma_load_) { | ||||||||||||||||||||||||||||
| if (producer_buffer_detector.has_producer_buffer_) { | ||||||||||||||||||||||||||||
| InsertBuffer(op->condition); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| StmtExprVisitor::VisitStmt_(op); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| void VisitStmt_(const ForNode *op) final { | ||||||||||||||||||||||||||||
| TMAFinder tma_finder; | ||||||||||||||||||||||||||||
| tma_finder(op->body); | ||||||||||||||||||||||||||||
| if (tma_finder.has_tma_load_) { | ||||||||||||||||||||||||||||
| ProducerBufferDetector producer_buffer_detector(producer_buffers_); | ||||||||||||||||||||||||||||
| producer_buffer_detector(op->body); | ||||||||||||||||||||||||||||
| if (producer_buffer_detector.has_producer_buffer_) { | ||||||||||||||||||||||||||||
| InsertBuffer(op->min); | ||||||||||||||||||||||||||||
| InsertBuffer(op->extent); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| StmtExprVisitor::VisitStmt_(op); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| void VisitStmt_(const BufferStoreNode *op) final { | ||||||||||||||||||||||||||||
| if (producer_buffers_.count(op->buffer.get())) { | ||||||||||||||||||||||||||||
| InsertBuffer(op->value); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| StmtExprVisitor::VisitStmt_(op); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| void VisitExpr_(const CallNode *op) final { | ||||||||||||||||||||||||||||
| if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) { | ||||||||||||||||||||||||||||
| for (auto arg : op->args) { | ||||||||||||||||||||||||||||
| if (auto buffer_load = arg.as<BufferLoadNode>()) { | ||||||||||||||||||||||||||||
| used_in_producer_cond_.insert(buffer_load->buffer.get()); | ||||||||||||||||||||||||||||
| producer_buffers_.insert(buffer_load->buffer.get()); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| private: | ||||||||||||||||||||||||||||
| std::unordered_set<const BufferNode *> used_in_producer_cond_; | ||||||||||||||||||||||||||||
| std::unordered_set<const BufferNode *> producer_buffers_; | ||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| class WarpSpecializedRoleMarker : public StmtVisitor { | ||||||||||||||||||||||||||||
|
|
@@ -95,7 +123,7 @@ class WarpSpecializedRoleMarker : public StmtVisitor { | |||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| void Prepare(const Stmt &stmt) { | ||||||||||||||||||||||||||||
| ProducerUsedBufferFinder finder; | ||||||||||||||||||||||||||||
| used_in_producer_cond_ = finder.FindProducerusedBuffer(stmt); | ||||||||||||||||||||||||||||
| producer_buffers_ = finder.FindProducerusedBuffer(stmt); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| Role GetRole(const StmtNode *stmt) const { | ||||||||||||||||||||||||||||
|
|
@@ -123,7 +151,7 @@ class WarpSpecializedRoleMarker : public StmtVisitor { | |||||||||||||||||||||||||||
| void VisitStmt_(const BufferStoreNode *op) final { | ||||||||||||||||||||||||||||
| bool is_shared_store = | ||||||||||||||||||||||||||||
| op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared"; | ||||||||||||||||||||||||||||
| if (used_in_producer_cond_.count(op->buffer.get())) { | ||||||||||||||||||||||||||||
| if (producer_buffers_.count(op->buffer.get())) { | ||||||||||||||||||||||||||||
| SetRole(op, Role::kBoth); | ||||||||||||||||||||||||||||
| return; | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
@@ -207,7 +235,7 @@ class WarpSpecializedRoleMarker : public StmtVisitor { | |||||||||||||||||||||||||||
| std::unordered_map<const StmtNode *, Role> map_; | ||||||||||||||||||||||||||||
| bool has_simt_copy_ = false; | ||||||||||||||||||||||||||||
| bool has_bulk_copy_ = false; | ||||||||||||||||||||||||||||
| std::unordered_set<const BufferNode *> used_in_producer_cond_; | ||||||||||||||||||||||||||||
| std::unordered_set<const BufferNode *> producer_buffers_; | ||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| static PrimExpr makeGetBarrier(PrimExpr barrier_id) { | ||||||||||||||||||||||||||||
|
|
@@ -1112,7 +1140,7 @@ class WarpSpecializedRewriter : public StmtExprMutator { | |||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| auto inc_reg_stmt = Evaluate(0); | ||||||||||||||||||||||||||||
| auto dec_reg_stmt = Evaluate(0); | ||||||||||||||||||||||||||||
| if (dec_reg >= 0 && inc_reg >= 0) { | ||||||||||||||||||||||||||||
| if (dec_reg >= 0 && inc_reg >= 0 && !marker.HasSimtCopy()) { | ||||||||||||||||||||||||||||
| inc_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(), | ||||||||||||||||||||||||||||
| {inc_reg == 0 ? 240 : inc_reg, 1})); | ||||||||||||||||||||||||||||
| dec_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(), | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To improve performance, consider passing
cur_producer_buffersbyconstreference to avoid copying the set. This is safe becauseProducerBufferDetectoris short-lived and the referenced set will outlive it.