-
Notifications
You must be signed in to change notification settings - Fork 333
[Feature] Add ptx_cp_async_barrier_noinc intrinsic and related functionality #809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,99 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| /*! | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| * \file warp_specialized_rewriter.h | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| * \brief tools for warp-specialized-related analysis and transformation | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| #pragma once | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include "arith/ir_visitor_with_analyzer.h" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include "tir/analysis/var_use_def_analysis.h" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include <tvm/ffi/reflection/registry.h> | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include <tvm/tir/analysis.h> | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include <tvm/tir/builtin.h> | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include <tvm/tir/op.h> | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include <tvm/tir/stmt_functor.h> | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include <tvm/tir/transform.h> | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include <utility> | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include "../op/builtin.h" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include "./common/collector.h" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include "runtime/thread_storage_scope.h" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include "tir/transforms/ir_utils.h" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| namespace tvm { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| namespace tl { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| using namespace tir; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| using namespace runtime; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| using arith::IRVisitorWithAnalyzer; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| class WarpSpecializedDetector : public IRVisitorWithAnalyzer { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| public: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| // return true means this aws will be disabled | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| static bool Detect(const Stmt &stmt, bool skip_thread_partition = false) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| WarpSpecializedDetector detector; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| detector.VisitStmt(stmt); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (detector.has_warp_specialization_) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| LOG(WARNING) << "Auto warp specialization will be disabled because warp " | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| "specialization is manually enabled"; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| return true; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (detector.has_tma_op_ && detector.has_mbarrier_op_) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| LOG(WARNING) << "Auto warp specialization will be disabled because TMA " | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| "and mbarrier are both present"; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| return true; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| return false; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| WarpSpecializedDetector() { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| has_tma_op_ = false; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| has_mbarrier_op_ = false; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| has_warp_specialization_ = false; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| private: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| void VisitStmt_(const EvaluateNode *op) final { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (const CallNode *call = op->value.as<CallNode>()) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (call->op.same_as(create_list_of_mbarrier()) || | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| call->op.same_as(mbarrier_wait_parity()) || | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| call->op.same_as(builtin::ptx_arrive_barrier()) || | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| call->op.same_as(builtin::ptx_cp_async_barrier())) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| has_mbarrier_op_ = true; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+59
to
+64
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new intrinsic if (call->op.same_as(create_list_of_mbarrier()) ||
call->op.same_as(mbarrier_wait_parity()) ||
call->op.same_as(builtin::ptx_arrive_barrier()) ||
call->op.same_as(builtin::ptx_cp_async_barrier()) ||
call->op.same_as(ptx_cp_async_barrier_noinc())) {
has_mbarrier_op_ = true;
} |
||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| IRVisitorWithAnalyzer::VisitStmt_(op); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+56
to
+68
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Detector misses new noinc intrinsic; may under-detect mbarrier usage. Add ptx_cp_async_barrier_noinc() (and optionally other mbarrier-related ops) to the mbarrier set to keep gating correct with the new intrinsic. - if (call->op.same_as(create_list_of_mbarrier()) ||
- call->op.same_as(mbarrier_wait_parity()) ||
- call->op.same_as(builtin::ptx_arrive_barrier()) ||
- call->op.same_as(builtin::ptx_cp_async_barrier())) {
+ if (call->op.same_as(create_list_of_mbarrier()) ||
+ call->op.same_as(mbarrier_wait_parity()) ||
+ call->op.same_as(builtin::ptx_arrive_barrier()) ||
+ call->op.same_as(builtin::ptx_cp_async_barrier()) ||
+ call->op.same_as(builtin::ptx_cp_async_barrier_noinc())) {
has_mbarrier_op_ = true;
}📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||
| void VisitExpr_(const CallNode *op) final { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) || | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| op->op.same_as(set_max_nreg())) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| has_tma_op_ = true; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| IRVisitorWithAnalyzer::VisitExpr_(op); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+69
to
+75
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Classifying set_max_nreg as “TMA” causes false disables.
- if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) ||
- op->op.same_as(set_max_nreg())) {
+ if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
has_tma_op_ = true;
}If you still want to track register hints, add a separate flag (e.g., 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| void VisitStmt_(const AttrStmtNode *op) final { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (op->attr_key == "warp_specialize" && | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| op->value.as<IntImmNode>()->value == 1) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| has_warp_specialization_ = true; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (op->attr_key == tir::attr::thread_extent) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| IterVar iv = Downcast<IterVar>(op->node); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (iv->thread_tag == "threadIdx.x") { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| ICHECK(iv->dom->extent.as<IntImmNode>()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| thread_var_ = iv; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| IRVisitorWithAnalyzer::VisitStmt_(op); | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+77
to
+90
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Verification agent 🧩 Analysis chainAttribute key for manual WS may be inconsistent. Detector checks I can align the detector once we confirm the canonical key. 🏁 Script executed: #!/bin/bash
# Verify which attribute keys are used across the repo
rg -nP '"warp_specialize"\s*|kWarpSpecializationScope' -C2Length of output: 4489 Replace "warp_specialize" with attr::kWarpSpecializationScope in the detector Repo-wide usage and tests use "kWarpSpecializationScope" (attr::kWarpSpecializationScope); "warp_specialize" only appears in ir.cc. Update src/transform/warp_specialized_rewriter.h (VisitStmt_, ~lines 77–90) to check attr::kWarpSpecializationScope — or check both keys if backward-compat is required. 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| bool has_tma_op_{false}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| IterVar thread_var_; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| bool has_mbarrier_op_{false}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| bool has_warp_specialization_{false}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| } // namespace tl | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| } // namespace tvm | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -350,3 +350,9 @@ def sync_grid(): | |||||||||||||||||||||||||||||||
| """Synchronize all threads in a grid. | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| return tir.call_intrin("handle", tir.op.Op.get("tl.sync_grid")) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def cp_async_barrier_noinc(barrier_id: Union[int, PrimExpr, tir.Call]): | ||||||||||||||||||||||||||||||||
| """Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc. | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id) | ||||||||||||||||||||||||||||||||
|
Comment on lines
+355
to
+358
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Normalize barrier argument like mbarrier_arrive; passing an int currently breaks codegen. As written, callers can pass an int (per the type hint), which will be emitted as a literal (e.g., 0) into tl::mbarrier_cp_async_arrive_noinc and fail to bind to the expected reference. Mirror mbarrier_arrive’s normalization to always pass a handle. Apply this diff: -def cp_async_barrier_noinc(barrier_id: Union[int, PrimExpr, tir.Call]):
- """Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.
- """
- return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id)
+def cp_async_barrier_noinc(mbarrier: Union[int, PrimExpr, tir.Call]):
+ """Perform a PTX async-copy barrier using cp.async.mbarrier.arrive.noinc."""
+ if isinstance(mbarrier, (tir.Call, tir.BufferLoad)):
+ mb = mbarrier
+ elif isinstance(mbarrier, (tir.PrimExpr, int)):
+ mb = get_mbarrier(mbarrier)
+ elif isinstance(mbarrier, tir.Buffer):
+ mb = tir.BufferLoad(mbarrier, [0])
+ else:
+ raise TypeError(f"mbarrier must be an integer or a tir.Call, but got {type(mbarrier)}")
+ return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), mb)📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic to get the integer pointer
smem_int_mbarfromsmem_mbaris identical to the logic in the existingmbarrier_cp_async_arrivefunction. To improve maintainability and reduce code duplication, consider extracting this logic into a common helper function that bothmbarrier_cp_async_arriveandmbarrier_cp_async_arrive_noinccan use.