Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ TIR_DEFINE_TL_BUILTIN(ptx_stmatrix)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(ptx_cp_async_barrier_noinc)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(fence_proxy_async)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind",
Expand Down
8 changes: 8 additions & 0 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,14 @@ TVM_DLL const Op &ptx_ldmatrix();
*/
TVM_DLL const Op &ptx_stmatrix();

/*!
* \brief tvm intrinsic for ptx async copy barrier using
* cp.async.mbarrier.arrive.noinc
*
* This op is used to represent a ptx async copy barrier operation in tilelang.
*/
TVM_DLL const Op &ptx_cp_async_barrier_noinc();

/*!
* \brief Pack two b16 value into a b32 value
*
Expand Down
2 changes: 2 additions & 0 deletions src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
}
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
print_extern_call_stmt("tl::mbarrier_cp_async_arrive");
} else if (op->op.same_as(tl::ptx_cp_async_barrier_noinc())) {
print_extern_call_stmt("tl::mbarrier_cp_async_arrive_noinc");
} else if (op->op.same_as(tl::mbarrier_expect_tx())) {
ICHECK_EQ(op->args.size(), 2);
this->PrintIndent();
Expand Down
16 changes: 16 additions & 0 deletions src/tl_templates/cuda/barrier.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,22 @@ TL_DEVICE void mbarrier_cp_async_arrive(BarrierType &smem_mbar) {
: "r"(smem_int_mbar));
}

template <typename BarrierType = uint64_t>
TL_DEVICE void mbarrier_cp_async_arrive_noinc(BarrierType &smem_mbar) {
uint32_t smem_int_mbar;
if constexpr (std::is_pointer_v<BarrierType>) {
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(smem_mbar));
} else {
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(&smem_mbar));
}
Comment on lines +118 to +123
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to get the integer pointer smem_int_mbar from smem_mbar is identical to the logic in the existing mbarrier_cp_async_arrive function. To improve maintainability and reduce code duplication, consider extracting this logic into a common helper function that both mbarrier_cp_async_arrive and mbarrier_cp_async_arrive_noinc can use.

asm volatile("{\n\t"
"cp.async.mbarrier.arrive.noinc.shared::cta.b64 [%0];\n\t"
"}"
:
: "r"(smem_int_mbar));
cutlass::arch::synclog_emit_cpasync_barrier_arrive(__LINE__, smem_int_mbar);
}

TL_DEVICE void fence_proxy_async() {
asm volatile("fence.proxy.async.shared::cta;" : :);
}
Expand Down
13 changes: 6 additions & 7 deletions src/transform/annotate_warp_group_reg_alloc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,11 @@
* \file annotate_warp_group_reg_alloc.cc
* \brief Annotate warp group reg alloc for warp specialization
*/
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "warp_specialized_rewriter.h"
#include <unordered_set>
#include <utility>
#include <vector>

#include "../op/builtin.h"
#include "tir/transforms/ir_utils.h"

namespace tvm {
namespace tl {

Expand Down Expand Up @@ -57,6 +51,11 @@ class SetMaxNRegCollector : public StmtExprVisitor {
class SetMaxNRegInjector : public StmtExprMutator {
public:
static PrimFunc Inject(PrimFunc f) {
bool warp_specialized = WarpSpecializedDetector::Detect(f->body);
if (warp_specialized) {
// Should handle set_max_nreg when using hand-written warp specialized
return f;
}
auto T = SetMaxNRegInjector();
T.nreg_ = SetMaxNRegCollector::Collect(f);
f.CopyOnWrite()->body = T(f->body);
Expand Down
83 changes: 1 addition & 82 deletions src/transform/warp_specialized_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,7 @@
* \brief Warp specialized Pipeline for cuda GPU (sm90+)
*/

#include "arith/ir_visitor_with_analyzer.h"
#include "tir/analysis/var_use_def_analysis.h"
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include <utility>

#include "../op/builtin.h"
#include "./common/collector.h"
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
#include "warp_specialized_rewriter.h"

namespace tvm {
namespace tl {
Expand Down Expand Up @@ -1284,73 +1270,6 @@ class WarpSpecializedRewriter : public StmtExprMutator {
bool disable_shuffle_elect_ = false;
};

class WarpSpecializedDetector : public IRVisitorWithAnalyzer {
public:
// return true means this aws will be disabled
static bool Detect(const Stmt &stmt, bool skip_thread_partition = false) {
WarpSpecializedDetector detector;
detector.VisitStmt(stmt);
if (detector.has_warp_specialization_) {
LOG(WARNING) << "Auto warp specialization will be disabled because warp "
"specialization is manually enabled";
return true;
}
if (detector.has_tma_op_ && detector.has_mbarrier_op_) {
LOG(WARNING) << "Auto warp specialization will be disabled because TMA "
"and mbarrier are both present";
return true;
}
return false;
}

WarpSpecializedDetector() {
has_tma_op_ = false;
has_mbarrier_op_ = false;
has_warp_specialization_ = false;
}

private:
void VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(create_list_of_mbarrier()) ||
call->op.same_as(mbarrier_wait_parity()) ||
call->op.same_as(builtin::ptx_arrive_barrier()) ||
call->op.same_as(builtin::ptx_cp_async_barrier())) {
has_mbarrier_op_ = true;
}
}
IRVisitorWithAnalyzer::VisitStmt_(op);
}

void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) ||
op->op.same_as(set_max_nreg())) {
has_tma_op_ = true;
}
IRVisitorWithAnalyzer::VisitExpr_(op);
}

void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == "warp_specialize" &&
op->value.as<IntImmNode>()->value == 1) {
has_warp_specialization_ = true;
}
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") {
ICHECK(iv->dom->extent.as<IntImmNode>());
thread_var_ = iv;
}
}
IRVisitorWithAnalyzer::VisitStmt_(op);
}

bool has_tma_op_{false};
IterVar thread_var_;
bool has_mbarrier_op_{false};
bool has_warp_specialization_{false};
};

using namespace tir::transform;

tvm::transform::Pass WarpSpecialized() {
Expand Down
99 changes: 99 additions & 0 deletions src/transform/warp_specialized_rewriter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*!
* \file warp_specialized_rewriter.h
* \brief tools for warp-specialized-related analysis and transformation
*/

#pragma once

#include "arith/ir_visitor_with_analyzer.h"
#include "tir/analysis/var_use_def_analysis.h"
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include <utility>

#include "../op/builtin.h"
#include "./common/collector.h"
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"

namespace tvm {
namespace tl {

using namespace tir;
using namespace runtime;
using arith::IRVisitorWithAnalyzer;

class WarpSpecializedDetector : public IRVisitorWithAnalyzer {
public:
// return true means this aws will be disabled
static bool Detect(const Stmt &stmt, bool skip_thread_partition = false) {
WarpSpecializedDetector detector;
detector.VisitStmt(stmt);
if (detector.has_warp_specialization_) {
LOG(WARNING) << "Auto warp specialization will be disabled because warp "
"specialization is manually enabled";
return true;
}
if (detector.has_tma_op_ && detector.has_mbarrier_op_) {
LOG(WARNING) << "Auto warp specialization will be disabled because TMA "
"and mbarrier are both present";
return true;
}
return false;
}

WarpSpecializedDetector() {
has_tma_op_ = false;
has_mbarrier_op_ = false;
has_warp_specialization_ = false;
}

private:
void VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(create_list_of_mbarrier()) ||
call->op.same_as(mbarrier_wait_parity()) ||
call->op.same_as(builtin::ptx_arrive_barrier()) ||
call->op.same_as(builtin::ptx_cp_async_barrier())) {
has_mbarrier_op_ = true;
}
Comment on lines +59 to +64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The new intrinsic ptx_cp_async_barrier_noinc seems to be missing from the check for mbarrier operations in WarpSpecializedDetector. This could lead to auto warp specialization not being disabled when it should be (i.e., when both TMA and this new mbarrier op are present). You should add it to the list of checks to ensure correct behavior.

      if (call->op.same_as(create_list_of_mbarrier()) ||
          call->op.same_as(mbarrier_wait_parity()) ||
          call->op.same_as(builtin::ptx_arrive_barrier()) ||
          call->op.same_as(builtin::ptx_cp_async_barrier()) ||
          call->op.same_as(ptx_cp_async_barrier_noinc())) {
        has_mbarrier_op_ = true;
      }

}
IRVisitorWithAnalyzer::VisitStmt_(op);
}

Comment on lines +56 to +68
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Detector misses new noinc intrinsic; may under-detect mbarrier usage.

Add ptx_cp_async_barrier_noinc() (and optionally other mbarrier-related ops) to the mbarrier set to keep gating correct with the new intrinsic.

-      if (call->op.same_as(create_list_of_mbarrier()) ||
-          call->op.same_as(mbarrier_wait_parity()) ||
-          call->op.same_as(builtin::ptx_arrive_barrier()) ||
-          call->op.same_as(builtin::ptx_cp_async_barrier())) {
+      if (call->op.same_as(create_list_of_mbarrier()) ||
+          call->op.same_as(mbarrier_wait_parity()) ||
+          call->op.same_as(builtin::ptx_arrive_barrier()) ||
+          call->op.same_as(builtin::ptx_cp_async_barrier()) ||
+          call->op.same_as(builtin::ptx_cp_async_barrier_noinc())) {
         has_mbarrier_op_ = true;
       }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
private:
void VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(create_list_of_mbarrier()) ||
call->op.same_as(mbarrier_wait_parity()) ||
call->op.same_as(builtin::ptx_arrive_barrier()) ||
call->op.same_as(builtin::ptx_cp_async_barrier())) {
has_mbarrier_op_ = true;
}
}
IRVisitorWithAnalyzer::VisitStmt_(op);
}
private:
void VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(create_list_of_mbarrier()) ||
call->op.same_as(mbarrier_wait_parity()) ||
call->op.same_as(builtin::ptx_arrive_barrier()) ||
call->op.same_as(builtin::ptx_cp_async_barrier()) ||
call->op.same_as(builtin::ptx_cp_async_barrier_noinc())) {
has_mbarrier_op_ = true;
}
}
IRVisitorWithAnalyzer::VisitStmt_(op);
}
🤖 Prompt for AI Agents
In src/transform/warp_specialized_rewriter.h around lines 56 to 68, the
VisitStmt_ detector currently checks a set of mbarrier intrinsics but misses the
new ptx_cp_async_barrier_noinc intrinsic; update the conditional that sets
has_mbarrier_op_ to also check
call->op.same_as(builtin::ptx_cp_async_barrier_noinc()) (and add any other new
mbarrier-related intrinsics as needed) so the detector correctly flags usage;
ensure you include the appropriate declaration/namespace for the new builtin
symbol and run tests to validate gating behavior.

void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) ||
op->op.same_as(set_max_nreg())) {
has_tma_op_ = true;
}
IRVisitorWithAnalyzer::VisitExpr_(op);
}
Comment on lines +69 to +75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Classifying set_max_nreg as “TMA” causes false disables.

set_max_nreg() is unrelated to TMA. Including it can disable AWS whenever any mbarrier call appears with a register hint.

-    if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) ||
-        op->op.same_as(set_max_nreg())) {
+    if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
       has_tma_op_ = true;
     }

If you still want to track register hints, add a separate flag (e.g., has_set_max_nreg_) and keep it out of the TMA/mbarrier cross check.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) ||
op->op.same_as(set_max_nreg())) {
has_tma_op_ = true;
}
IRVisitorWithAnalyzer::VisitExpr_(op);
}
void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
has_tma_op_ = true;
}
IRVisitorWithAnalyzer::VisitExpr_(op);
}


void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == "warp_specialize" &&
op->value.as<IntImmNode>()->value == 1) {
has_warp_specialization_ = true;
}
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") {
ICHECK(iv->dom->extent.as<IntImmNode>());
thread_var_ = iv;
}
}
IRVisitorWithAnalyzer::VisitStmt_(op);
}
Comment on lines +77 to +90
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Attribute key for manual WS may be inconsistent.

Detector checks op->attr_key == "warp_specialize", while the pass uses attr::kWarpSpecializationScope. Ensure we’re keying off the right attribute for “manual WS enabled”.

I can align the detector once we confirm the canonical key.


🏁 Script executed:

#!/bin/bash
# Verify which attribute keys are used across the repo
rg -nP '"warp_specialize"\s*|kWarpSpecializationScope' -C2

Length of output: 4489


Replace "warp_specialize" with attr::kWarpSpecializationScope in the detector

Repo-wide usage and tests use "kWarpSpecializationScope" (attr::kWarpSpecializationScope); "warp_specialize" only appears in ir.cc. Update src/transform/warp_specialized_rewriter.h (VisitStmt_, ~lines 77–90) to check attr::kWarpSpecializationScope — or check both keys if backward-compat is required.

🤖 Prompt for AI Agents
In src/transform/warp_specialized_rewriter.h around lines 77 to 90, the
VisitStmt_ detector currently checks the literal attr key "warp_specialize"
which is inconsistent with repo-wide usage; change the conditional to check
tir::attr::kWarpSpecializationScope instead (or check both "warp_specialize" and
tir::attr::kWarpSpecializationScope to preserve backward compatibility). Update
the if statement to compare op->attr_key against
tir::attr::kWarpSpecializationScope (and optionally also accept the legacy
"warp_specialize"), keeping the existing check for value.as<IntImmNode>()->value
== 1 and setting has_warp_specialization_ unchanged. Ensure to include the
proper namespace qualification (tir::attr::kWarpSpecializationScope) and adjust
any includes if necessary.


bool has_tma_op_{false};
IterVar thread_var_;
bool has_mbarrier_op_{false};
bool has_warp_specialization_{false};
};

} // namespace tl
} // namespace tvm
6 changes: 6 additions & 0 deletions tilelang/language/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,9 @@ def sync_grid():
"""Synchronize all threads in a grid.
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.sync_grid"))


def cp_async_barrier_noinc(barrier_id: Union[int, PrimExpr, tir.Call]):
"""Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id)
Comment on lines +355 to +358
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Normalize barrier argument like mbarrier_arrive; passing an int currently breaks codegen.

As written, callers can pass an int (per the type hint), which will be emitted as a literal (e.g., 0) into tl::mbarrier_cp_async_arrive_noinc and fail to bind to the expected reference. Mirror mbarrier_arrive’s normalization to always pass a handle.

Apply this diff:

-def cp_async_barrier_noinc(barrier_id: Union[int, PrimExpr, tir.Call]):
-    """Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.
-    """
-    return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id)
+def cp_async_barrier_noinc(mbarrier: Union[int, PrimExpr, tir.Call]):
+    """Perform a PTX async-copy barrier using cp.async.mbarrier.arrive.noinc."""
+    if isinstance(mbarrier, (tir.Call, tir.BufferLoad)):
+        mb = mbarrier
+    elif isinstance(mbarrier, (tir.PrimExpr, int)):
+        mb = get_mbarrier(mbarrier)
+    elif isinstance(mbarrier, tir.Buffer):
+        mb = tir.BufferLoad(mbarrier, [0])
+    else:
+        raise TypeError(f"mbarrier must be an integer or a tir.Call, but got {type(mbarrier)}")
+    return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), mb)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def cp_async_barrier_noinc(barrier_id: Union[int, PrimExpr, tir.Call]):
"""Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id)
def cp_async_barrier_noinc(mbarrier: Union[int, PrimExpr, tir.Call]):
"""Perform a PTX async-copy barrier using cp.async.mbarrier.arrive.noinc."""
if isinstance(mbarrier, (tir.Call, tir.BufferLoad)):
mb = mbarrier
elif isinstance(mbarrier, (tir.PrimExpr, int)):
mb = get_mbarrier(mbarrier)
elif isinstance(mbarrier, tir.Buffer):
mb = tir.BufferLoad(mbarrier, [0])
else:
raise TypeError(f"mbarrier must be an integer or a tir.Call, but got {type(mbarrier)}")
return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), mb)
🤖 Prompt for AI Agents
In tilelang/language/builtin.py around lines 355-358, the cp_async_barrier_noinc
function currently accepts an int and emits a literal which breaks codegen;
mirror mbarrier_arrive’s normalization so the barrier_id is always passed as a
handle. Change the function to detect when barrier_id is a plain int/PrimExpr
literal (or otherwise not already a handle) and wrap it the same way
mbarrier_arrive does (i.e., convert the integer/literal into a tir handle via
the same tir.call_intrin wrap used in mbarrier_arrive) before returning the
call_intrin for tl.ptx_cp_async_barrier_noinc.

Loading