Skip to content

Temporary workaround to disable TritonGPUHoistTMEMAlloc in b_dk += tl.dot(tl.trans(b_dA), b_kb)#687

Merged
zhiyuan1i merged 8 commits intofla-org:mainfrom
rucnyz:fix-blackwell-gdn
Dec 19, 2025
Merged

Temporary workaround to disable TritonGPUHoistTMEMAlloc in b_dk += tl.dot(tl.trans(b_dA), b_kb)#687
zhiyuan1i merged 8 commits intofla-org:mainfrom
rucnyz:fix-blackwell-gdn

Conversation

@rucnyz
Copy link
Contributor

@rucnyz rucnyz commented Dec 18, 2025

Temporarily fix #638

In the original code b_dk += tl.dot(tl.trans(b_dA), b_kb), the compiler fuses the addition into the dot op, generating IR like %136 (b_dk new) = tt.dot %134 (tl.trans(b_dA)), %127 (b_kb), %135 (b_dk_old), using %135 as the accumulator.
It seems the TritonGPUHoistTMEMAlloc pass identifies that %135 (input) and %136 (output) share the same TMEM buffer and attempts an optimization. However, on B200, it incorrectly calculates the dominance, moving the definition of %135 after its use in %136, causing the operand #2 does not dominate this use error.
The inline_asm forces the compiler to separate the dot and add instructions. And seems like only inline_asm will work and all other methods like + 0.0, * 1.0 and even bitcasts will still lead to the compiler optimize them away and re-enable the fusion.

Summary by CodeRabbit

  • Bug Fixes

    • Fixed incorrect dot-product fusion on NVIDIA Blackwell GPUs by adding a targeted runtime workaround to ensure correct math results.
  • New Features

    • Added Blackwell device detection and a safe dot-product path that selects the appropriate implementation at runtime.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 18, 2025

Caution

Review failed

The pull request is closed.

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds an IS_NVIDIA_BLACKWELL device flag and a hardware-aware safe_dot wrapper used in wy_fast to avoid a Triton/MLIR dominance/fusion compiler error on NVIDIA Blackwell GPUs; replaces one tl.dot call with safe_dot.

Changes

Cohort / File(s) Summary
Gated delta rule kernel
fla/ops/gated_delta_rule/wy_fast.py
Adds safe_dot(a, b) and replaces tl.dot(tl.trans(b_dA), b_kb) with safe_dot(tl.trans(b_dA), b_kb). When IS_NVIDIA_BLACKWELL is true, safe_dot uses a Triton inline-assembly workaround (docstring describes rationale); otherwise it falls back to tl.dot.
Device detection flag
fla/utils.py
Adds boolean IS_NVIDIA_BLACKWELL (true when IS_NVIDIA and torch.cuda.get_device_capability()[0] == 10) and exposes it via the module aliasing mechanism alongside existing device flags.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20–30 minutes

  • Review Triton inline-assembly in safe_dot for correctness, operand ordering, and conformance with Triton inline-asm conventions.
  • Verify the replacement site in wy_fast.py doesn't miss other similar tl.dot usages that might need the workaround.
  • Confirm IS_NVIDIA_BLACKWELL detection logic and alias registration align with existing patterns (e.g., Hopper flag) and handle environments without CUDA safely.

Poem

🐰 I stitched a tiny asm along the dot,

where compilers tangled, I fixed the spot.
Blackwell bunnies hop, kernels run bright,
no more dominance woes in the night.
🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the primary change: a workaround for the TritonGPUHoistTMEMAlloc fusion issue affecting the specific line b_dk += tl.dot(tl.trans(b_dA), b_kb).
Linked Issues check ✅ Passed The PR successfully addresses issue #638 by implementing an inline assembly workaround that prevents the problematic TritonGPUHoistTMEMAlloc fusion, allowing Blackwell compilation to proceed without MLIR dominance errors.
Out of Scope Changes check ✅ Passed All changes are directly scoped to the issue: safe_dot helper function and IS_NVIDIA_BLACKWELL detection flag to conditionally apply the workaround on Blackwell GPUs only.

📜 Recent review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 79036e4 and 235a439.

📒 Files selected for processing (1)
  • fla/ops/gated_delta_rule/wy_fast.py (2 hunks)

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @rucnyz, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a temporary fix to address a compiler bug specific to B200 GPUs. The issue stems from an incorrect optimization by the TritonGPUHoistTMEMAlloc pass, which attempts to fuse dot product and addition operations but fails due to a dominance calculation error. The solution involves using Triton's inline_asm_elementwise to explicitly prevent this fusion, thereby resolving the compilation error and allowing the code to run correctly on the affected hardware.

Highlights

  • Compiler Bug Workaround: This pull request implements a temporary workaround for a specific compiler bug on B200 GPUs. The bug occurs when the TritonGPUHoistTMEMAlloc pass incorrectly calculates dominance, leading to an 'operand does not dominate this use' error during the fusion of dot product and addition operations.
  • Inline Assembly for Control: To circumvent the compiler's problematic optimization, tl.inline_asm_elementwise is now used. This explicitly forces the separation of the tl.dot operation from the subsequent addition, preventing the compiler from fusing them and triggering the bug.
  • Targeted Fix: The change is applied to the b_dk += tl.dot(tl.trans(b_dA), b_kb) expression within the prepare_wy_repr_bwd_kernel function in fla/ops/gated_delta_rule/wy_fast.py.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a temporary workaround for a Triton compiler bug on B200 GPUs that causes an error due to incorrect instruction fusion. The change uses inline assembly to prevent the fusion of a dot and add operation. The workaround is well-explained in the pull request description and the implementation appears correct. My feedback is to add a comment in the code to document this workaround for future maintainability, explaining why it's necessary and under what conditions it could be removed.

Comment on lines 201 to 208
b_dk += tl.inline_asm_elementwise(
asm="mov.f32 $0, $1;",
constraints="=r,r",
args=[tl.dot(tl.trans(b_dA), b_kb)],
dtype=tl.float32,
is_pure=True,
pack=1,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The PR description provides excellent context for this workaround. To improve long-term maintainability, it's crucial to capture this information in the code itself. This ensures that any developer looking at this code in the future will understand the reason for this unusual construct and know when it can be removed.

I suggest adding a comment explaining the workaround and a TODO to track its removal.

Suggested change
b_dk += tl.inline_asm_elementwise(
asm="mov.f32 $0, $1;",
constraints="=r,r",
args=[tl.dot(tl.trans(b_dA), b_kb)],
dtype=tl.float32,
is_pure=True,
pack=1,
)
# Temporary workaround for a Triton compiler bug on B200 GPUs.
# The `TritonGPUHoistTMEMAlloc` pass incorrectly fuses the add and dot operations,
# leading to a dominance error. The inline assembly prevents this fusion.
# TODO: Remove this workaround when the compiler bug is fixed. Link to issue if available.
b_dk += tl.inline_asm_elementwise(
asm="mov.f32 $0, $1;",
constraints="=r,r",
args=[tl.dot(tl.trans(b_dA), b_kb)],
dtype=tl.float32,
is_pure=True,
pack=1,
)

@zhiyuan1i
Copy link
Collaborator

Thanks for the contribution, but it looks like this PR will break the FLA repository's general support for GPUs like AMD. Can you feed back to triton-lang/triton at the same time?

@yzhangcs
Copy link
Member

@rucnyz could you refactor this into a small helper function? That way, we can conditionally use inline assembly based on the device type while keeping the main logic clean.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
fla/utils.py (1)

473-490: Add IS_NVIDIA_BLACKWELL to the alias registration list.

The new IS_NVIDIA_BLACKWELL flag should be registered for lowercase aliasing, consistent with other device capability flags like IS_NVIDIA_HOPPER. This maintains the established convention where all device flags have both uppercase and lowercase variants available.

🔎 Proposed fix
 def _register_aliases():
     current_module = sys.modules[__name__]
     for key in (
         'IS_AMD',
         'IS_INTEL',
         'IS_NVIDIA',
         'IS_INTEL_ALCHEMIST',
         'IS_NVIDIA_HOPPER',
+        'IS_NVIDIA_BLACKWELL',
         'USE_CUDA_GRAPH',
         'IS_TF32_SUPPORTED',
         'IS_GATHER_SUPPORTED',
         'IS_TMA_SUPPORTED',
     ):
         if hasattr(current_module, key):
             setattr(current_module, key.lower(), getattr(current_module, key))
🧹 Nitpick comments (1)
fla/utils.py (1)

398-398: LGTM - Blackwell detection logic is correct.

The compute capability check correctly identifies Blackwell GPUs (CC 10.x), and the IS_NVIDIA guard ensures safe evaluation. The pattern matches the existing IS_NVIDIA_HOPPER flag.

Optional: For consistency with IS_TF32_SUPPORTED (line 402), consider using an explicit device index: torch.cuda.get_device_capability(0)[0] == 10. However, the current approach follows the IS_NVIDIA_HOPPER pattern and defaults to the current device, which is acceptable.

📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7ac536e and 435f069.

📒 Files selected for processing (2)
  • fla/ops/gated_delta_rule/wy_fast.py (2 hunks)
  • fla/utils.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • fla/ops/gated_delta_rule/wy_fast.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-07-24T11:34:30.618Z
Learnt from: KevlarKanou
Repo: fla-org/flash-linear-attention PR: 544
File: fla/utils.py:0-0
Timestamp: 2025-07-24T11:34:30.618Z
Learning: In `fla/utils.py`, the `TRITON_F32_DEFAULT` environment variable should be forcibly overridden to `'ieee'` for NVIDIA GPUs that don't support tf32 (device capability < 8), rather than using `setdefault()`. This override is required for compatibility with Turing GPUs and should not respect existing user settings.

Applied to files:

  • fla/utils.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
fla/ops/gated_delta_rule/wy_fast.py (1)

12-33: Add TODO comment to track removal of workaround.

As noted in the past review comments, a TODO comment should be added to ensure this temporary workaround is removed when the Triton compiler bug is fixed. The docstring provides good context, but an explicit TODO will make the temporary nature more visible and trackable.

🔎 Suggested documentation improvement
 if IS_NVIDIA_BLACKWELL:
     """
     Compute tl.dot with SM100 workaround.
 
     On SM100 (Blackwell) GPUs, wraps the result in inline assembly to prevent
     the TritonGPUHoistTMEMAlloc pass from incorrectly fusing add and dot operations.
     See: https://github.com/fla-org/flash-linear-attention/issues/638
+    
+    TODO: Remove this workaround once the Triton compiler bug is fixed.
+    Track upstream fix at: https://github.com/triton-lang/triton/issues/[ISSUE_NUMBER]
     """
     @triton.jit
     def safe_dot(a, b):

Consider also filing an issue with the Triton project (triton-lang/triton) as suggested in the PR comments, and linking it in the TODO.

🧹 Nitpick comments (1)
fla/ops/gated_delta_rule/wy_fast.py (1)

13-19: Consider moving docstring inside the function.

The docstring is currently placed after the if statement but before the function definition, which is non-standard. Per PEP 257, docstrings should be the first statement in a function body. While this placement provides context for the conditional compilation, it may confuse documentation generators and IDE tooltips.

Alternative placement suggestion
 if IS_NVIDIA_BLACKWELL:
-    """
-    Compute tl.dot with SM100 workaround.
-
-    On SM100 (Blackwell) GPUs, wraps the result in inline assembly to prevent
-    the TritonGPUHoistTMEMAlloc pass from incorrectly fusing add and dot operations.
-    See: https://github.com/fla-org/flash-linear-attention/issues/638
-    """
     @triton.jit
     def safe_dot(a, b):
+        """
+        Compute tl.dot with SM100 workaround.
+        
+        On SM100 (Blackwell) GPUs, wraps the result in inline assembly to prevent
+        the TritonGPUHoistTMEMAlloc pass from incorrectly fusing add and dot operations.
+        See: https://github.com/fla-org/flash-linear-attention/issues/638
+        """
         return tl.inline_asm_elementwise(
 else:
     @triton.jit
     def safe_dot(a, b):
+        """Compute tl.dot without workarounds (non-Blackwell GPUs)."""
         return tl.dot(a, b)
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 435f069 and dea8e1c.

📒 Files selected for processing (1)
  • fla/ops/gated_delta_rule/wy_fast.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
fla/ops/gated_delta_rule/wy_fast.py (1)
fla/utils.py (1)
  • check_shared_mem (448-454)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (3)
fla/ops/gated_delta_rule/wy_fast.py (3)

10-11: LGTM: Import addition is appropriate.

The IS_NVIDIA_BLACKWELL flag is correctly imported and used for conditional compilation of the workaround.


224-224: Workaround correctly applied at the problematic fusion site.

The replacement of tl.dot(tl.trans(b_dA), b_kb) with safe_dot(...) addresses the critical Triton compiler bug on Blackwell GPUs where TritonGPUHoistTMEMAlloc incorrectly hoists TMEM allocations, causing dominance errors. The function-level docstring provides sufficient context for this change.


12-29: The hardcoded dtype=tl.float32 is correct—no changes needed.

The kernel code shows that b_dA and b_kb may be converted to lower precision types (e.g., k.dtype.element_ty). However, Triton upcasts float dtypes to at least tl.float32, so tl.dot operations with float16 or bfloat16 inputs automatically return float32 results. The hardcoded dtype=tl.float32 correctly matches this auto-promoted output type.

Additionally, the suggested fix using dtype=result.dtype remains technically infeasible—Triton's inline_asm_elementwise requires the dtype parameter to be specified upfront to define the output type before the computation occurs; the result object cannot be queried before generation.

Likely an incorrect or invalid review comment.

@zhiyuan1i
Copy link
Collaborator

LGTM.

@zhiyuan1i zhiyuan1i merged commit 7ec266f into fla-org:main Dec 19, 2025
1 of 3 checks passed
@wilsonyqm
Copy link

seems like the fix is being deleted in #692 https://github.com/fla-org/flash-linear-attention/pull/692/files

might need to resend the fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] GatedDeltaNet backward error on Blackwell: 'error: operand #2 does not dominate this use'

4 participants