Temporary workaround to disable TritonGPUHoistTMEMAlloc in b_dk += tl.dot(tl.trans(b_dA), b_kb)#687
Conversation
….dot(tl.trans(b_dA), b_kb)
|
Caution Review failedThe pull request is closed. Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds an IS_NVIDIA_BLACKWELL device flag and a hardware-aware safe_dot wrapper used in wy_fast to avoid a Triton/MLIR dominance/fusion compiler error on NVIDIA Blackwell GPUs; replaces one tl.dot call with safe_dot. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20–30 minutes
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
📜 Recent review detailsConfiguration used: Repository UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @rucnyz, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a temporary fix to address a compiler bug specific to B200 GPUs. The issue stems from an incorrect optimization by the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a temporary workaround for a Triton compiler bug on B200 GPUs that causes an error due to incorrect instruction fusion. The change uses inline assembly to prevent the fusion of a dot and add operation. The workaround is well-explained in the pull request description and the implementation appears correct. My feedback is to add a comment in the code to document this workaround for future maintainability, explaining why it's necessary and under what conditions it could be removed.
fla/ops/gated_delta_rule/wy_fast.py
Outdated
| b_dk += tl.inline_asm_elementwise( | ||
| asm="mov.f32 $0, $1;", | ||
| constraints="=r,r", | ||
| args=[tl.dot(tl.trans(b_dA), b_kb)], | ||
| dtype=tl.float32, | ||
| is_pure=True, | ||
| pack=1, | ||
| ) |
There was a problem hiding this comment.
The PR description provides excellent context for this workaround. To improve long-term maintainability, it's crucial to capture this information in the code itself. This ensures that any developer looking at this code in the future will understand the reason for this unusual construct and know when it can be removed.
I suggest adding a comment explaining the workaround and a TODO to track its removal.
| b_dk += tl.inline_asm_elementwise( | |
| asm="mov.f32 $0, $1;", | |
| constraints="=r,r", | |
| args=[tl.dot(tl.trans(b_dA), b_kb)], | |
| dtype=tl.float32, | |
| is_pure=True, | |
| pack=1, | |
| ) | |
| # Temporary workaround for a Triton compiler bug on B200 GPUs. | |
| # The `TritonGPUHoistTMEMAlloc` pass incorrectly fuses the add and dot operations, | |
| # leading to a dominance error. The inline assembly prevents this fusion. | |
| # TODO: Remove this workaround when the compiler bug is fixed. Link to issue if available. | |
| b_dk += tl.inline_asm_elementwise( | |
| asm="mov.f32 $0, $1;", | |
| constraints="=r,r", | |
| args=[tl.dot(tl.trans(b_dA), b_kb)], | |
| dtype=tl.float32, | |
| is_pure=True, | |
| pack=1, | |
| ) |
|
Thanks for the contribution, but it looks like this PR will break the FLA repository's general support for GPUs like AMD. Can you feed back to triton-lang/triton at the same time? |
|
@rucnyz could you refactor this into a small helper function? That way, we can conditionally use inline assembly based on the device type while keeping the main logic clean. |
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
fla/utils.py (1)
473-490: AddIS_NVIDIA_BLACKWELLto the alias registration list.The new
IS_NVIDIA_BLACKWELLflag should be registered for lowercase aliasing, consistent with other device capability flags likeIS_NVIDIA_HOPPER. This maintains the established convention where all device flags have both uppercase and lowercase variants available.🔎 Proposed fix
def _register_aliases(): current_module = sys.modules[__name__] for key in ( 'IS_AMD', 'IS_INTEL', 'IS_NVIDIA', 'IS_INTEL_ALCHEMIST', 'IS_NVIDIA_HOPPER', + 'IS_NVIDIA_BLACKWELL', 'USE_CUDA_GRAPH', 'IS_TF32_SUPPORTED', 'IS_GATHER_SUPPORTED', 'IS_TMA_SUPPORTED', ): if hasattr(current_module, key): setattr(current_module, key.lower(), getattr(current_module, key))
🧹 Nitpick comments (1)
fla/utils.py (1)
398-398: LGTM - Blackwell detection logic is correct.The compute capability check correctly identifies Blackwell GPUs (CC 10.x), and the
IS_NVIDIAguard ensures safe evaluation. The pattern matches the existingIS_NVIDIA_HOPPERflag.Optional: For consistency with
IS_TF32_SUPPORTED(line 402), consider using an explicit device index:torch.cuda.get_device_capability(0)[0] == 10. However, the current approach follows theIS_NVIDIA_HOPPERpattern and defaults to the current device, which is acceptable.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
fla/ops/gated_delta_rule/wy_fast.py(2 hunks)fla/utils.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- fla/ops/gated_delta_rule/wy_fast.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-07-24T11:34:30.618Z
Learnt from: KevlarKanou
Repo: fla-org/flash-linear-attention PR: 544
File: fla/utils.py:0-0
Timestamp: 2025-07-24T11:34:30.618Z
Learning: In `fla/utils.py`, the `TRITON_F32_DEFAULT` environment variable should be forcibly overridden to `'ieee'` for NVIDIA GPUs that don't support tf32 (device capability < 8), rather than using `setdefault()`. This override is required for compatibility with Turing GPUs and should not respect existing user settings.
Applied to files:
fla/utils.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test H100 (PyTorch 2.7) / test-ops
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (1)
fla/ops/gated_delta_rule/wy_fast.py (1)
12-33: Add TODO comment to track removal of workaround.As noted in the past review comments, a TODO comment should be added to ensure this temporary workaround is removed when the Triton compiler bug is fixed. The docstring provides good context, but an explicit TODO will make the temporary nature more visible and trackable.
🔎 Suggested documentation improvement
if IS_NVIDIA_BLACKWELL: """ Compute tl.dot with SM100 workaround. On SM100 (Blackwell) GPUs, wraps the result in inline assembly to prevent the TritonGPUHoistTMEMAlloc pass from incorrectly fusing add and dot operations. See: https://github.com/fla-org/flash-linear-attention/issues/638 + + TODO: Remove this workaround once the Triton compiler bug is fixed. + Track upstream fix at: https://github.com/triton-lang/triton/issues/[ISSUE_NUMBER] """ @triton.jit def safe_dot(a, b):Consider also filing an issue with the Triton project (triton-lang/triton) as suggested in the PR comments, and linking it in the TODO.
🧹 Nitpick comments (1)
fla/ops/gated_delta_rule/wy_fast.py (1)
13-19: Consider moving docstring inside the function.The docstring is currently placed after the
ifstatement but before the function definition, which is non-standard. Per PEP 257, docstrings should be the first statement in a function body. While this placement provides context for the conditional compilation, it may confuse documentation generators and IDE tooltips.Alternative placement suggestion
if IS_NVIDIA_BLACKWELL: - """ - Compute tl.dot with SM100 workaround. - - On SM100 (Blackwell) GPUs, wraps the result in inline assembly to prevent - the TritonGPUHoistTMEMAlloc pass from incorrectly fusing add and dot operations. - See: https://github.com/fla-org/flash-linear-attention/issues/638 - """ @triton.jit def safe_dot(a, b): + """ + Compute tl.dot with SM100 workaround. + + On SM100 (Blackwell) GPUs, wraps the result in inline assembly to prevent + the TritonGPUHoistTMEMAlloc pass from incorrectly fusing add and dot operations. + See: https://github.com/fla-org/flash-linear-attention/issues/638 + """ return tl.inline_asm_elementwise( else: @triton.jit def safe_dot(a, b): + """Compute tl.dot without workarounds (non-Blackwell GPUs).""" return tl.dot(a, b)
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
fla/ops/gated_delta_rule/wy_fast.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
fla/ops/gated_delta_rule/wy_fast.py (1)
fla/utils.py (1)
check_shared_mem(448-454)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (3)
fla/ops/gated_delta_rule/wy_fast.py (3)
10-11: LGTM: Import addition is appropriate.The
IS_NVIDIA_BLACKWELLflag is correctly imported and used for conditional compilation of the workaround.
224-224: Workaround correctly applied at the problematic fusion site.The replacement of
tl.dot(tl.trans(b_dA), b_kb)withsafe_dot(...)addresses the critical Triton compiler bug on Blackwell GPUs whereTritonGPUHoistTMEMAllocincorrectly hoists TMEM allocations, causing dominance errors. The function-level docstring provides sufficient context for this change.
12-29: The hardcodeddtype=tl.float32is correct—no changes needed.The kernel code shows that
b_dAandb_kbmay be converted to lower precision types (e.g.,k.dtype.element_ty). However, Triton upcasts float dtypes to at least tl.float32, sotl.dotoperations with float16 or bfloat16 inputs automatically return float32 results. The hardcodeddtype=tl.float32correctly matches this auto-promoted output type.Additionally, the suggested fix using
dtype=result.dtyperemains technically infeasible—Triton'sinline_asm_elementwiserequires thedtypeparameter to be specified upfront to define the output type before the computation occurs; the result object cannot be queried before generation.Likely an incorrect or invalid review comment.
|
LGTM. |
|
seems like the fix is being deleted in #692 https://github.com/fla-org/flash-linear-attention/pull/692/files might need to resend the fix. |
Temporarily fix #638
In the original code
b_dk += tl.dot(tl.trans(b_dA), b_kb), the compiler fuses the addition into the dot op, generating IR like %136 (b_dk new) = tt.dot %134 (tl.trans(b_dA)), %127 (b_kb), %135 (b_dk_old), using %135 as the accumulator.It seems the
TritonGPUHoistTMEMAllocpass identifies that %135 (input) and %136 (output) share the same TMEM buffer and attempts an optimization. However, on B200, it incorrectly calculates the dominance, moving the definition of %135 after its use in %136, causing theoperand #2 does not dominate this useerror.The inline_asm forces the compiler to separate the dot and add instructions. And seems like only inline_asm will work and all other methods like + 0.0, * 1.0 and even bitcasts will still lead to the compiler optimize them away and re-enable the fusion.
Summary by CodeRabbit
Bug Fixes
New Features
✏️ Tip: You can customize this high-level summary in your review settings.