Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 48 additions & 20 deletions src/transform/warp_specialized_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,69 +23,97 @@ using arith::IRVisitorWithAnalyzer;

enum class Role { kConsumer, kProducer, kBoth };

class TMAFinder : public StmtExprVisitor {
class ProducerBufferDetector : public StmtExprVisitor {
public:
void clear() { has_tma_load_ = false; }
ProducerBufferDetector(
std::unordered_set<const BufferNode *> cur_producer_buffers)
: cur_producer_buffers_(cur_producer_buffers) {}
Comment on lines +28 to +30
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve performance, consider passing cur_producer_buffers by const reference to avoid copying the set. This is safe because ProducerBufferDetector is short-lived and the referenced set will outlive it.

Suggested change
ProducerBufferDetector(
std::unordered_set<const BufferNode *> cur_producer_buffers)
: cur_producer_buffers_(cur_producer_buffers) {}
ProducerBufferDetector(
const std::unordered_set<const BufferNode *>& cur_producer_buffers)
: cur_producer_buffers_(cur_producer_buffers) {}


void clear() { has_producer_buffer_ = false; }

void VisitExpr_(const CallNode *call) final {
if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
has_tma_load_ = true;
has_producer_buffer_ = true;
}
StmtExprVisitor::VisitExpr_(call);
}

bool has_tma_load_ = false;
void VisitExpr_(const BufferLoadNode *op) final {
if (cur_producer_buffers_.count(op->buffer.get())) {
has_producer_buffer_ = true;
}
StmtExprVisitor::VisitExpr_(op);
}

bool has_producer_buffer_ = false;
std::unordered_set<const BufferNode *> cur_producer_buffers_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To correspond with the change in the constructor to pass by reference, this member variable should also be a const reference to avoid making a copy of the set.

Suggested change
std::unordered_set<const BufferNode *> cur_producer_buffers_;
const std::unordered_set<const BufferNode *>& cur_producer_buffers_;

};

class ProducerUsedBufferFinder : public StmtExprVisitor {
public:
auto FindProducerusedBuffer(Stmt stmt) {
VisitStmt(stmt);
return used_in_producer_cond_;
producer_buffers_.clear();
std::unordered_set<const BufferNode *> last_producer_buffers_;
for (;;) {
VisitStmt(stmt);
if (producer_buffers_ == last_producer_buffers_) {
break;
}
last_producer_buffers_ = producer_buffers_;
Comment on lines +57 to +62
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current fixed-point loop copies the producer_buffers_ set and performs an element-wise comparison in each iteration. This can be inefficient, especially for large sets. A more performant approach is to check if the size of the set has changed after visiting the statement. This avoids both the expensive set copy and the element-wise comparison.

Suggested change
for (;;) {
VisitStmt(stmt);
if (producer_buffers_ == last_producer_buffers_) {
break;
}
last_producer_buffers_ = producer_buffers_;
for (;;) {
size_t num_buffers_before = producer_buffers_.size();
VisitStmt(stmt);
if (producer_buffers_.size() == num_buffers_before) {
break;
}
}

}
return producer_buffers_;
}

void InsertBuffer(const PrimExpr &expr) {
// Find the buffer that is used in the condition
VarUseDefAnalyzer usage(Array<Var>{});
usage(expr);
for (const auto &buffer : usage.buffer_use_count_) {
used_in_producer_cond_.insert(buffer.first);
producer_buffers_.insert(buffer.first);
}
}

void VisitStmt_(const IfThenElseNode *op) final {
TMAFinder tma_finder;
tma_finder(op->then_case);
ProducerBufferDetector producer_buffer_detector(producer_buffers_);
producer_buffer_detector(op->then_case);
if (op->else_case.defined()) {
tma_finder(op->else_case.value());
producer_buffer_detector(op->else_case.value());
}
if (tma_finder.has_tma_load_) {
if (producer_buffer_detector.has_producer_buffer_) {
InsertBuffer(op->condition);
}
StmtExprVisitor::VisitStmt_(op);
}

void VisitStmt_(const ForNode *op) final {
TMAFinder tma_finder;
tma_finder(op->body);
if (tma_finder.has_tma_load_) {
ProducerBufferDetector producer_buffer_detector(producer_buffers_);
producer_buffer_detector(op->body);
if (producer_buffer_detector.has_producer_buffer_) {
InsertBuffer(op->min);
InsertBuffer(op->extent);
}
StmtExprVisitor::VisitStmt_(op);
}

void VisitStmt_(const BufferStoreNode *op) final {
if (producer_buffers_.count(op->buffer.get())) {
InsertBuffer(op->value);
}
StmtExprVisitor::VisitStmt_(op);
}

void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
for (auto arg : op->args) {
if (auto buffer_load = arg.as<BufferLoadNode>()) {
used_in_producer_cond_.insert(buffer_load->buffer.get());
producer_buffers_.insert(buffer_load->buffer.get());
}
}
}
}

private:
std::unordered_set<const BufferNode *> used_in_producer_cond_;
std::unordered_set<const BufferNode *> producer_buffers_;
};

class WarpSpecializedRoleMarker : public StmtVisitor {
Expand All @@ -95,7 +123,7 @@ class WarpSpecializedRoleMarker : public StmtVisitor {

void Prepare(const Stmt &stmt) {
ProducerUsedBufferFinder finder;
used_in_producer_cond_ = finder.FindProducerusedBuffer(stmt);
producer_buffers_ = finder.FindProducerusedBuffer(stmt);
}

Role GetRole(const StmtNode *stmt) const {
Expand Down Expand Up @@ -123,7 +151,7 @@ class WarpSpecializedRoleMarker : public StmtVisitor {
void VisitStmt_(const BufferStoreNode *op) final {
bool is_shared_store =
op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared";
if (used_in_producer_cond_.count(op->buffer.get())) {
if (producer_buffers_.count(op->buffer.get())) {
SetRole(op, Role::kBoth);
return;
}
Expand Down Expand Up @@ -207,7 +235,7 @@ class WarpSpecializedRoleMarker : public StmtVisitor {
std::unordered_map<const StmtNode *, Role> map_;
bool has_simt_copy_ = false;
bool has_bulk_copy_ = false;
std::unordered_set<const BufferNode *> used_in_producer_cond_;
std::unordered_set<const BufferNode *> producer_buffers_;
};

static PrimExpr makeGetBarrier(PrimExpr barrier_id) {
Expand Down Expand Up @@ -1112,7 +1140,7 @@ class WarpSpecializedRewriter : public StmtExprMutator {

auto inc_reg_stmt = Evaluate(0);
auto dec_reg_stmt = Evaluate(0);
if (dec_reg >= 0 && inc_reg >= 0) {
if (dec_reg >= 0 && inc_reg >= 0 && !marker.HasSimtCopy()) {
inc_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(),
{inc_reg == 0 ? 240 : inc_reg, 1}));
dec_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(),
Expand Down