Skip to content

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Nov 20, 2025

This pull request introduces conditional compilation and improved error handling for CUDA atomic operations and debug utilities, ensuring compatibility with different CUDA versions and architectures. The main changes add checks for CUDA version and architecture before using certain features, and provide clear error messages when unsupported features are accessed.

CUDA atomic operations compatibility and error handling:

  • Added the TL_NOT_IMPLEMENTED() macro in atomic.h to print a message and trigger a breakpoint when an unsupported atomic operation is called, improving debuggability.
  • Wrapped calls to cuda::atomic_ref-based atomic operations (AtomicMax, AtomicMaxRet, AtomicMin, AtomicMinRet, AtomicAdd, AtomicAddRet, AtomicLoad, AtomicStore) in #if CUDART_VERSION >= 11080 checks, calling TL_NOT_IMPLEMENTED() if the CUDA version is too old. This prevents compilation or runtime errors on unsupported CUDA versions. [1] [2] [3] [4] [5] [6] [7]

Debug utilities and architecture checks:

  • Wrapped all uses of fp8_e4_t and fp8_e5_t debug printing and buffer value specializations in #if __CUDA_ARCH_LIST__ >= 890 checks, and included cuda_fp8.h only when this architecture is present. This ensures these features are only available on supported architectures. [1] [2] [3] [4] [5]
  • Removed an unconditional include of cuda_fp8.h from gemm_mma.h, which is now only included when needed by architecture checks elsewhere.

Summary by CodeRabbit

  • Chores
    • Updated CUDA atomic operations with version compatibility checks to ensure proper support across different CUDA versions.
    • Enhanced FP8 type support with architecture-specific compilation guards for optimal compatibility.
    • Removed redundant includes from template headers to streamline builds.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 20, 2025

Walkthrough

This pull request restricts CUDA atomic operations and FP8-specific debug functionality to supported CUDA and GPU architecture versions via preprocessor guards. It introduces a TL_NOT_IMPLEMENTED() macro to handle unsupported version paths and removes an unconditional FP8 header include, replacing it with architecture-gated includes.

Changes

Cohort / File(s) Summary
CUDA Atomic Operations with Version Gating
src/tl_templates/cuda/atomic.h
Added TL_NOT_IMPLEMENTED() macro; wrapped AtomicMax, AtomicMaxRet, AtomicMin, AtomicMinRet, AtomicAdd, AtomicAddRet, AtomicLoad, and AtomicStore implementations with #if CUDART_VERSION >= 11080 guards to route pre-11.8 CUDA versions to TL_NOT_IMPLEMENTED() paths.
FP8 Debug Specializations with Architecture Gating
src/tl_templates/cuda/debug.h
Added #if __CUDA_ARCH_LIST__ >= 890 guards around cuda_fp8.h inclusion and specialized debug_print_var / debug_print_buffer_value implementations for fp8_e4_t and fp8_e5_t types.
FP8 Header Cleanup
src/tl_templates/cuda/gemm_mma.h
Removed unconditional include of cuda_fp8.h; all conditional includes remain intact.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~12 minutes

  • Attention areas:
    • Verify correctness of CUDART_VERSION threshold (11080) across all guarded atomic operations
    • Confirm CUDA_ARCH_LIST >= 890 is the appropriate architecture requirement for FP8 specializations
    • Ensure TL_NOT_IMPLEMENTED() macro definition and behavior are appropriate for the fallback paths
    • Check that removal of unconditional cuda_fp8.h include from gemm_mma.h doesn't break existing conditional includes elsewhere

Possibly related issues

Possibly related PRs

Poem

🐰 A rabbit hops through CUDA's maze,
With version gates and architect praise,
FP8 confined to SM_89's grace,
Atomic ops skip unsupported space—
TL_NOT_IMPLEMENTED marks the way! 🚀

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Title check ⚠️ Warning The title '[Compatibility] Support CUDA 11.3' is partially related to the changeset but misleading. Changes primarily guard against unsupported CUDA versions (11.3 and older) via TL_NOT_IMPLEMENTED() and architecture checks (CUDART_VERSION >= 11080 and CUDA_ARCH_LIST >= 890), rather than adding positive support for CUDA 11.3. Revise the title to reflect the actual intent, such as '[Compatibility] Add conditional compilation for CUDA version/architecture checks' or '[Compatibility] Route unsupported CUDA operations to error handler'.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

📝 Customizable high-level summaries are now available in beta!

You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.

  • Provide your own instructions using the high_level_summary_instructions setting.
  • Format the summary however you like (bullet lists, tables, multi-section layouts, contributor stats, etc.).
  • Use high_level_summary_in_walkthrough to move the summary from the description to the walkthrough section.

Example instruction:

"Divide the high-level summary into five sections:

  1. 📝 Description — Summarize the main change in 50–60 words, explaining what was done.
  2. 📓 References — List relevant issues, discussions, documentation, or related PRs.
  3. 📦 Dependencies & Requirements — Mention any new/updated dependencies, environment variable changes, or configuration updates.
  4. 📊 Contributor Summary — Include a Markdown table showing contributions:
    | Contributor | Lines Added | Lines Removed | Files Changed |
  5. ✔️ Additional Notes — Add any extra reviewer context.
    Keep each section concise (under 200 words) and use bullet or numbered lists for clarity."

Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9e67b86 and dc369c0.

📒 Files selected for processing (3)
  • src/tl_templates/cuda/atomic.h (8 hunks)
  • src/tl_templates/cuda/debug.h (5 hunks)
  • src/tl_templates/cuda/gemm_mma.h (0 hunks)
💤 Files with no reviewable changes (1)
  • src/tl_templates/cuda/gemm_mma.h
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.

Applied to files:

  • src/tl_templates/cuda/atomic.h
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Quick Lint
🔇 Additional comments (2)
src/tl_templates/cuda/debug.h (1)

3-5: LGTM! FP8 architecture guards properly implemented.

The conditional compilation guards correctly restrict FP8 type support to architectures >= 890 (SM 8.9+), preventing compilation errors on older GPUs that lack FP8 hardware support. The pattern is consistently applied to the header include and all FP8 specializations.

Also applies to: 123-144, 251-273

src/tl_templates/cuda/atomic.h (1)

15-19: LGTM! Clear runtime diagnostic for unsupported paths.

The TL_NOT_IMPLEMENTED() macro provides clear diagnostics when unsupported atomic operations are invoked on older CUDA versions, using a breakpoint to halt execution for debugging.

Comment on lines +70 to +75
#if CUDART_VERSION >= 11080
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.fetch_max(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
#else
TL_NOT_IMPLEMENTED();
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify CUDA version requirement matches PR objectives.

The code gates atomic operations behind CUDART_VERSION >= 11080 (CUDA 11.8), but the PR title indicates support for CUDA 11.3 (version 11030). Please confirm whether CUDA 11.3 lacks cuda::atomic_ref support, or if the version check should be lowered to 11030.

Also applies to: 100-106, 163-169, 240-245, 318-324, 650-655, 661-666


🌐 Web query:

What CUDA version introduced cuda::atomic_ref support?

💡 Result:

cuda::atomic_ref was introduced in CUDA Toolkit 11.6. [1][2]

Sources:
[1] CUDA Core Compute Libraries — Synchronization Primitives (cuda::atomic_ref listed as CUDA 11.6).
[2] libcu++ 1.7.0 release notes (libcu++ 1.7.0 / CUDA Toolkit 11.6 adds atomic_ref).


Confirm PR target version—CUDA 11.3 cannot support cuda::atomic_ref operations.

The cuda::atomic_ref API was introduced in CUDA Toolkit 11.6, not 11.3. The version gates at CUDART_VERSION >= 11080 (CUDA 11.8) are unnecessarily restrictive, but lowering them to 11030 (CUDA 11.3) will not work—the API didn't exist until 11.6.

Either the PR target version should be CUDA 11.6+, or alternative implementations must be provided for CUDA 11.3-11.5. The current TL_NOT_IMPLEMENTED() fallback would fail at runtime for CUDA 11.3 users.

Applies to: lines 70-75, 100-106, 163-169, 240-245, 318-324, 650-655, 661-666

🤖 Prompt for AI Agents
In src/tl_templates/cuda/atomic.h around lines 70-75, 100-106, 163-169, 240-245,
318-324, 650-655 and 661-666 the code gates use cuda::atomic_ref which was
introduced in CUDA 11.6; update the CUDART_VERSION checks to >= 11060 if we
intend to require CUDA 11.6+ and document that PR target or CI must use 11.6+,
otherwise replace the TL_NOT_IMPLEMENTED() fallbacks with a portable atomic
implementation (e.g. an atomicCAS loop or intrinsics) for CUDA 11.3–11.5 to
perform the same operations, and ensure all version guards and documentation
reflect the chosen approach.

Comment on lines +132 to +138
#if CUDART_VERSION >= 11080
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
return static_cast<T1>(
aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order)));
#else
TL_NOT_IMPLEMENTED();
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: Return statement in void function.

AtomicMin is declared as returning void (line 111), but lines 134-135 contain a return statement with a value. This will cause a compilation error.

Apply this diff to remove the erroneous return statement:

 #if CUDART_VERSION >= 11080
     cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
-    return static_cast<T1>(
-        aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order)));
+    aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
 #else
     TL_NOT_IMPLEMENTED();
 #endif
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
#if CUDART_VERSION >= 11080
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
return static_cast<T1>(
aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order)));
#else
TL_NOT_IMPLEMENTED();
#endif
#if CUDART_VERSION >= 11080
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
#else
TL_NOT_IMPLEMENTED();
#endif
🤖 Prompt for AI Agents
In src/tl_templates/cuda/atomic.h around lines 132 to 138, the AtomicMin
implementation is in a function declared void but contains a returned value;
remove the erroneous return so the function does not return a value. Replace the
line returning static_cast<T1>(aref.fetch_min(...)) with a call to
aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order)) and
discard the result (or cast val beforehand) so the operation executes without
returning anything; keep the TL_NOT_IMPLEMENTED() fallback as-is.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant