Skip to content

[GDN] Fix potential ood for long inputs#692

Merged
yzhangcs merged 3 commits intomainfrom
gdn-ood
Dec 20, 2025
Merged

[GDN] Fix potential ood for long inputs#692
yzhangcs merged 3 commits intomainfrom
gdn-ood

Conversation

@yzhangcs
Copy link
Member

@yzhangcs yzhangcs commented Dec 20, 2025

This PR fixes the following illegal access

1510999c673ea804994b093caefdefbb

Summary by CodeRabbit

  • Refactor
    • Optimized computational kernel performance with improved indexing and synchronization logic
    • Enhanced backward pass kernel calculations with refined accumulation strategies
    • Updated public function return type to optionally include a third tensor value for extended functionality

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 20, 2025

Walkthrough

The changes refactor pointer arithmetic in Triton kernels by replacing stride-based indexing with explicit dimension-product multiplications across forward and backward delta-rule kernels. The return type annotation for chunk_gated_delta_rule_fwd_h is expanded to include an optional third tensor, and the WY representation backward kernel's accumulation logic is restructured with new synchronization barriers and reworked tensor product patterns.

Changes

Cohort / File(s) Summary
Delta-H Kernel Pointer Arithmetic
fla/ops/common/chunk_delta_h.py
Replaced stride-based offset calculations with explicit dimension-product indexing (e.g., H*K*V, H*K, H*V factors) in forward kernel (chunk_gated_delta_rule_fwd_kernel_h_blockdim64) and backward kernel (chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64). Updated block pointer constructions for h, v, k, w, and gradient tensors. Updated return type annotation of chunk_gated_delta_rule_fwd_h to `tuple[torch.Tensor, torch.Tensor, torch.Tensor
WY Backward Kernel Algebra
fla/ops/gated_delta_rule/wy_fast.py
Reworked b_A and b_dA computation in prepare_wy_repr_bwd_kernel: replaced b_kb-based accumulation with direct b_k and b_A tensor operations; introduced explicit synchronization barriers between initialization and accumulation phases; adjusted b_dk computation pattern; added per-block scaling of b_A by b_b factor; preserved alias wiring for backward API.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Pointer arithmetic verification: Carefully verify that the new dimension-product based indexing in both forward and backward kernels produces correct memory access patterns and boundary conditions as the original stride-based approach.
  • Return type compatibility: Ensure callers of chunk_gated_delta_rule_fwd_h handle the optional third tensor correctly.
  • Backward algebra correctness: The reworked tensor product patterns (replacing b_kb with direct b_k and b_A operations, new b_dk computation) require verification that accumulated gradients remain mathematically equivalent.
  • Synchronization barriers: Confirm that added tl.debug_barrier calls do not introduce unintended serialization or deadlocks.

Possibly related PRs

Poem

🐰 Stride-bounds fade to products bright,
Where H and K dance in the night,
Pointers leap with barriers true,
Our kernels gleam, reborn anew!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Title check ❓ Inconclusive The title '[GDN] Fix potential ood for long inputs' uses an acronym 'ood' that is not clearly defined and makes the change vague; it refers to a real aspect of the change but lacks sufficient clarity about what is being fixed. Clarify the acronym 'ood' in the title (e.g., 'out-of-bounds' or 'out-of-domain') to make the primary issue more explicit and understandable to reviewers scanning the history.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch gdn-ood

📜 Recent review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 042d998 and 2614e79.

📒 Files selected for processing (2)
  • fla/ops/common/chunk_delta_h.py (11 hunks)
  • fla/ops/gated_delta_rule/wy_fast.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
fla/ops/gated_delta_rule/wy_fast.py (2)
fla/ops/delta_rule/wy_fast.py (1)
  • prepare_wy_repr_bwd (245-286)
fla/ops/comba/wy_fast.py (1)
  • prepare_wy_repr_bwd (372-424)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (8)
fla/ops/gated_delta_rule/wy_fast.py (3)

215-217: Synchronization barrier placement looks correct.

The tl.debug_barrier() ensures all threads complete the b_dA computation and dtype conversion before entering the second loop. This is important for correctness when threads share data.


218-231: Restructured accumulation logic looks correct.

The refactored computation properly:

  1. Accumulates b_A via tl.dot(b_k, b_kt) for the k @ k^T term
  2. Computes b_dk by combining the scaled term with the transposed dot product
  3. Accumulates with the previously stored dk from the first loop

The separation of b_kt and b_ktb improves clarity.


234-239: Post-loop scaling and gradient computation looks correct.

The b_A *= b_b[:, None] applies the beta scaling to the accumulated matrix after the loop completes. The b_dg gradient computation via axis sums aligns with the mathematical derivation for gating gradients.

fla/ops/common/chunk_delta_h.py (5)

82-88: Critical fix for integer overflow on long sequences.

The explicit .to(tl.int64) cast before multiplying by dimension products (K, V) prevents integer overflow when computing pointer offsets. For example, with T=65536, H=32, K=128, the offset bos * H * K could exceed 2^31 and overflow int32.

This is the core fix for the "potential ood for long inputs" mentioned in the PR title.


110-141: Consistent application of explicit dimension products in block pointers.

The block pointer offsets (i_t * H*K*V, H*K, H*V) are now explicit rather than relying on stride-based inference. This maintains consistency with the int64 base offset computation and prevents any intermediate overflow.


293-301: Backward kernel receives consistent int64 treatment.

The same overflow prevention pattern is applied to all backward kernel pointer calculations. This ensures both forward and backward passes handle long sequences correctly.


322-331: Block pointer offset consistency in backward kernel.

The backward kernel block pointers use explicit i_t*H*K*V offsets, matching the forward kernel pattern. This ensures the entire address computation chain is consistent.


477-477: Return type annotation correctly updated.

The expanded return type tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] accurately reflects the function's actual return values: (h, v_new, final_state) where final_state is None when output_final_state=False.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @yzhangcs, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a potential out-of-bounds (OOD) issue that could arise when processing long input sequences within the Gated Delta Network (GDN) operations. The changes primarily focus on refining memory access patterns and correcting calculations within the forward and backward passes of the chunked gated delta rule, ensuring robustness and correctness for various input lengths and preventing unexpected errors.

Highlights

  • Offset Calculation Refinement: Memory offset calculations for various tensors (h, v, k, w, v_new, q, do, dv, dv2, dh, gk) in chunk_delta_h.py have been reordered and simplified. This change explicitly integrates stride values directly into tl.make_block_ptr calls, removing intermediate stride variables and preventing potential out-of-bounds access for long inputs.
  • Backward Pass Kernel Correction: The prepare_wy_repr_bwd_kernel in wy_fast.py received significant updates to its calculation logic for b_A and b_dk. These revisions aim to improve the accuracy of gradient computations and mitigate out-of-bounds issues, especially when processing long input sequences.
  • Function Signature Update: The return type hint for the chunk_gated_delta_rule_fwd_h function was updated to include an optional torch.Tensor, indicating a potential new output or a more flexible return structure for the forward pass.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses a potential integer overflow issue in chunk_delta_h.py for long inputs by ensuring calculations are performed using 64-bit integers. The changes also include refactoring in both chunk_delta_h.py and wy_fast.py that appears to improve performance and work around potential compiler issues. The updated function signature in chunk_delta_h.py is also a good correctness fix. Overall, these are solid improvements to the codebase.

# main recurrence
for i_t in range(NT):
p_h1 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
p_h1 = tl.make_block_ptr(h + i_t * H*K*V, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For improved readability and maintainability, consider re-introducing the stride_h, stride_k, and stride_v variables. Using named variables for strides (e.g., stride_h = H*K*V) can make the pointer arithmetic in the loops clearer and less prone to errors, as the expressions are replaced by a descriptive name. This would apply to all places where these strides were used in both the forward and backward kernels. If there's a performance or compiler-related reason for inlining them, a comment explaining it would be beneficial.

b_A = tl.zeros([BT, BT], dtype=tl.float32)
b_dA = tl.where(m_A, -b_dA, 0).to(k.dtype.element_ty)

tl.debug_barrier()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Please consider adding a comment explaining why tl.debug_barrier() is necessary here. It would be helpful for future maintainers to understand the reason, for example, if it's to prevent a specific compiler optimization pass from causing issues (similar to the explanation for safe_dot), or to ensure a specific order of memory operations.

@yzhangcs yzhangcs merged commit 854c4ce into main Dec 20, 2025
5 checks passed
@yzhangcs yzhangcs deleted the gdn-ood branch December 20, 2025 18:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant