-
Notifications
You must be signed in to change notification settings - Fork 332
[Feature] Add ptx_cp_async_barrier_noinc intrinsic and related functionality #809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…onality - Introduced a new intrinsic `ptx_cp_async_barrier_noinc` for handling the `cp.async.mbarrier.arrive.noinc` operation in TileLang. - Updated the CUDA code generation to support the new barrier operation. - Added a corresponding function in the TileLang Python API for ease of use. - Enhanced the barrier handling in CUDA templates to include the new no-increment operation, improving synchronization capabilities in parallel execution contexts.
WalkthroughAdds a new TL builtin op for PTX cp.async barrier no-increment, wires it through Python API, TIR, CUDA codegen, and CUDA templates. Introduces WarpSpecializedDetector in a new header, adjusts transform passes to use it, and adds early-exit gating in register allocation injection when warp specialization is detected. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant TL as tilelang.builtin
participant TIR as tl.op registry
participant CG as CUDA CodeGen
participant RT as CUDA Templates
Note over User,RT: cp.async barrier (noinc) flow
User->>TL: cp_async_barrier_noinc(barrier_id)
TL->>TIR: tir.call_intrin("tl.ptx_cp_async_barrier_noinc", barrier_id)
TIR-->>CG: CallNode(tl.ptx_cp_async_barrier_noinc)
CG->>RT: extern call mbarrier_cp_async_arrive_noinc(smem_mbar)
RT->>RT: asm("cp.async.mbarrier.arrive.noinc ...")
RT-->>User: completes
sequenceDiagram
autonumber
participant Pass as SetMaxNRegInjector::Inject
participant Det as WarpSpecializedDetector
participant IR as PrimFunc
rect rgba(220,235,245,0.5)
Note right of Pass: Early gating
Pass->>Det: Detect(IR->body)
alt warp-specialized detected
Pass-->>IR: return IR (unchanged)
else not detected
Pass->>Pass: proceed with set_max_nreg injection
Pass-->>IR: return transformed IR
end
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Pre-merge checks (2 passed, 1 warning)❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
Poem
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).Please share your feedback with us on this Discord post. ✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @chengyupku, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces comprehensive support for the cp.async.mbarrier.arrive.noinc PTX instruction within the TileLang framework. This includes defining a new intrinsic, integrating it into the CUDA code generation pipeline, providing a Python-level API for ease of use, and implementing the low-level CUDA template function. Additionally, a significant refactoring of the warp specialization detection logic has been performed to enhance modularity and improve handling of specific optimization scenarios.
Highlights
- New Intrinsic for Asynchronous Barrier: Introduced
ptx_cp_async_barrier_noincintrinsic to support thecp.async.mbarrier.arrive.noincoperation in TileLang, enabling more granular control over asynchronous memory barriers in CUDA. - CUDA Code Generation Support: Updated the CUDA code generator to correctly translate the new
ptx_cp_async_barrier_noincintrinsic into the correspondingtl::mbarrier_cp_async_arrive_noincexternal call. - Python API Integration: Added a user-friendly Python API function
cp_async_barrier_noincin TileLang, allowing developers to easily utilize the new asynchronous barrier operation. - CUDA Template Enhancement: Implemented the
mbarrier_cp_async_arrive_noincfunction within CUDA templates, which directly emits the PTXcp.async.mbarrier.arrive.noincinstruction, enhancing synchronization capabilities in parallel execution contexts. - Warp Specialization Refactoring: Refactored the
WarpSpecializedDetectorclass into its own header file (warp_specialized_rewriter.h) to improve code organization and reusability, and integrated it into theSetMaxNRegInjectorto handle manual warp specialization.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces the ptx_cp_async_barrier_noinc intrinsic, which is a useful addition for advanced synchronization patterns in CUDA. The changes are well-structured, including updates to the TIR builtin, CUDA codegen, and the Python API. The refactoring to move WarpSpecializedDetector into its own header file improves code organization.
I have two main points of feedback:
- A potential logic error in
WarpSpecializedDetectorwhere the new intrinsic is not being checked, which could lead to incorrect behavior of the auto warp specialization pass. - A minor code duplication issue in the CUDA template for the new barrier, which could be refactored to improve maintainability.
Overall, this is a good contribution. Please see my detailed comments below.
| if (call->op.same_as(create_list_of_mbarrier()) || | ||
| call->op.same_as(mbarrier_wait_parity()) || | ||
| call->op.same_as(builtin::ptx_arrive_barrier()) || | ||
| call->op.same_as(builtin::ptx_cp_async_barrier())) { | ||
| has_mbarrier_op_ = true; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new intrinsic ptx_cp_async_barrier_noinc seems to be missing from the check for mbarrier operations in WarpSpecializedDetector. This could lead to auto warp specialization not being disabled when it should be (i.e., when both TMA and this new mbarrier op are present). You should add it to the list of checks to ensure correct behavior.
if (call->op.same_as(create_list_of_mbarrier()) ||
call->op.same_as(mbarrier_wait_parity()) ||
call->op.same_as(builtin::ptx_arrive_barrier()) ||
call->op.same_as(builtin::ptx_cp_async_barrier()) ||
call->op.same_as(ptx_cp_async_barrier_noinc())) {
has_mbarrier_op_ = true;
}| uint32_t smem_int_mbar; | ||
| if constexpr (std::is_pointer_v<BarrierType>) { | ||
| smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(smem_mbar)); | ||
| } else { | ||
| smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(&smem_mbar)); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic to get the integer pointer smem_int_mbar from smem_mbar is identical to the logic in the existing mbarrier_cp_async_arrive function. To improve maintainability and reduce code duplication, consider extracting this logic into a common helper function that both mbarrier_cp_async_arrive and mbarrier_cp_async_arrive_noinc can use.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/transform/warp_specialized_rewriter.cc (1)
588-590: WgMMA detection defaults to true → likely false positives.
has_wgmma_{true}means HasWgMMA() returns true even when no GEMM op is present.Set the default to false:
- bool has_wgmma_{true}; + bool has_wgmma_{false};
🧹 Nitpick comments (9)
src/tl_templates/cuda/barrier.h (1)
124-130: Gate the synclog call to avoid build-time dependency in non-debug builds.If synclog isn’t available, this will fail to link. Suggest guarding.
Apply this diff:
asm volatile("{\n\t" "cp.async.mbarrier.arrive.noinc.shared::cta.b64 [%0];\n\t" "}" : : "r"(smem_int_mbar)); - cutlass::arch::synclog_emit_cpasync_barrier_arrive(__LINE__, smem_int_mbar); +#ifdef TL_ENABLE_SYNCLOG + cutlass::arch::synclog_emit_cpasync_barrier_arrive(__LINE__, smem_int_mbar); +#endifsrc/transform/warp_specialized_rewriter.cc (3)
253-257: Plumb support for cp.async.mbarrier.arrive.noinc to avoid double-increment scenarios.Emitter still only uses ptx_cp_async_barrier(); with the new noinc intrinsic available, consider allowing a noinc path when a separate arrive will also be issued.
Apply this minimal refactor so call sites can opt into noinc without churn:
-static Stmt makeCpAsyncBarrier(PrimExpr barrier_id) { - auto call = Call(DataType::Handle(), builtin::ptx_cp_async_barrier(), - {makeGetBarrier(std::move(barrier_id))}); - return Evaluate(call); -} +static Stmt makeCpAsyncBarrier(PrimExpr barrier_id, bool noinc = false) { + const Op& op = noinc ? builtin::ptx_cp_async_barrier_noinc() + : builtin::ptx_cp_async_barrier(); + auto call = Call(DataType::Handle(), op, + {makeGetBarrier(std::move(barrier_id))}); + return Evaluate(call); +}
866-871: Remove unused/overshadowed variable.
auto for_node = result.as<For>();is unused and overshadowed below.- auto for_node = result.as<For>();
1276-1287: Clarify variable naming for gating.
warp_specializedactually holds “disable AWS” (Detect() returns true => disable). Rename for readability.- bool warp_specialized = - WarpSpecializedDetector::Detect(f->body); + bool disable_aws = + WarpSpecializedDetector::Detect(f->body); - if (!warp_specialized) { + if (!disable_aws) {src/transform/annotate_warp_group_reg_alloc.cc (2)
54-58: Early exit gating reads better if the intent is explicit.Rename the flag to reflect “disable AWS” semantics and keep logs centralized in the detector.
- bool warp_specialized = WarpSpecializedDetector::Detect(f->body); - if (warp_specialized) { + bool disable_aws = WarpSpecializedDetector::Detect(f->body); + if (disable_aws) { // Should handle set_max_nreg when using hand-written warp specialized return f; }
106-121: SIMT-copy awareness is hardcoded to false.Register hints may be suboptimal when SIMT copy is present. Consider exposing a lightweight “has SIMT copy” detector (similar to ProducerTraitsCollector) and using it here to gate injection.
I can extract a detector helper into the shared header and thread it here if you want.
src/transform/warp_specialized_rewriter.h (3)
34-48: Unused parameter in Detect().
skip_thread_partitionis unused; silence warnings until it’s implemented.static bool Detect(const Stmt &stmt, bool skip_thread_partition = false) { WarpSpecializedDetector detector; + (void)skip_thread_partition; detector.VisitStmt(stmt);
27-30: Avoidusing namespacein headers.Header-wide
using namespace tir;andusing namespace runtime;pollute includers. Prefer qualified names or localizedusingin source files.I can send a follow-up patch replacing unqualified symbol uses if you agree.
92-96: Dead field:thread_var_is set but never used.Drop it or wire it into future logic to avoid confusion.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
src/op/builtin.cc(1 hunks)src/op/builtin.h(1 hunks)src/target/codegen_cuda.cc(1 hunks)src/tl_templates/cuda/barrier.h(1 hunks)src/transform/annotate_warp_group_reg_alloc.cc(2 hunks)src/transform/warp_specialized_rewriter.cc(1 hunks)src/transform/warp_specialized_rewriter.h(1 hunks)tilelang/language/builtin.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
tilelang/language/builtin.py (1)
tilelang/language/tir/op.py (1)
call_intrin(119-144)
src/transform/warp_specialized_rewriter.h (2)
src/op/builtin.h (2)
tl(22-330)attr(24-28)src/target/codegen_cuda.cc (14)
VisitStmt_(208-227)VisitStmt_(208-208)VisitStmt_(1636-1675)VisitStmt_(1636-1636)VisitStmt_(1677-1737)VisitStmt_(1677-1677)VisitStmt_(1739-1755)VisitStmt_(1739-1739)op(129-144)op(129-129)op(1389-1391)op(1389-1389)op(1392-1394)op(1392-1392)
src/transform/annotate_warp_group_reg_alloc.cc (1)
src/transform/warp_specialized_rewriter.cc (2)
f(1121-1140)f(1121-1122)
🔇 Additional comments (5)
src/op/builtin.h (1)
180-187: New TL intrinsic declaration is correct and consistent with the registry.Signature and placement look good. No further issues.
src/target/codegen_cuda.cc (1)
1068-1071: Codegen hook for noinc variant mirrors the existing path. LGTM.Extern emits to tl::mbarrier_cp_async_arrive_noinc; consistent with templates.
src/op/builtin.cc (1)
93-97: Op registration looks correct..num_inputs(1) and kOpaque effect are appropriate for this intrinsic.
src/tl_templates/cuda/barrier.h (1)
125-126: Confirm.shared::ctaqualifier and target SM
- Valid:
cp.async.mbarrier.arrive.noinc.shared::cta.b64is a supported form;.noincomits the implicit increment (the incrementing variant uses.shared.b64).- Action: Confirm this is intentional and that the build/target supports mbarrier/cp.async (PTX ISA 7.0 / GPUs sm_80+). If you depend on Hopper/SM90-only extensions, target sm_90.
- Location: src/tl_templates/cuda/barrier.h:125-126 —
"cp.async.mbarrier.arrive.noinc.shared::cta.b64 [%0];\n\t"src/transform/warp_specialized_rewriter.cc (1)
691-741: Confirm barrier increment semantics when both cp.async barrier and mbarrier.arrive are emitted.In producer path, a cp.async barrier is emitted, and conditionally an arrive barrier is also emitted. Depending on ISA semantics, this may double-increment. If noinc is intended for the cp.async path, wire the bool suggested above accordingly.
Would you like me to wire the exact conditions (e.g., release_after=false ⇒ noinc=true) once we confirm the intended sequencing?
Also applies to: 758-767
| private: | ||
| void VisitStmt_(const EvaluateNode *op) final { | ||
| if (const CallNode *call = op->value.as<CallNode>()) { | ||
| if (call->op.same_as(create_list_of_mbarrier()) || | ||
| call->op.same_as(mbarrier_wait_parity()) || | ||
| call->op.same_as(builtin::ptx_arrive_barrier()) || | ||
| call->op.same_as(builtin::ptx_cp_async_barrier())) { | ||
| has_mbarrier_op_ = true; | ||
| } | ||
| } | ||
| IRVisitorWithAnalyzer::VisitStmt_(op); | ||
| } | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Detector misses new noinc intrinsic; may under-detect mbarrier usage.
Add ptx_cp_async_barrier_noinc() (and optionally other mbarrier-related ops) to the mbarrier set to keep gating correct with the new intrinsic.
- if (call->op.same_as(create_list_of_mbarrier()) ||
- call->op.same_as(mbarrier_wait_parity()) ||
- call->op.same_as(builtin::ptx_arrive_barrier()) ||
- call->op.same_as(builtin::ptx_cp_async_barrier())) {
+ if (call->op.same_as(create_list_of_mbarrier()) ||
+ call->op.same_as(mbarrier_wait_parity()) ||
+ call->op.same_as(builtin::ptx_arrive_barrier()) ||
+ call->op.same_as(builtin::ptx_cp_async_barrier()) ||
+ call->op.same_as(builtin::ptx_cp_async_barrier_noinc())) {
has_mbarrier_op_ = true;
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| private: | |
| void VisitStmt_(const EvaluateNode *op) final { | |
| if (const CallNode *call = op->value.as<CallNode>()) { | |
| if (call->op.same_as(create_list_of_mbarrier()) || | |
| call->op.same_as(mbarrier_wait_parity()) || | |
| call->op.same_as(builtin::ptx_arrive_barrier()) || | |
| call->op.same_as(builtin::ptx_cp_async_barrier())) { | |
| has_mbarrier_op_ = true; | |
| } | |
| } | |
| IRVisitorWithAnalyzer::VisitStmt_(op); | |
| } | |
| private: | |
| void VisitStmt_(const EvaluateNode *op) final { | |
| if (const CallNode *call = op->value.as<CallNode>()) { | |
| if (call->op.same_as(create_list_of_mbarrier()) || | |
| call->op.same_as(mbarrier_wait_parity()) || | |
| call->op.same_as(builtin::ptx_arrive_barrier()) || | |
| call->op.same_as(builtin::ptx_cp_async_barrier()) || | |
| call->op.same_as(builtin::ptx_cp_async_barrier_noinc())) { | |
| has_mbarrier_op_ = true; | |
| } | |
| } | |
| IRVisitorWithAnalyzer::VisitStmt_(op); | |
| } |
🤖 Prompt for AI Agents
In src/transform/warp_specialized_rewriter.h around lines 56 to 68, the
VisitStmt_ detector currently checks a set of mbarrier intrinsics but misses the
new ptx_cp_async_barrier_noinc intrinsic; update the conditional that sets
has_mbarrier_op_ to also check
call->op.same_as(builtin::ptx_cp_async_barrier_noinc()) (and add any other new
mbarrier-related intrinsics as needed) so the detector correctly flags usage;
ensure you include the appropriate declaration/namespace for the new builtin
symbol and run tests to validate gating behavior.
| void VisitExpr_(const CallNode *op) final { | ||
| if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) || | ||
| op->op.same_as(set_max_nreg())) { | ||
| has_tma_op_ = true; | ||
| } | ||
| IRVisitorWithAnalyzer::VisitExpr_(op); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Classifying set_max_nreg as “TMA” causes false disables.
set_max_nreg() is unrelated to TMA. Including it can disable AWS whenever any mbarrier call appears with a register hint.
- if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) ||
- op->op.same_as(set_max_nreg())) {
+ if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
has_tma_op_ = true;
}If you still want to track register hints, add a separate flag (e.g., has_set_max_nreg_) and keep it out of the TMA/mbarrier cross check.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| void VisitExpr_(const CallNode *op) final { | |
| if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) || | |
| op->op.same_as(set_max_nreg())) { | |
| has_tma_op_ = true; | |
| } | |
| IRVisitorWithAnalyzer::VisitExpr_(op); | |
| } | |
| void VisitExpr_(const CallNode *op) final { | |
| if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) { | |
| has_tma_op_ = true; | |
| } | |
| IRVisitorWithAnalyzer::VisitExpr_(op); | |
| } |
| void VisitStmt_(const AttrStmtNode *op) final { | ||
| if (op->attr_key == "warp_specialize" && | ||
| op->value.as<IntImmNode>()->value == 1) { | ||
| has_warp_specialization_ = true; | ||
| } | ||
| if (op->attr_key == tir::attr::thread_extent) { | ||
| IterVar iv = Downcast<IterVar>(op->node); | ||
| if (iv->thread_tag == "threadIdx.x") { | ||
| ICHECK(iv->dom->extent.as<IntImmNode>()); | ||
| thread_var_ = iv; | ||
| } | ||
| } | ||
| IRVisitorWithAnalyzer::VisitStmt_(op); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Attribute key for manual WS may be inconsistent.
Detector checks op->attr_key == "warp_specialize", while the pass uses attr::kWarpSpecializationScope. Ensure we’re keying off the right attribute for “manual WS enabled”.
I can align the detector once we confirm the canonical key.
🏁 Script executed:
#!/bin/bash
# Verify which attribute keys are used across the repo
rg -nP '"warp_specialize"\s*|kWarpSpecializationScope' -C2Length of output: 4489
Replace "warp_specialize" with attr::kWarpSpecializationScope in the detector
Repo-wide usage and tests use "kWarpSpecializationScope" (attr::kWarpSpecializationScope); "warp_specialize" only appears in ir.cc. Update src/transform/warp_specialized_rewriter.h (VisitStmt_, ~lines 77–90) to check attr::kWarpSpecializationScope — or check both keys if backward-compat is required.
🤖 Prompt for AI Agents
In src/transform/warp_specialized_rewriter.h around lines 77 to 90, the
VisitStmt_ detector currently checks the literal attr key "warp_specialize"
which is inconsistent with repo-wide usage; change the conditional to check
tir::attr::kWarpSpecializationScope instead (or check both "warp_specialize" and
tir::attr::kWarpSpecializationScope to preserve backward compatibility). Update
the if statement to compare op->attr_key against
tir::attr::kWarpSpecializationScope (and optionally also accept the legacy
"warp_specialize"), keeping the existing check for value.as<IntImmNode>()->value
== 1 and setting has_warp_specialization_ unchanged. Ensure to include the
proper namespace qualification (tir::attr::kWarpSpecializationScope) and adjust
any includes if necessary.
| def cp_async_barrier_noinc(barrier_id: Union[int, PrimExpr, tir.Call]): | ||
| """Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc. | ||
| """ | ||
| return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Normalize barrier argument like mbarrier_arrive; passing an int currently breaks codegen.
As written, callers can pass an int (per the type hint), which will be emitted as a literal (e.g., 0) into tl::mbarrier_cp_async_arrive_noinc and fail to bind to the expected reference. Mirror mbarrier_arrive’s normalization to always pass a handle.
Apply this diff:
-def cp_async_barrier_noinc(barrier_id: Union[int, PrimExpr, tir.Call]):
- """Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.
- """
- return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id)
+def cp_async_barrier_noinc(mbarrier: Union[int, PrimExpr, tir.Call]):
+ """Perform a PTX async-copy barrier using cp.async.mbarrier.arrive.noinc."""
+ if isinstance(mbarrier, (tir.Call, tir.BufferLoad)):
+ mb = mbarrier
+ elif isinstance(mbarrier, (tir.PrimExpr, int)):
+ mb = get_mbarrier(mbarrier)
+ elif isinstance(mbarrier, tir.Buffer):
+ mb = tir.BufferLoad(mbarrier, [0])
+ else:
+ raise TypeError(f"mbarrier must be an integer or a tir.Call, but got {type(mbarrier)}")
+ return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), mb)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def cp_async_barrier_noinc(barrier_id: Union[int, PrimExpr, tir.Call]): | |
| """Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc. | |
| """ | |
| return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id) | |
| def cp_async_barrier_noinc(mbarrier: Union[int, PrimExpr, tir.Call]): | |
| """Perform a PTX async-copy barrier using cp.async.mbarrier.arrive.noinc.""" | |
| if isinstance(mbarrier, (tir.Call, tir.BufferLoad)): | |
| mb = mbarrier | |
| elif isinstance(mbarrier, (tir.PrimExpr, int)): | |
| mb = get_mbarrier(mbarrier) | |
| elif isinstance(mbarrier, tir.Buffer): | |
| mb = tir.BufferLoad(mbarrier, [0]) | |
| else: | |
| raise TypeError(f"mbarrier must be an integer or a tir.Call, but got {type(mbarrier)}") | |
| return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), mb) |
🤖 Prompt for AI Agents
In tilelang/language/builtin.py around lines 355-358, the cp_async_barrier_noinc
function currently accepts an int and emits a literal which breaks codegen;
mirror mbarrier_arrive’s normalization so the barrier_id is always passed as a
handle. Change the function to detect when barrier_id is a plain int/PrimExpr
literal (or otherwise not already a handle) and wrap it the same way
mbarrier_arrive does (i.e., convert the integer/literal into a tir handle via
the same tir.call_intrin wrap used in mbarrier_arrive) before returning the
call_intrin for tl.ptx_cp_async_barrier_noinc.
…onality (tile-ai#809) - Introduced a new intrinsic `ptx_cp_async_barrier_noinc` for handling the `cp.async.mbarrier.arrive.noinc` operation in TileLang. - Updated the CUDA code generation to support the new barrier operation. - Added a corresponding function in the TileLang Python API for ease of use. - Enhanced the barrier handling in CUDA templates to include the new no-increment operation, improving synchronization capabilities in parallel execution contexts.
ptx_cp_async_barrier_noincfor handling thecp.async.mbarrier.arrive.noincoperation in TileLang.Summary by CodeRabbit
New Features
Refactor