Skip to content

Conversation

@chengyupku
Copy link
Contributor

@chengyupku chengyupku commented Sep 12, 2025

  • Introduced a new intrinsic ptx_cp_async_barrier_noinc for handling the cp.async.mbarrier.arrive.noinc operation in TileLang.
  • Updated the CUDA code generation to support the new barrier operation.
  • Added a corresponding function in the TileLang Python API for ease of use.
  • Enhanced the barrier handling in CUDA templates to include the new no-increment operation, improving synchronization capabilities in parallel execution contexts.

Summary by CodeRabbit

  • New Features

    • Added support for a “no-increment” async copy barrier, enabling finer control of cp.async synchronization on CUDA.
    • Introduced a Python API to invoke the new barrier from TileLang programs.
  • Refactor

    • Centralized and simplified warp-specialization detection logic.
    • Compilation now skips register-allocation hints when warp-specialized code is detected, improving stability and preserving user intent.

…onality

- Introduced a new intrinsic `ptx_cp_async_barrier_noinc` for handling the `cp.async.mbarrier.arrive.noinc` operation in TileLang.
- Updated the CUDA code generation to support the new barrier operation.
- Added a corresponding function in the TileLang Python API for ease of use.
- Enhanced the barrier handling in CUDA templates to include the new no-increment operation, improving synchronization capabilities in parallel execution contexts.
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 12, 2025

Walkthrough

Adds a new TL builtin op for PTX cp.async barrier no-increment, wires it through Python API, TIR, CUDA codegen, and CUDA templates. Introduces WarpSpecializedDetector in a new header, adjusts transform passes to use it, and adds early-exit gating in register allocation injection when warp specialization is detected.

Changes

Cohort / File(s) Summary
TL builtin: cp_async_barrier_noinc
src/op/builtin.h, src/op/builtin.cc, tilelang/language/builtin.py
Declares, registers, and exposes tl.ptx_cp_async_barrier_noinc op; adds Python wrapper cp_async_barrier_noinc(...) invoking the intrinsic.
CUDA codegen hook
src/target/codegen_cuda.cc
Extends CodeGenTileLangCUDA::VisitExpr_(CallNode) to handle tl::ptx_cp_async_barrier_noinc, emitting tl::mbarrier_cp_async_arrive_noinc extern call.
CUDA template helper
src/tl_templates/cuda/barrier.h
Adds mbarrier_cp_async_arrive_noinc(...) templated device function using cp.async.mbarrier.arrive.noinc and synclog emission.
Register alloc transform gating
src/transform/annotate_warp_group_reg_alloc.cc
Adds early return when WarpSpecializedDetector::Detect(...) is true; simplifies includes to use warp_specialized_rewriter.h.
Warp specialization detector and rewriter refactor
src/transform/warp_specialized_rewriter.h, src/transform/warp_specialized_rewriter.cc
Introduces WarpSpecializedDetector class (visitor with Detect(...)); rewriter source switches to unified header include and removes previous internal detector.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor User
  participant TL as tilelang.builtin
  participant TIR as tl.op registry
  participant CG as CUDA CodeGen
  participant RT as CUDA Templates
  Note over User,RT: cp.async barrier (noinc) flow
  User->>TL: cp_async_barrier_noinc(barrier_id)
  TL->>TIR: tir.call_intrin("tl.ptx_cp_async_barrier_noinc", barrier_id)
  TIR-->>CG: CallNode(tl.ptx_cp_async_barrier_noinc)
  CG->>RT: extern call mbarrier_cp_async_arrive_noinc(smem_mbar)
  RT->>RT: asm("cp.async.mbarrier.arrive.noinc ...")
  RT-->>User: completes
Loading
sequenceDiagram
  autonumber
  participant Pass as SetMaxNRegInjector::Inject
  participant Det as WarpSpecializedDetector
  participant IR as PrimFunc
  rect rgba(220,235,245,0.5)
    Note right of Pass: Early gating
    Pass->>Det: Detect(IR->body)
    alt warp-specialized detected
      Pass-->>IR: return IR (unchanged)
    else not detected
      Pass->>Pass: proceed with set_max_nreg injection
      Pass-->>IR: return transformed IR
    end
  end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Pre-merge checks (2 passed, 1 warning)

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 23.53% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title is a concise, single-sentence summary that accurately identifies the primary change—adding the ptx_cp_async_barrier_noinc intrinsic and its related functionality (CUDA codegen and TileLang API exposure)—and it directly reflects the PR objectives and changed files. It is specific, clear, and not overly broad or vague.

Poem

A bunny taps the barrier line, no-inc, no fuss, just right,
Hops through tiles and CUDA lanes, swift as bits in flight.
Detects the warp, then bounds away—no tweaks today, all clear!
New op squeaks, the templates speak, asm whispers in our ear.
Thump-thump—merge night’s near! 🐇✨

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.

✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @chengyupku, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces comprehensive support for the cp.async.mbarrier.arrive.noinc PTX instruction within the TileLang framework. This includes defining a new intrinsic, integrating it into the CUDA code generation pipeline, providing a Python-level API for ease of use, and implementing the low-level CUDA template function. Additionally, a significant refactoring of the warp specialization detection logic has been performed to enhance modularity and improve handling of specific optimization scenarios.

Highlights

  • New Intrinsic for Asynchronous Barrier: Introduced ptx_cp_async_barrier_noinc intrinsic to support the cp.async.mbarrier.arrive.noinc operation in TileLang, enabling more granular control over asynchronous memory barriers in CUDA.
  • CUDA Code Generation Support: Updated the CUDA code generator to correctly translate the new ptx_cp_async_barrier_noinc intrinsic into the corresponding tl::mbarrier_cp_async_arrive_noinc external call.
  • Python API Integration: Added a user-friendly Python API function cp_async_barrier_noinc in TileLang, allowing developers to easily utilize the new asynchronous barrier operation.
  • CUDA Template Enhancement: Implemented the mbarrier_cp_async_arrive_noinc function within CUDA templates, which directly emits the PTX cp.async.mbarrier.arrive.noinc instruction, enhancing synchronization capabilities in parallel execution contexts.
  • Warp Specialization Refactoring: Refactored the WarpSpecializedDetector class into its own header file (warp_specialized_rewriter.h) to improve code organization and reusability, and integrated it into the SetMaxNRegInjector to handle manual warp specialization.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the ptx_cp_async_barrier_noinc intrinsic, which is a useful addition for advanced synchronization patterns in CUDA. The changes are well-structured, including updates to the TIR builtin, CUDA codegen, and the Python API. The refactoring to move WarpSpecializedDetector into its own header file improves code organization.

I have two main points of feedback:

  1. A potential logic error in WarpSpecializedDetector where the new intrinsic is not being checked, which could lead to incorrect behavior of the auto warp specialization pass.
  2. A minor code duplication issue in the CUDA template for the new barrier, which could be refactored to improve maintainability.

Overall, this is a good contribution. Please see my detailed comments below.

Comment on lines +59 to +64
if (call->op.same_as(create_list_of_mbarrier()) ||
call->op.same_as(mbarrier_wait_parity()) ||
call->op.same_as(builtin::ptx_arrive_barrier()) ||
call->op.same_as(builtin::ptx_cp_async_barrier())) {
has_mbarrier_op_ = true;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The new intrinsic ptx_cp_async_barrier_noinc seems to be missing from the check for mbarrier operations in WarpSpecializedDetector. This could lead to auto warp specialization not being disabled when it should be (i.e., when both TMA and this new mbarrier op are present). You should add it to the list of checks to ensure correct behavior.

      if (call->op.same_as(create_list_of_mbarrier()) ||
          call->op.same_as(mbarrier_wait_parity()) ||
          call->op.same_as(builtin::ptx_arrive_barrier()) ||
          call->op.same_as(builtin::ptx_cp_async_barrier()) ||
          call->op.same_as(ptx_cp_async_barrier_noinc())) {
        has_mbarrier_op_ = true;
      }

Comment on lines +118 to +123
uint32_t smem_int_mbar;
if constexpr (std::is_pointer_v<BarrierType>) {
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(smem_mbar));
} else {
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(&smem_mbar));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to get the integer pointer smem_int_mbar from smem_mbar is identical to the logic in the existing mbarrier_cp_async_arrive function. To improve maintainability and reduce code duplication, consider extracting this logic into a common helper function that both mbarrier_cp_async_arrive and mbarrier_cp_async_arrive_noinc can use.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/transform/warp_specialized_rewriter.cc (1)

588-590: WgMMA detection defaults to true → likely false positives.

has_wgmma_{true} means HasWgMMA() returns true even when no GEMM op is present.

Set the default to false:

-  bool has_wgmma_{true};
+  bool has_wgmma_{false};
🧹 Nitpick comments (9)
src/tl_templates/cuda/barrier.h (1)

124-130: Gate the synclog call to avoid build-time dependency in non-debug builds.

If synclog isn’t available, this will fail to link. Suggest guarding.

Apply this diff:

   asm volatile("{\n\t"
                "cp.async.mbarrier.arrive.noinc.shared::cta.b64 [%0];\n\t"
                "}"
                :
                : "r"(smem_int_mbar));
-  cutlass::arch::synclog_emit_cpasync_barrier_arrive(__LINE__, smem_int_mbar);
+#ifdef TL_ENABLE_SYNCLOG
+  cutlass::arch::synclog_emit_cpasync_barrier_arrive(__LINE__, smem_int_mbar);
+#endif
src/transform/warp_specialized_rewriter.cc (3)

253-257: Plumb support for cp.async.mbarrier.arrive.noinc to avoid double-increment scenarios.

Emitter still only uses ptx_cp_async_barrier(); with the new noinc intrinsic available, consider allowing a noinc path when a separate arrive will also be issued.

Apply this minimal refactor so call sites can opt into noinc without churn:

-static Stmt makeCpAsyncBarrier(PrimExpr barrier_id) {
-  auto call = Call(DataType::Handle(), builtin::ptx_cp_async_barrier(),
-                   {makeGetBarrier(std::move(barrier_id))});
-  return Evaluate(call);
-}
+static Stmt makeCpAsyncBarrier(PrimExpr barrier_id, bool noinc = false) {
+  const Op& op = noinc ? builtin::ptx_cp_async_barrier_noinc()
+                       : builtin::ptx_cp_async_barrier();
+  auto call = Call(DataType::Handle(), op,
+                   {makeGetBarrier(std::move(barrier_id))});
+  return Evaluate(call);
+}

866-871: Remove unused/overshadowed variable.

auto for_node = result.as<For>(); is unused and overshadowed below.

-    auto for_node = result.as<For>();

1276-1287: Clarify variable naming for gating.

warp_specialized actually holds “disable AWS” (Detect() returns true => disable). Rename for readability.

-    bool warp_specialized =
-        WarpSpecializedDetector::Detect(f->body);
+    bool disable_aws =
+        WarpSpecializedDetector::Detect(f->body);
 
-    if (!warp_specialized) {
+    if (!disable_aws) {
src/transform/annotate_warp_group_reg_alloc.cc (2)

54-58: Early exit gating reads better if the intent is explicit.

Rename the flag to reflect “disable AWS” semantics and keep logs centralized in the detector.

-    bool warp_specialized = WarpSpecializedDetector::Detect(f->body);
-    if (warp_specialized) {
+    bool disable_aws = WarpSpecializedDetector::Detect(f->body);
+    if (disable_aws) {
       // Should handle set_max_nreg when using hand-written warp specialized
       return f;
     }

106-121: SIMT-copy awareness is hardcoded to false.

Register hints may be suboptimal when SIMT copy is present. Consider exposing a lightweight “has SIMT copy” detector (similar to ProducerTraitsCollector) and using it here to gate injection.

I can extract a detector helper into the shared header and thread it here if you want.

src/transform/warp_specialized_rewriter.h (3)

34-48: Unused parameter in Detect().

skip_thread_partition is unused; silence warnings until it’s implemented.

   static bool Detect(const Stmt &stmt, bool skip_thread_partition = false) {
     WarpSpecializedDetector detector;
+    (void)skip_thread_partition;
     detector.VisitStmt(stmt);

27-30: Avoid using namespace in headers.

Header-wide using namespace tir; and using namespace runtime; pollute includers. Prefer qualified names or localized using in source files.

I can send a follow-up patch replacing unqualified symbol uses if you agree.


92-96: Dead field: thread_var_ is set but never used.

Drop it or wire it into future logic to avoid confusion.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 143b522 and 0681b26.

📒 Files selected for processing (8)
  • src/op/builtin.cc (1 hunks)
  • src/op/builtin.h (1 hunks)
  • src/target/codegen_cuda.cc (1 hunks)
  • src/tl_templates/cuda/barrier.h (1 hunks)
  • src/transform/annotate_warp_group_reg_alloc.cc (2 hunks)
  • src/transform/warp_specialized_rewriter.cc (1 hunks)
  • src/transform/warp_specialized_rewriter.h (1 hunks)
  • tilelang/language/builtin.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
tilelang/language/builtin.py (1)
tilelang/language/tir/op.py (1)
  • call_intrin (119-144)
src/transform/warp_specialized_rewriter.h (2)
src/op/builtin.h (2)
  • tl (22-330)
  • attr (24-28)
src/target/codegen_cuda.cc (14)
  • VisitStmt_ (208-227)
  • VisitStmt_ (208-208)
  • VisitStmt_ (1636-1675)
  • VisitStmt_ (1636-1636)
  • VisitStmt_ (1677-1737)
  • VisitStmt_ (1677-1677)
  • VisitStmt_ (1739-1755)
  • VisitStmt_ (1739-1739)
  • op (129-144)
  • op (129-129)
  • op (1389-1391)
  • op (1389-1389)
  • op (1392-1394)
  • op (1392-1392)
src/transform/annotate_warp_group_reg_alloc.cc (1)
src/transform/warp_specialized_rewriter.cc (2)
  • f (1121-1140)
  • f (1121-1122)
🔇 Additional comments (5)
src/op/builtin.h (1)

180-187: New TL intrinsic declaration is correct and consistent with the registry.

Signature and placement look good. No further issues.

src/target/codegen_cuda.cc (1)

1068-1071: Codegen hook for noinc variant mirrors the existing path. LGTM.

Extern emits to tl::mbarrier_cp_async_arrive_noinc; consistent with templates.

src/op/builtin.cc (1)

93-97: Op registration looks correct.

.num_inputs(1) and kOpaque effect are appropriate for this intrinsic.

src/tl_templates/cuda/barrier.h (1)

125-126: Confirm .shared::cta qualifier and target SM

  • Valid: cp.async.mbarrier.arrive.noinc.shared::cta.b64 is a supported form; .noinc omits the implicit increment (the incrementing variant uses .shared.b64).
  • Action: Confirm this is intentional and that the build/target supports mbarrier/cp.async (PTX ISA 7.0 / GPUs sm_80+). If you depend on Hopper/SM90-only extensions, target sm_90.
  • Location: src/tl_templates/cuda/barrier.h:125-126 — "cp.async.mbarrier.arrive.noinc.shared::cta.b64 [%0];\n\t"
src/transform/warp_specialized_rewriter.cc (1)

691-741: Confirm barrier increment semantics when both cp.async barrier and mbarrier.arrive are emitted.

In producer path, a cp.async barrier is emitted, and conditionally an arrive barrier is also emitted. Depending on ISA semantics, this may double-increment. If noinc is intended for the cp.async path, wire the bool suggested above accordingly.

Would you like me to wire the exact conditions (e.g., release_after=false ⇒ noinc=true) once we confirm the intended sequencing?

Also applies to: 758-767

Comment on lines +56 to +68
private:
void VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(create_list_of_mbarrier()) ||
call->op.same_as(mbarrier_wait_parity()) ||
call->op.same_as(builtin::ptx_arrive_barrier()) ||
call->op.same_as(builtin::ptx_cp_async_barrier())) {
has_mbarrier_op_ = true;
}
}
IRVisitorWithAnalyzer::VisitStmt_(op);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Detector misses new noinc intrinsic; may under-detect mbarrier usage.

Add ptx_cp_async_barrier_noinc() (and optionally other mbarrier-related ops) to the mbarrier set to keep gating correct with the new intrinsic.

-      if (call->op.same_as(create_list_of_mbarrier()) ||
-          call->op.same_as(mbarrier_wait_parity()) ||
-          call->op.same_as(builtin::ptx_arrive_barrier()) ||
-          call->op.same_as(builtin::ptx_cp_async_barrier())) {
+      if (call->op.same_as(create_list_of_mbarrier()) ||
+          call->op.same_as(mbarrier_wait_parity()) ||
+          call->op.same_as(builtin::ptx_arrive_barrier()) ||
+          call->op.same_as(builtin::ptx_cp_async_barrier()) ||
+          call->op.same_as(builtin::ptx_cp_async_barrier_noinc())) {
         has_mbarrier_op_ = true;
       }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
private:
void VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(create_list_of_mbarrier()) ||
call->op.same_as(mbarrier_wait_parity()) ||
call->op.same_as(builtin::ptx_arrive_barrier()) ||
call->op.same_as(builtin::ptx_cp_async_barrier())) {
has_mbarrier_op_ = true;
}
}
IRVisitorWithAnalyzer::VisitStmt_(op);
}
private:
void VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(create_list_of_mbarrier()) ||
call->op.same_as(mbarrier_wait_parity()) ||
call->op.same_as(builtin::ptx_arrive_barrier()) ||
call->op.same_as(builtin::ptx_cp_async_barrier()) ||
call->op.same_as(builtin::ptx_cp_async_barrier_noinc())) {
has_mbarrier_op_ = true;
}
}
IRVisitorWithAnalyzer::VisitStmt_(op);
}
🤖 Prompt for AI Agents
In src/transform/warp_specialized_rewriter.h around lines 56 to 68, the
VisitStmt_ detector currently checks a set of mbarrier intrinsics but misses the
new ptx_cp_async_barrier_noinc intrinsic; update the conditional that sets
has_mbarrier_op_ to also check
call->op.same_as(builtin::ptx_cp_async_barrier_noinc()) (and add any other new
mbarrier-related intrinsics as needed) so the detector correctly flags usage;
ensure you include the appropriate declaration/namespace for the new builtin
symbol and run tests to validate gating behavior.

Comment on lines +69 to +75
void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) ||
op->op.same_as(set_max_nreg())) {
has_tma_op_ = true;
}
IRVisitorWithAnalyzer::VisitExpr_(op);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Classifying set_max_nreg as “TMA” causes false disables.

set_max_nreg() is unrelated to TMA. Including it can disable AWS whenever any mbarrier call appears with a register hint.

-    if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) ||
-        op->op.same_as(set_max_nreg())) {
+    if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
       has_tma_op_ = true;
     }

If you still want to track register hints, add a separate flag (e.g., has_set_max_nreg_) and keep it out of the TMA/mbarrier cross check.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) ||
op->op.same_as(set_max_nreg())) {
has_tma_op_ = true;
}
IRVisitorWithAnalyzer::VisitExpr_(op);
}
void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
has_tma_op_ = true;
}
IRVisitorWithAnalyzer::VisitExpr_(op);
}

Comment on lines +77 to +90
void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == "warp_specialize" &&
op->value.as<IntImmNode>()->value == 1) {
has_warp_specialization_ = true;
}
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") {
ICHECK(iv->dom->extent.as<IntImmNode>());
thread_var_ = iv;
}
}
IRVisitorWithAnalyzer::VisitStmt_(op);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Attribute key for manual WS may be inconsistent.

Detector checks op->attr_key == "warp_specialize", while the pass uses attr::kWarpSpecializationScope. Ensure we’re keying off the right attribute for “manual WS enabled”.

I can align the detector once we confirm the canonical key.


🏁 Script executed:

#!/bin/bash
# Verify which attribute keys are used across the repo
rg -nP '"warp_specialize"\s*|kWarpSpecializationScope' -C2

Length of output: 4489


Replace "warp_specialize" with attr::kWarpSpecializationScope in the detector

Repo-wide usage and tests use "kWarpSpecializationScope" (attr::kWarpSpecializationScope); "warp_specialize" only appears in ir.cc. Update src/transform/warp_specialized_rewriter.h (VisitStmt_, ~lines 77–90) to check attr::kWarpSpecializationScope — or check both keys if backward-compat is required.

🤖 Prompt for AI Agents
In src/transform/warp_specialized_rewriter.h around lines 77 to 90, the
VisitStmt_ detector currently checks the literal attr key "warp_specialize"
which is inconsistent with repo-wide usage; change the conditional to check
tir::attr::kWarpSpecializationScope instead (or check both "warp_specialize" and
tir::attr::kWarpSpecializationScope to preserve backward compatibility). Update
the if statement to compare op->attr_key against
tir::attr::kWarpSpecializationScope (and optionally also accept the legacy
"warp_specialize"), keeping the existing check for value.as<IntImmNode>()->value
== 1 and setting has_warp_specialization_ unchanged. Ensure to include the
proper namespace qualification (tir::attr::kWarpSpecializationScope) and adjust
any includes if necessary.

Comment on lines +355 to +358
def cp_async_barrier_noinc(barrier_id: Union[int, PrimExpr, tir.Call]):
"""Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Normalize barrier argument like mbarrier_arrive; passing an int currently breaks codegen.

As written, callers can pass an int (per the type hint), which will be emitted as a literal (e.g., 0) into tl::mbarrier_cp_async_arrive_noinc and fail to bind to the expected reference. Mirror mbarrier_arrive’s normalization to always pass a handle.

Apply this diff:

-def cp_async_barrier_noinc(barrier_id: Union[int, PrimExpr, tir.Call]):
-    """Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.
-    """
-    return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id)
+def cp_async_barrier_noinc(mbarrier: Union[int, PrimExpr, tir.Call]):
+    """Perform a PTX async-copy barrier using cp.async.mbarrier.arrive.noinc."""
+    if isinstance(mbarrier, (tir.Call, tir.BufferLoad)):
+        mb = mbarrier
+    elif isinstance(mbarrier, (tir.PrimExpr, int)):
+        mb = get_mbarrier(mbarrier)
+    elif isinstance(mbarrier, tir.Buffer):
+        mb = tir.BufferLoad(mbarrier, [0])
+    else:
+        raise TypeError(f"mbarrier must be an integer or a tir.Call, but got {type(mbarrier)}")
+    return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), mb)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def cp_async_barrier_noinc(barrier_id: Union[int, PrimExpr, tir.Call]):
"""Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id)
def cp_async_barrier_noinc(mbarrier: Union[int, PrimExpr, tir.Call]):
"""Perform a PTX async-copy barrier using cp.async.mbarrier.arrive.noinc."""
if isinstance(mbarrier, (tir.Call, tir.BufferLoad)):
mb = mbarrier
elif isinstance(mbarrier, (tir.PrimExpr, int)):
mb = get_mbarrier(mbarrier)
elif isinstance(mbarrier, tir.Buffer):
mb = tir.BufferLoad(mbarrier, [0])
else:
raise TypeError(f"mbarrier must be an integer or a tir.Call, but got {type(mbarrier)}")
return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), mb)
🤖 Prompt for AI Agents
In tilelang/language/builtin.py around lines 355-358, the cp_async_barrier_noinc
function currently accepts an int and emits a literal which breaks codegen;
mirror mbarrier_arrive’s normalization so the barrier_id is always passed as a
handle. Change the function to detect when barrier_id is a plain int/PrimExpr
literal (or otherwise not already a handle) and wrap it the same way
mbarrier_arrive does (i.e., convert the integer/literal into a tir handle via
the same tir.call_intrin wrap used in mbarrier_arrive) before returning the
call_intrin for tl.ptx_cp_async_barrier_noinc.

@chengyupku chengyupku merged commit ae9b706 into tile-ai:main Sep 14, 2025
5 of 7 checks passed
RubiaCx pushed a commit to RubiaCx/tilelang that referenced this pull request Nov 24, 2025
…onality (tile-ai#809)

- Introduced a new intrinsic `ptx_cp_async_barrier_noinc` for handling the `cp.async.mbarrier.arrive.noinc` operation in TileLang.
- Updated the CUDA code generation to support the new barrier operation.
- Added a corresponding function in the TileLang Python API for ease of use.
- Enhanced the barrier handling in CUDA templates to include the new no-increment operation, improving synchronization capabilities in parallel execution contexts.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant