From 4e425115c55f8dd4b258d2cdae4a4764c2420166 Mon Sep 17 00:00:00 2001 From: Yu Cheng Date: Thu, 31 Jul 2025 11:50:48 +0000 Subject: [PATCH] [Enhancement] Refactored buffer detection logic in warp_specialized_rewriter.cc - Renamed TMAFinder to ProducerBufferDetector and improved handling of CallNode and BufferLoadNode. - This change aims to enhance code maintainability and performance by more accurately tracking producer buffer usage. --- src/transform/warp_specialized_rewriter.cc | 68 +++++++++++++++------- 1 file changed, 48 insertions(+), 20 deletions(-) diff --git a/src/transform/warp_specialized_rewriter.cc b/src/transform/warp_specialized_rewriter.cc index a5c9cf8bb..c2799bfed 100644 --- a/src/transform/warp_specialized_rewriter.cc +++ b/src/transform/warp_specialized_rewriter.cc @@ -23,24 +23,45 @@ using arith::IRVisitorWithAnalyzer; enum class Role { kConsumer, kProducer, kBoth }; -class TMAFinder : public StmtExprVisitor { +class ProducerBufferDetector : public StmtExprVisitor { public: - void clear() { has_tma_load_ = false; } + ProducerBufferDetector( + std::unordered_set cur_producer_buffers) + : cur_producer_buffers_(cur_producer_buffers) {} + + void clear() { has_producer_buffer_ = false; } void VisitExpr_(const CallNode *call) final { if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { - has_tma_load_ = true; + has_producer_buffer_ = true; } + StmtExprVisitor::VisitExpr_(call); } - bool has_tma_load_ = false; + void VisitExpr_(const BufferLoadNode *op) final { + if (cur_producer_buffers_.count(op->buffer.get())) { + has_producer_buffer_ = true; + } + StmtExprVisitor::VisitExpr_(op); + } + + bool has_producer_buffer_ = false; + std::unordered_set cur_producer_buffers_; }; class ProducerUsedBufferFinder : public StmtExprVisitor { public: auto FindProducerusedBuffer(Stmt stmt) { - VisitStmt(stmt); - return used_in_producer_cond_; + producer_buffers_.clear(); + std::unordered_set last_producer_buffers_; + for (;;) { + VisitStmt(stmt); + if (producer_buffers_ == last_producer_buffers_) { + break; + } + last_producer_buffers_ = producer_buffers_; + } + return producer_buffers_; } void InsertBuffer(const PrimExpr &expr) { @@ -48,44 +69,51 @@ class ProducerUsedBufferFinder : public StmtExprVisitor { VarUseDefAnalyzer usage(Array{}); usage(expr); for (const auto &buffer : usage.buffer_use_count_) { - used_in_producer_cond_.insert(buffer.first); + producer_buffers_.insert(buffer.first); } } void VisitStmt_(const IfThenElseNode *op) final { - TMAFinder tma_finder; - tma_finder(op->then_case); + ProducerBufferDetector producer_buffer_detector(producer_buffers_); + producer_buffer_detector(op->then_case); if (op->else_case.defined()) { - tma_finder(op->else_case.value()); + producer_buffer_detector(op->else_case.value()); } - if (tma_finder.has_tma_load_) { + if (producer_buffer_detector.has_producer_buffer_) { InsertBuffer(op->condition); } StmtExprVisitor::VisitStmt_(op); } void VisitStmt_(const ForNode *op) final { - TMAFinder tma_finder; - tma_finder(op->body); - if (tma_finder.has_tma_load_) { + ProducerBufferDetector producer_buffer_detector(producer_buffers_); + producer_buffer_detector(op->body); + if (producer_buffer_detector.has_producer_buffer_) { InsertBuffer(op->min); InsertBuffer(op->extent); } StmtExprVisitor::VisitStmt_(op); } + void VisitStmt_(const BufferStoreNode *op) final { + if (producer_buffers_.count(op->buffer.get())) { + InsertBuffer(op->value); + } + StmtExprVisitor::VisitStmt_(op); + } + void VisitExpr_(const CallNode *op) final { if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) { for (auto arg : op->args) { if (auto buffer_load = arg.as()) { - used_in_producer_cond_.insert(buffer_load->buffer.get()); + producer_buffers_.insert(buffer_load->buffer.get()); } } } } private: - std::unordered_set used_in_producer_cond_; + std::unordered_set producer_buffers_; }; class WarpSpecializedRoleMarker : public StmtVisitor { @@ -95,7 +123,7 @@ class WarpSpecializedRoleMarker : public StmtVisitor { void Prepare(const Stmt &stmt) { ProducerUsedBufferFinder finder; - used_in_producer_cond_ = finder.FindProducerusedBuffer(stmt); + producer_buffers_ = finder.FindProducerusedBuffer(stmt); } Role GetRole(const StmtNode *stmt) const { @@ -123,7 +151,7 @@ class WarpSpecializedRoleMarker : public StmtVisitor { void VisitStmt_(const BufferStoreNode *op) final { bool is_shared_store = op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared"; - if (used_in_producer_cond_.count(op->buffer.get())) { + if (producer_buffers_.count(op->buffer.get())) { SetRole(op, Role::kBoth); return; } @@ -207,7 +235,7 @@ class WarpSpecializedRoleMarker : public StmtVisitor { std::unordered_map map_; bool has_simt_copy_ = false; bool has_bulk_copy_ = false; - std::unordered_set used_in_producer_cond_; + std::unordered_set producer_buffers_; }; static PrimExpr makeGetBarrier(PrimExpr barrier_id) { @@ -1112,7 +1140,7 @@ class WarpSpecializedRewriter : public StmtExprMutator { auto inc_reg_stmt = Evaluate(0); auto dec_reg_stmt = Evaluate(0); - if (dec_reg >= 0 && inc_reg >= 0) { + if (dec_reg >= 0 && inc_reg >= 0 && !marker.HasSimtCopy()) { inc_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(), {inc_reg == 0 ? 240 : inc_reg, 1})); dec_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(),